Commit 5184793e authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added base class ByteArraySynchronizer

parent 98fe52f6
Loading
Loading
Loading
Loading
+32 −0
Original line number Diff line number Diff line
/***************************************************************************
                          ByteArraySynchronizer.h  -  description
                             -------------------
    begin                : November 17, 2020
    copyright            : (C) 2020 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

// Implemented by: Jakub Klinkovský

#pragma once

#include <TNL/Containers/ArrayView.h>

namespace TNL {
namespace Containers {

template< typename Device, typename Index >
class ByteArraySynchronizer
{
public:
   using ByteArrayView = ArrayView< std::uint8_t, Device, Index >;

   virtual void synchronizeByteArray( ByteArrayView& array, int bytesPerValue ) = 0;

   virtual ~ByteArraySynchronizer() = default;
};

} // namespace Containers
} // namespace TNL
+27 −12
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@

#pragma once

#include <TNL/Containers/ByteArraySynchronizer.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Matrices/DenseMatrix.h>

@@ -32,11 +33,15 @@ struct HasMeshType< T, typename Containers::Expressions::enable_if_type< typenam
template< typename DistributedMesh,
          int EntityDimension = DistributedMesh::getMeshDimension() >
class DistributedMeshSynchronizer
: public Containers::ByteArraySynchronizer< typename DistributedMesh::DeviceType, typename DistributedMesh::GlobalIndexType >
{
   using Base = Containers::ByteArraySynchronizer< typename DistributedMesh::DeviceType, typename DistributedMesh::GlobalIndexType >;

public:
   using DeviceType = typename DistributedMesh::DeviceType;
   using GlobalIndexType = typename DistributedMesh::GlobalIndexType;
   using CommunicatorType = typename DistributedMesh::CommunicatorType;
   using ByteArrayView = typename Base::ByteArrayView;

   DistributedMeshSynchronizer() = default;

@@ -182,10 +187,20 @@ public:
   template< typename Array >
   void synchronizeArray( Array& array, int valuesPerElement = 1 )
   {
      TNL_ASSERT_EQ( array.getSize(), valuesPerElement * ghostOffsets[ ghostOffsets.getSize() - 1 ],
                     "The array does not have the expected size." );
      static_assert( std::is_same< typename Array::DeviceType, DeviceType >::value,
                     "mismatched DeviceType of the array" );
      using ValueType = typename Array::ValueType;

      ByteArrayView view;
      view.bind( reinterpret_cast<std::uint8_t*>( array.getData() ), sizeof(ValueType) * array.getSize() );
      synchronizeByteArray( view, sizeof(ValueType) * valuesPerElement );
   }

   virtual void synchronizeByteArray( ByteArrayView& array, int bytesPerValue ) override
   {
      TNL_ASSERT_EQ( array.getSize(), bytesPerValue * ghostOffsets[ ghostOffsets.getSize() - 1 ],
                     "The array does not have the expected size." );

      // GOTCHA: https://devblogs.nvidia.com/cuda-pro-tip-always-set-current-device-avoid-multithreading-bugs/
      #ifdef HAVE_CUDA
      if( std::is_same< DeviceType, Devices::Cuda >::value )
@@ -196,7 +211,7 @@ public:
      const int nproc = CommunicatorType::GetSize( group );

      // allocate send buffers (setSize does nothing if the array size is already correct)
      sendBuffers.setSize( valuesPerElement * ghostNeighborOffsets[ nproc ] * sizeof(ValueType) );
      sendBuffers.setSize( bytesPerValue * ghostNeighborOffsets[ nproc ] );

      // buffer for asynchronous communication requests
      std::vector< typename CommunicatorType::Request > requests;
@@ -205,20 +220,20 @@ public:
      for( int j = 0; j < nproc; j++ ) {
         if( ghostEntitiesCounts( rank, j ) > 0 ) {
            requests.push_back( CommunicatorType::IRecv(
                     array.getData() + valuesPerElement * ghostOffsets[ j ],
                     valuesPerElement * ghostEntitiesCounts( rank, j ),
                     array.getData() + bytesPerValue * ghostOffsets[ j ],
                     bytesPerValue * ghostEntitiesCounts( rank, j ),
                     j, 0, group ) );
         }
      }

      Containers::ArrayView< ValueType, DeviceType, GlobalIndexType > sendBuffersView;
      sendBuffersView.bind( reinterpret_cast<ValueType*>( sendBuffers.getData() ), valuesPerElement * ghostNeighborOffsets[ nproc ] );
      ByteArrayView sendBuffersView;
      sendBuffersView.bind( sendBuffers.getData(), bytesPerValue * ghostNeighborOffsets[ nproc ] );
      const auto ghostNeighborsView = ghostNeighbors.getConstView();
      const auto arrayView = array.getConstView();
      auto copy_kernel = [sendBuffersView, arrayView, ghostNeighborsView, valuesPerElement] __cuda_callable__ ( GlobalIndexType k, GlobalIndexType offset ) mutable
      auto copy_kernel = [sendBuffersView, arrayView, ghostNeighborsView, bytesPerValue] __cuda_callable__ ( GlobalIndexType k, GlobalIndexType offset ) mutable
      {
         for( int i = 0; i < valuesPerElement; i++ )
            sendBuffersView[ i + valuesPerElement * (offset + k) ] = arrayView[ i + valuesPerElement * ghostNeighborsView[ offset + k ] ];
         for( int i = 0; i < bytesPerValue; i++ )
            sendBuffersView[ i + bytesPerValue * (offset + k) ] = arrayView[ i + bytesPerValue * ghostNeighborsView[ offset + k ] ];
      };

      for( int i = 0; i < nproc; i++ ) {
@@ -229,8 +244,8 @@ public:

            // issue async send operation
            requests.push_back( CommunicatorType::ISend(
                     sendBuffersView.getData() + valuesPerElement * ghostNeighborOffsets[ i ],
                     valuesPerElement * ghostEntitiesCounts( i, rank ),
                     sendBuffersView.getData() + bytesPerValue * ghostNeighborOffsets[ i ],
                     bytesPerValue * ghostEntitiesCounts( i, rank ),
                     i, 0, group ) );
         }
      }