Commit d18c2cf0 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactor operators for SyncDirection in DistributedNDArraySynchronizer

parent ac498c7f
Loading
Loading
Loading
Loading
+22 −8
Original line number Diff line number Diff line
@@ -67,10 +67,24 @@ enum class SyncDirection : std::uint8_t
   // FrontBottomLeft = Front | Bottom | Left,
};

inline bool
inline SyncDirection
operator&( SyncDirection a, SyncDirection b )
{
   return std::uint8_t( a ) & std::uint8_t( b );
   return static_cast< SyncDirection >( static_cast< std::uint8_t >( a ) & static_cast< std::uint8_t >( b ) );
}

inline SyncDirection
operator|( SyncDirection a, SyncDirection b )
{
   return static_cast< SyncDirection >( static_cast< std::uint8_t >( a ) | static_cast< std::uint8_t >( b ) );
}

// this operator makes `a -= b` equivalent to `a &= ~b`, i.e. it clears all bits from b in a
inline SyncDirection&
operator-=( SyncDirection& a, SyncDirection b )
{
   a = static_cast< SyncDirection >( static_cast< std::uint8_t >( a ) & ~ static_cast< std::uint8_t >( b ) );
   return a;
}

template< typename DistributedNDArray,
@@ -434,26 +448,26 @@ protected:
         copy_kernel.to_buffer = to_buffer;

         if( to_buffer ) {
            if( mask & SyncDirection::Left ) {
            if( ( mask & SyncDirection::Left ) != SyncDirection::None ) {
               copy_kernel.buffer_view.bind( dim_buffers.left_send_view );
               copy_kernel.array_offsets = dim_buffers.left_send_offsets;
               dim_buffers.left_send_view.forAll( copy_kernel );
            }

            if( mask & SyncDirection::Right ) {
            if( ( mask & SyncDirection::Right ) != SyncDirection::None ) {
               copy_kernel.buffer_view.bind( dim_buffers.right_send_view );
               copy_kernel.array_offsets = dim_buffers.right_send_offsets;
               dim_buffers.right_send_view.forAll( copy_kernel );
            }
         }
         else {
            if( mask & SyncDirection::Right ) {
            if( ( mask & SyncDirection::Right ) != SyncDirection::None ) {
               copy_kernel.buffer_view.bind( dim_buffers.left_recv_view );
               copy_kernel.array_offsets = dim_buffers.left_recv_offsets;
               dim_buffers.left_recv_view.forAll( copy_kernel );
            }

            if( mask & SyncDirection::Left ) {
            if( ( mask & SyncDirection::Left ) != SyncDirection::None ) {
               copy_kernel.buffer_view.bind( dim_buffers.right_recv_view );
               copy_kernel.array_offsets = dim_buffers.right_recv_offsets;
               dim_buffers.right_recv_view.forAll( copy_kernel );
@@ -486,7 +500,7 @@ protected:

      auto& dim_buffers = buffers.template getDimBuffers< dim >();

      if( mask & SyncDirection::Left ) {
      if( ( mask & SyncDirection::Left ) != SyncDirection::None ) {
         requests.push_back( MPI::Isend( dim_buffers.left_send_view.getData(),
                                         dim_buffers.left_send_view.getStorageSize(),
                                         dim_buffers.left_neighbor,
@@ -498,7 +512,7 @@ protected:
                                         tag_from_right,
                                         communicator ) );
      }
      if( mask & SyncDirection::Right ) {
      if( ( mask & SyncDirection::Right ) != SyncDirection::None ) {
         requests.push_back( MPI::Isend( dim_buffers.right_send_view.getData(),
                                         dim_buffers.right_send_view.getStorageSize(),
                                         dim_buffers.right_neighbor,