Commit 4ad2ba2f authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Implemented host-device data transfers in TNL::File for HIP

parent d5831d45
Loading
Loading
Loading
Loading
+27 −6
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@
#include <TNL/String.h>
#include <TNL/Allocators/Host.h>
#include <TNL/Allocators/Cuda.h>
#include <TNL/Allocators/Hip.h>

namespace TNL {

@@ -144,7 +145,8 @@ class File
      template< typename Type,
                typename SourceType,
                typename Allocator,
                typename = std::enable_if_t< ! std::is_same< Allocator, Allocators::Cuda< Type > >::value > >
                typename = std::enable_if_t< ! std::is_same< Allocator, Allocators::Cuda< Type > >::value
                                          && ! std::is_same< Allocator, Allocators::Hip< Type > >::value > >
      void load_impl( Type* buffer, std::streamsize elements );

      // implementation for \ref Allocators::Cuda
@@ -155,11 +157,21 @@ class File
                typename = void >
      void load_impl( Type* buffer, std::streamsize elements );

      // implementation for \ref Allocators::Hip
      template< typename Type,
                typename SourceType,
                typename Allocator,
                typename = std::enable_if_t< std::is_same< Allocator, Allocators::Hip< Type > >::value >,
                typename = void,
                typename = void >
      void load_impl( Type* buffer, std::streamsize elements );

      // implementation for all allocators which allocate data accessible from host
      template< typename Type,
                typename TargetType,
                typename Allocator,
                typename = std::enable_if_t< ! std::is_same< Allocator, Allocators::Cuda< Type > >::value > >
                typename = std::enable_if_t< ! std::is_same< Allocator, Allocators::Cuda< Type > >::value
                                          && ! std::is_same< Allocator, Allocators::Hip< Type > >::value > >
      void save_impl( const Type* buffer, std::streamsize elements );

      // implementation for \ref Allocators::Cuda
@@ -170,6 +182,15 @@ class File
                typename = void >
      void save_impl( const Type* buffer, std::streamsize elements );

      // implementation for \ref Allocators::Hip
      template< typename Type,
                typename TargetType,
                typename Allocator,
                typename = std::enable_if_t< std::is_same< Allocator, Allocators::Hip< Type > >::value >,
                typename = void,
                typename = void >
      void save_impl( const Type* buffer, std::streamsize elements );

      std::fstream file;
      String fileName;
};
+108 −0
Original line number Diff line number Diff line
@@ -20,6 +20,9 @@
#include <TNL/Cuda/CheckDevice.h>
#include <TNL/Cuda/LaunchHelpers.h>
#include <TNL/Exceptions/CudaSupportMissing.h>
#include <TNL/Hip/CheckDevice.h>
#include <TNL/Hip/LaunchHelpers.h>
#include <TNL/Exceptions/HipSupportMissing.h>
#include <TNL/Exceptions/FileSerializationError.h>
#include <TNL/Exceptions/FileDeserializationError.h>
#include <TNL/Exceptions/NotImplementedError.h>
@@ -171,6 +174,58 @@ void File::load_impl( Type* buffer, std::streamsize elements )
#endif
}

// Allocators::Hip
template< typename Type,
          typename SourceType,
          typename Allocator,
          typename, typename, typename >
void File::load_impl( Type* buffer, std::streamsize elements )
{
#ifdef HAVE_HIP
   const std::streamsize host_buffer_size = std::min( Hip::getTransferBufferSize() / (std::streamsize) sizeof(Type), elements );
   using BaseType = typename std::remove_cv< Type >::type;
   std::unique_ptr< BaseType[] > host_buffer{ new BaseType[ host_buffer_size ] };

   std::streamsize readElements = 0;
   if( std::is_same< Type, SourceType >::value )
   {
      while( readElements < elements )
      {
         const std::streamsize transfer = std::min( elements - readElements, host_buffer_size );
         file.read( reinterpret_cast<char*>(host_buffer.get()), sizeof(Type) * transfer );
         hipMemcpy( (void*) &buffer[ readElements ],
                    (void*) host_buffer.get(),
                    transfer * sizeof( Type ),
                    hipMemcpyHostToDevice );
         TNL_CHECK_HIP_DEVICE;
         readElements += transfer;
      }
   }
   else
   {
      const std::streamsize cast_buffer_size = std::min( Hip::getTransferBufferSize() / (std::streamsize) sizeof(SourceType), elements );
      using BaseType = typename std::remove_cv< SourceType >::type;
      std::unique_ptr< BaseType[] > cast_buffer{ new BaseType[ cast_buffer_size ] };

      while( readElements < elements )
      {
         const std::streamsize transfer = std::min( elements - readElements, cast_buffer_size );
         file.read( reinterpret_cast<char*>(cast_buffer.get()), sizeof(SourceType) * transfer );
         for( std::streamsize i = 0; i < transfer; i++ )
            host_buffer[ i ] = static_cast< Type >( cast_buffer[ i ] );
         hipMemcpy( (void*) &buffer[ readElements ],
                    (void*) host_buffer.get(),
                    transfer * sizeof( Type ),
                    hipMemcpyHostToDevice );
         TNL_CHECK_HIP_DEVICE;
         readElements += transfer;
      }
   }
#else
   throw Exceptions::HipSupportMissing();
#endif
}

template< typename Type,
          typename TargetType,
          typename Allocator >
@@ -266,6 +321,59 @@ void File::save_impl( const Type* buffer, std::streamsize elements )
#endif
}

// Allocators::Hip
template< typename Type,
          typename TargetType,
          typename Allocator,
          typename, typename, typename >
void File::save_impl( const Type* buffer, std::streamsize elements )
{
#ifdef HAVE_HIP
   const std::streamsize host_buffer_size = std::min( Hip::getTransferBufferSize() / (std::streamsize) sizeof(Type), elements );
   using BaseType = typename std::remove_cv< Type >::type;
   std::unique_ptr< BaseType[] > host_buffer{ new BaseType[ host_buffer_size ] };

   std::streamsize writtenElements = 0;
   if( std::is_same< Type, TargetType >::value )
   {
      while( writtenElements < elements )
      {
         const std::streamsize transfer = std::min( elements - writtenElements, host_buffer_size );
         hipMemcpy( (void*) host_buffer.get(),
                    (void*) &buffer[ writtenElements ],
                    transfer * sizeof(Type),
                    hipMemcpyDeviceToHost );
         TNL_CHECK_HIP_DEVICE;
         file.write( reinterpret_cast<const char*>(host_buffer.get()), sizeof(Type) * transfer );
         writtenElements += transfer;
      }
   }
   else
   {
      const std::streamsize cast_buffer_size = std::min( Hip::getTransferBufferSize() / (std::streamsize) sizeof(TargetType), elements );
      using BaseType = typename std::remove_cv< TargetType >::type;
      std::unique_ptr< BaseType[] > cast_buffer{ new BaseType[ cast_buffer_size ] };

      while( writtenElements < elements )
      {
         const std::streamsize transfer = std::min( elements - writtenElements, host_buffer_size );
         hipMemcpy( (void*) host_buffer.get(),
                    (void*) &buffer[ writtenElements ],
                    transfer * sizeof(Type),
                    hipMemcpyDeviceToHost );
         TNL_CHECK_HIP_DEVICE;
         for( std::streamsize i = 0; i < transfer; i++ )
            cast_buffer[ i ] = static_cast< TargetType >( host_buffer[ i ] );

         file.write( reinterpret_cast<const char*>(cast_buffer.get()), sizeof(TargetType) * transfer );
         writtenElements += transfer;
      }
   }
#else
   throw Exceptions::HipSupportMissing();
#endif
}

inline bool fileExists( const String& fileName )
{
   std::fstream file;