Commit 82e0c538 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

DistributedNDArray: added forBoundary and forLocalBoundary methods

parent f9853a86
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
@@ -223,6 +223,37 @@ public:
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are neighbours of *global* boundaries
   template< typename Device2 = DeviceType, typename Func >
   void forBoundary( Func f ) const
   {
      // add static sizes
      using SkipBegins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >;
      // add dynamic sizes
      SkipBegins skipBegins;
      __ndarray_impl::SetSizesAddHelper< 1, SkipBegins, SizesHolderType, Overlaps >::add( skipBegins, SizesHolderType{} );
      __ndarray_impl::SetSizesMaxHelper< SkipBegins, LocalBeginsType >::max( skipBegins, localBegins );

      // subtract static sizes
      using SkipEnds = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type;
      // subtract dynamic sizes
      SkipEnds skipEnds;
      __ndarray_impl::SetSizesSubtractHelper< 1, SkipEnds, SizesHolderType, Overlaps >::subtract( skipEnds, globalSizes );
      __ndarray_impl::SetSizesMinHelper< SkipEnds, SizesHolderType >::min( skipEnds, localEnds );

      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }

   // iterate over local elements outside the given [skipBegins, skipEnds) range specified by global indices
   template< typename Device2 = DeviceType, typename Func, typename SkipBegins, typename SkipEnds >
   void forBoundary( Func f, const SkipBegins& skipBegins, const SkipEnds& skipEnds ) const
   {
      // TODO: assert "localBegins <= skipBegins <= localEnds", "localBegins <= skipEnds <= localEnds"
      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }

   // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalInternal( Func f ) const
@@ -239,6 +270,22 @@ public:
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are neighbours of overlaps (if all overlaps are 0, it has no effect)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalBoundary( Func f ) const
   {
      // add dynamic sizes
      LocalBeginsType skipBegins;
      __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( skipBegins, localBegins, false );

      // subtract dynamic sizes
      SizesHolderType skipEnds;
      __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( skipEnds, localEnds, false );

      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }


   // extra methods

+47 −0
Original line number Diff line number Diff line
@@ -252,6 +252,37 @@ public:
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are neighbours of *global* boundaries
   template< typename Device2 = DeviceType, typename Func >
   void forBoundary( Func f ) const
   {
      // add static sizes
      using SkipBegins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >;
      // add dynamic sizes
      SkipBegins skipBegins;
      __ndarray_impl::SetSizesAddHelper< 1, SkipBegins, SizesHolderType, Overlaps >::add( skipBegins, SizesHolderType{} );
      __ndarray_impl::SetSizesMaxHelper< SkipBegins, LocalBeginsType >::max( skipBegins, localBegins );

      // subtract static sizes
      using SkipEnds = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type;
      // subtract dynamic sizes
      SkipEnds skipEnds;
      __ndarray_impl::SetSizesSubtractHelper< 1, SkipEnds, SizesHolderType, Overlaps >::subtract( skipEnds, globalSizes );
      __ndarray_impl::SetSizesMinHelper< SkipEnds, SizesHolderType >::min( skipEnds, localEnds );

      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }

   // iterate over local elements outside the given [skipBegins, skipEnds) range specified by global indices
   template< typename Device2 = DeviceType, typename Func, typename SkipBegins, typename SkipEnds >
   void forBoundary( Func f, const SkipBegins& skipBegins, const SkipEnds& skipEnds ) const
   {
      // TODO: assert "localBegins <= skipBegins <= localEnds", "localBegins <= skipEnds <= localEnds"
      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }

   // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalInternal( Func f ) const
@@ -268,6 +299,22 @@ public:
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are neighbours of overlaps (if all overlaps are 0, it has no effect)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalBoundary( Func f ) const
   {
      // add dynamic sizes
      LocalBeginsType skipBegins;
      __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( skipBegins, localBegins, false );

      // subtract dynamic sizes
      SizesHolderType skipEnds;
      __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( skipEnds, localEnds, false );

      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, skipBegins, skipEnds, localEnds, f );
   }

protected:
   NDArrayView localView;
   CommunicationGroup group = Communicator::NullGroup;
+84 −0
Original line number Diff line number Diff line
@@ -376,6 +376,90 @@ TYPED_TEST( DistributedNDArrayTest, forLocalInternal )
   test_helper_forLocalInternal( this->distributedNDArray );
}

// separate function because nvcc does not allow __cuda_callable__ lambdas inside
// private or protected methods (which are created by TYPED_TEST macro)
template< typename DistributedArray >
void test_helper_forBoundary( DistributedArray& a )
{
   using IndexType = typename DistributedArray::IndexType;

   const auto localRange = a.template getLocalRange< 0 >();
   auto a_view = a.getView();

   auto setter = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      a_view( i ) += 1;
   };

   a.setValue( 0 );
   a.forBoundary( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
   {
      if( gi == 0 || gi == a.template getSize< 0 >() - 1 )
         EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;
      else
         EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;
   }

   a.setValue( 0 );
   a_view.forBoundary( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
   {
      if( gi == 0 || gi == a.template getSize< 0 >() - 1 )
         EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;
      else
         EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;
   }
}

TYPED_TEST( DistributedNDArrayTest, forBoundary )
{
   test_helper_forBoundary( this->distributedNDArray );
}

// separate function because nvcc does not allow __cuda_callable__ lambdas inside
// private or protected methods (which are created by TYPED_TEST macro)
template< typename DistributedArray >
void test_helper_forLocalBoundary( DistributedArray& a )
{
   using IndexType = typename DistributedArray::IndexType;

   const auto localRange = a.template getLocalRange< 0 >();
   auto a_view = a.getView();

   auto setter = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      a_view( i ) += 1;
   };

   a.setValue( 0 );
   // empty set because all overlaps are 0
   a.forLocalBoundary( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;

   a.setValue( 0 );
   // empty set because all overlaps are 0
   a_view.forLocalBoundary( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;
}

TYPED_TEST( DistributedNDArrayTest, forLocalBoundary )
{
   test_helper_forLocalBoundary( this->distributedNDArray );
}

#endif  // HAVE_GTEST