Commit 8b91dfcc authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

DistributedNDArray: added forAll method

parent 07d933dc
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -184,6 +184,14 @@ public:
      return ! (*this == other);
   }

   // iterate over all local elements
   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {
      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, localEnds, f );
   }


   // extra methods

+8 −0
Original line number Diff line number Diff line
@@ -213,6 +213,14 @@ public:
      return ! (*this == other);
   }

   // iterate over all local elements
   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {
      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( localBegins, localEnds, f );
   }

protected:
   NDArrayView localView;
   CommunicationGroup group = Communicator::NullGroup;
+33 −0
Original line number Diff line number Diff line
@@ -259,6 +259,39 @@ TYPED_TEST( DistributedNDArrayTest, comparisonOperators )
   EXPECT_TRUE( u == v );
}

// separate function because nvcc does not allow __cuda_callable__ lambdas inside
// private or protected methods (which are created by TYPED_TEST macro)
template< typename DistributedArray >
void test_helper_forAll( DistributedArray& a )
{
   using IndexType = typename DistributedArray::IndexType;

   const auto localRange = a.template getLocalRange< 0 >();
   auto a_view = a.getView();

   auto setter = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      a_view( i ) += 1;
   };

   a.setValue( 0 );
   a.forAll( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 1 );

   a.setValue( 0 );
   a_view.forAll( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 1 );
}

TYPED_TEST( DistributedNDArrayTest, forAll )
{
   test_helper_forAll( this->distributedNDArray );
}

#endif  // HAVE_GTEST