Commit 3ccb6528 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Generalized DistributedMeshSynchronizer to allow synchronization of plain...

Generalized DistributedMeshSynchronizer to allow synchronization of plain arrays with multiple values per element
parent 34e9dac0
Loading
Loading
Loading
Loading
+41 −14
Original line number Diff line number Diff line
@@ -19,6 +19,16 @@ namespace TNL {
namespace Meshes {
namespace DistributedMeshes {

template< typename T, typename Enable = void >
struct HasMeshType
: public std::false_type
{};

template< typename T >
struct HasMeshType< T, typename Containers::Expressions::enable_if_type< typename T::MeshType >::type >
: public std::true_type
{};

template< typename DistributedMesh,
          int EntityDimension = DistributedMesh::getMeshDimension() >
class DistributedMeshSynchronizer
@@ -149,16 +159,32 @@ public:
      }
   }

   template< typename MeshFunction >
   template< typename MeshFunction,
             std::enable_if_t< HasMeshType< MeshFunction >::value, bool > = true >
   void synchronize( MeshFunction& function )
   {
      static_assert( MeshFunction::getEntitiesDimension() == EntityDimension,
                     "the mesh function's entity dimension does not match" );
      static_assert( std::is_same< typename MeshFunction::MeshType, typename DistributedMesh::MeshType >::value,
                     "The type of the mesh function's mesh does not match the local mesh." );
      TNL_ASSERT_EQ( function.getData().getSize(), ghostOffsets[ ghostOffsets.getSize() - 1 ],
                     "The mesh function does not have the expected size." );
      using RealType = typename MeshFunction::RealType;

      synchronize( function.getData() );
   }

   template< typename Array,
             std::enable_if_t< ! HasMeshType< Array >::value, bool > = true >
   void synchronize( Array& array )
   {
      // wrapped only because nvcc is fucked up and does not like __cuda_callable__ lambdas in enable_if methods
      synchronizeArray( array );
   }

   template< typename Array >
   void synchronizeArray( Array& array, int valuesPerElement = 1 )
   {
      TNL_ASSERT_EQ( array.getSize(), valuesPerElement * ghostOffsets[ ghostOffsets.getSize() - 1 ],
                     "The array does not have the expected size." );
      using ValueType = typename Array::ValueType;

      // GOTCHA: https://devblogs.nvidia.com/cuda-pro-tip-always-set-current-device-avoid-multithreading-bugs/
      #ifdef HAVE_CUDA
@@ -170,7 +196,7 @@ public:
      const int nproc = CommunicatorType::GetSize( group );

      // allocate send buffers (setSize does nothing if the array size is already correct)
      sendBuffers.setSize( ghostNeighborOffsets[ nproc ] * sizeof(RealType) );
      sendBuffers.setSize( valuesPerElement * ghostNeighborOffsets[ nproc ] * sizeof(ValueType) );

      // buffer for asynchronous communication requests
      std::vector< typename CommunicatorType::Request > requests;
@@ -179,19 +205,20 @@ public:
      for( int j = 0; j < nproc; j++ ) {
         if( ghostEntitiesCounts( rank, j ) > 0 ) {
            requests.push_back( CommunicatorType::IRecv(
                     function.getData().getData() + ghostOffsets[ j ],
                     ghostEntitiesCounts( rank, j ),
                     array.getData() + valuesPerElement * ghostOffsets[ j ],
                     valuesPerElement * ghostEntitiesCounts( rank, j ),
                     j, 0, group ) );
         }
      }

      Containers::ArrayView< RealType, DeviceType, GlobalIndexType > sendBuffersView;
      sendBuffersView.bind( reinterpret_cast<RealType*>( sendBuffers.getData() ), ghostNeighborOffsets[ nproc ] );
      Containers::ArrayView< ValueType, DeviceType, GlobalIndexType > sendBuffersView;
      sendBuffersView.bind( reinterpret_cast<ValueType*>( sendBuffers.getData() ), valuesPerElement * ghostNeighborOffsets[ nproc ] );
      const auto ghostNeighborsView = ghostNeighbors.getConstView();
      const auto functionDataView = function.getData().getConstView();
      auto copy_kernel = [sendBuffersView, functionDataView, ghostNeighborsView] __cuda_callable__ ( GlobalIndexType k, GlobalIndexType offset ) mutable
      const auto arrayView = array.getConstView();
      auto copy_kernel = [sendBuffersView, arrayView, ghostNeighborsView, valuesPerElement] __cuda_callable__ ( GlobalIndexType k, GlobalIndexType offset ) mutable
      {
         sendBuffersView[ offset + k ] = functionDataView[ ghostNeighborsView[ offset + k ] ];
         for( int i = 0; i < valuesPerElement; i++ )
            sendBuffersView[ i + valuesPerElement * (offset + k) ] = arrayView[ i + valuesPerElement * ghostNeighborsView[ offset + k ] ];
      };

      for( int i = 0; i < nproc; i++ ) {
@@ -202,8 +229,8 @@ public:

            // issue async send operation
            requests.push_back( CommunicatorType::ISend(
                     sendBuffersView.getData() + ghostNeighborOffsets[ i ],
                     ghostEntitiesCounts( i, rank ),
                     sendBuffersView.getData() + valuesPerElement * ghostNeighborOffsets[ i ],
                     valuesPerElement * ghostEntitiesCounts( i, rank ),
                     i, 0, group ) );
         }
      }