Commit ad476ebd authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Merge branch 'lbm-hack' into 'develop'

LBM hacks: add optimizations of the ndarray synchronizer as presented at WSC2019

See merge request !76
parents c386b09e fd751724
Loading
Loading
Loading
Loading
+72 −33
Original line number Diff line number Diff line
@@ -19,7 +19,12 @@
namespace TNL {
namespace Containers {

template< typename DistributedNDArray >
template< typename DistributedNDArray,
          // This can be set to false to optimize out buffering when it is not needed
          // (e.g. for LBM with 1D distribution and specific orientation of the ndarray)
          bool buffered = true,
          // switch for the LBM hack: only 9 of 27 distribution functions will be sent
          bool LBM_HACK = false >
class DistributedNDArraySynchronizer
{
public:
@@ -117,11 +122,18 @@ protected:
         SizesHolder bufferSize( localEnds );
         bufferSize.template setSize< dim >( overlap );

         // allocate buffers
         dim_buffers.left_send_buffer.setSize( bufferSize );
         dim_buffers.left_recv_buffer.setSize( bufferSize );
         dim_buffers.right_send_buffer.setSize( bufferSize );
         dim_buffers.right_recv_buffer.setSize( bufferSize );

         // bind views to the buffers
         dim_buffers.left_send_view.bind( dim_buffers.left_send_buffer.getView() );
         dim_buffers.left_recv_view.bind( dim_buffers.left_recv_buffer.getView() );
         dim_buffers.right_send_view.bind( dim_buffers.right_send_buffer.getView() );
         dim_buffers.right_recv_view.bind( dim_buffers.right_recv_buffer.getView() );

         // TODO: check overlap offsets for 2D and 3D distributions (watch out for the corners - maybe use SetSizesSubtractOverlapsHelper?)

         // offsets for left-send
@@ -153,36 +165,47 @@ protected:
   {
      static void exec( Buffers& buffers, DistributedNDArrayView& array_view, bool to_buffer )
      {
         // skip if there are no overlaps
         const std::size_t overlap = __ndarray_impl::get< dim >( typename DistributedNDArray::OverlapsType{} );
         if( overlap == 0 )
            return;

         auto& dim_buffers = buffers.template getDimBuffers< dim >();

         if( buffered ) {
            // TODO: specify CUDA stream for the copy, otherwise async won't work !!!
         CopyKernel< decltype(dim_buffers.left_send_buffer.getView()) > copy_kernel;
            CopyKernel< decltype(dim_buffers.left_send_view) > copy_kernel;
            copy_kernel.array_view.bind( array_view );
            copy_kernel.to_buffer = to_buffer;

            if( to_buffer ) {
            copy_kernel.buffer_view.bind( dim_buffers.left_send_buffer.getView() );
               copy_kernel.buffer_view.bind( dim_buffers.left_send_view );
               copy_kernel.array_offsets = dim_buffers.left_send_offsets;
            dim_buffers.left_send_buffer.forAll( copy_kernel );
               dim_buffers.left_send_view.forAll( copy_kernel );

            copy_kernel.buffer_view.bind( dim_buffers.right_send_buffer.getView() );
               copy_kernel.buffer_view.bind( dim_buffers.right_send_view );
               copy_kernel.array_offsets = dim_buffers.right_send_offsets;
            dim_buffers.right_send_buffer.forAll( copy_kernel );
               dim_buffers.right_send_view.forAll( copy_kernel );
            }
            else {
            copy_kernel.buffer_view.bind( dim_buffers.left_recv_buffer.getView() );
               copy_kernel.buffer_view.bind( dim_buffers.left_recv_view );
               copy_kernel.array_offsets = dim_buffers.left_recv_offsets;
            dim_buffers.left_recv_buffer.forAll( copy_kernel );
               dim_buffers.left_recv_view.forAll( copy_kernel );

            copy_kernel.buffer_view.bind( dim_buffers.right_recv_buffer.getView() );
               copy_kernel.buffer_view.bind( dim_buffers.right_recv_view );
               copy_kernel.array_offsets = dim_buffers.right_recv_offsets;
            dim_buffers.right_recv_buffer.forAll( copy_kernel );
               dim_buffers.right_recv_view.forAll( copy_kernel );
            }
         }
         else {
            // avoid buffering - bind buffer views directly to the array
            dim_buffers.left_send_view.bind( &call_with_offsets( dim_buffers.left_send_offsets, array_view ) );
            dim_buffers.left_recv_view.bind( &call_with_offsets( dim_buffers.left_recv_offsets, array_view ) );
            dim_buffers.right_send_view.bind( &call_with_offsets( dim_buffers.right_send_offsets, array_view ) );
            dim_buffers.right_recv_view.bind( &call_with_offsets( dim_buffers.right_recv_offsets, array_view ) );
         }

      }
   };

   template< std::size_t dim >
@@ -197,19 +220,35 @@ protected:

         auto& dim_buffers = buffers.template getDimBuffers< dim >();

         requests.push_back( Communicator::ISend( dim_buffers.left_send_buffer.getStorageArray().getData(),
                                                  dim_buffers.left_send_buffer.getStorageSize(),
         if( LBM_HACK == false ) {
            requests.push_back( Communicator::ISend( dim_buffers.left_send_view.getData(),
                                                     dim_buffers.left_send_view.getStorageSize(),
                                                     dim_buffers.left_neighbor, 0, group ) );
         requests.push_back( Communicator::IRecv( dim_buffers.left_recv_buffer.getStorageArray().getData(),
                                                  dim_buffers.left_recv_buffer.getStorageSize(),
            requests.push_back( Communicator::IRecv( dim_buffers.left_recv_view.getData(),
                                                     dim_buffers.left_recv_view.getStorageSize(),
                                                     dim_buffers.left_neighbor, 1, group ) );
         requests.push_back( Communicator::ISend( dim_buffers.right_send_buffer.getStorageArray().getData(),
                                                  dim_buffers.right_send_buffer.getStorageSize(),
            requests.push_back( Communicator::ISend( dim_buffers.right_send_view.getData(),
                                                     dim_buffers.right_send_view.getStorageSize(),
                                                     dim_buffers.right_neighbor, 1, group ) );
         requests.push_back( Communicator::IRecv( dim_buffers.right_recv_buffer.getStorageArray().getData(),
                                                  dim_buffers.right_recv_buffer.getStorageSize(),
            requests.push_back( Communicator::IRecv( dim_buffers.right_recv_view.getData(),
                                                     dim_buffers.right_recv_view.getStorageSize(),
                                                     dim_buffers.right_neighbor, 0, group ) );
         }
         else {
            requests.push_back( Communicator::ISend( dim_buffers.left_send_view.getData() + 0,
                                                     dim_buffers.left_send_view.getStorageSize() / 27 * 9,
                                                     dim_buffers.left_neighbor, 0, group ) );
            requests.push_back( Communicator::IRecv( dim_buffers.left_recv_view.getData() + dim_buffers.left_recv_view.getStorageSize() / 27 * 18,
                                                     dim_buffers.left_recv_view.getStorageSize() / 27 * 9,
                                                     dim_buffers.left_neighbor, 1, group ) );
            requests.push_back( Communicator::ISend( dim_buffers.right_send_view.getData() + dim_buffers.left_recv_view.getStorageSize() / 27 * 18,
                                                     dim_buffers.right_send_view.getStorageSize() / 27 * 9,
                                                     dim_buffers.right_neighbor, 1, group ) );
            requests.push_back( Communicator::IRecv( dim_buffers.right_recv_view.getData() + 0,
                                                     dim_buffers.right_recv_view.getStorageSize() / 27 * 9,
                                                     dim_buffers.right_neighbor, 0, group ) );
         }
      }
   };

#ifdef __NVCC__
+42 −1
Original line number Diff line number Diff line
@@ -18,6 +18,47 @@ namespace TNL {
namespace Containers {
namespace __ndarray_impl {

template< typename OffsetsHolder,
          typename Sequence >
struct OffsetsHelper
{};

template< typename OffsetsHolder,
          std::size_t... N >
struct OffsetsHelper< OffsetsHolder, std::index_sequence< N... > >
{
   template< typename Func >
   __cuda_callable__
   static auto apply( const OffsetsHolder& offsets, Func&& f ) -> decltype(auto)
   {
      return f( offsets.template getSize< N >()... );
   }

   template< typename Func >
   static auto apply_host( const OffsetsHolder& offsets, Func&& f ) -> decltype(auto)
   {
      return f( offsets.template getSize< N >()... );
   }
};

template< typename OffsetsHolder,
          typename Func >
__cuda_callable__
auto call_with_offsets( const OffsetsHolder& offsets, Func&& f ) -> decltype(auto)
{
   return OffsetsHelper< OffsetsHolder, std::make_index_sequence< OffsetsHolder::getDimension() > >
          ::apply( offsets, std::forward< Func >( f ) );
}

template< typename OffsetsHolder,
          typename Func >
auto host_call_with_offsets( const OffsetsHolder& offsets, Func&& f ) -> decltype(auto)
{
   return OffsetsHelper< OffsetsHolder, std::make_index_sequence< OffsetsHolder::getDimension() > >
          ::apply_host( offsets, std::forward< Func >( f ) );
}


template< typename OffsetsHolder,
          typename Sequence >
struct IndexShiftHelper
@@ -56,7 +97,7 @@ auto call_with_shifted_indices( const OffsetsHolder& offsets, Func&& f, Indices&
template< typename OffsetsHolder,
          typename Func,
          typename... Indices >
auto host_call_with_unshifted_indices( const OffsetsHolder& offsets, Func&& f, Indices&&... indices ) -> decltype(auto)
auto host_call_with_shifted_indices( const OffsetsHolder& offsets, Func&& f, Indices&&... indices ) -> decltype(auto)
{
   return IndexShiftHelper< OffsetsHolder, std::make_index_sequence< sizeof...( Indices ) > >
          ::apply_host( offsets, std::forward< Func >( f ), std::forward< Indices >( indices )... );
+6 −0
Original line number Diff line number Diff line
@@ -31,6 +31,7 @@ struct SynchronizerBuffersLayer
                                typename DistributedNDArray::PermutationType,
                                typename DistributedNDArray::DeviceType >;
   NDArrayType left_send_buffer, left_recv_buffer, right_send_buffer, right_recv_buffer;
   typename NDArrayType::ViewType left_send_view, left_recv_view, right_send_view, right_recv_view;
   typename DistributedNDArray::LocalBeginsType left_send_offsets, left_recv_offsets, right_send_offsets, right_recv_offsets;

   int left_neighbor = -1;
@@ -43,6 +44,11 @@ struct SynchronizerBuffersLayer
      right_send_buffer.reset();
      right_recv_buffer.reset();

      left_send_view.reset();
      left_recv_view.reset();
      right_send_view.reset();
      right_recv_view.reset();

      left_send_offsets = left_recv_offsets = right_send_offsets = right_recv_offsets = typename DistributedNDArray::LocalBeginsType{};

      left_neighbor = right_neighbor = -1;