Commit f9853a86 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

DistributedNDArray: added forInternal and forLocalInternal methods

parent 8b91dfcc
Loading
Loading
Loading
Loading
+47 −0
Original line number Diff line number Diff line
@@ -192,6 +192,53 @@ public:
      dispatch( localBegins, localEnds, f );
   }

   // iterate over local elements which are not neighbours of *global* boundaries
   template< typename Device2 = DeviceType, typename Func >
   void forInternal( Func f ) const
   {
      // add static sizes
      using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >;
      // add dynamic sizes
      Begins begins;
      __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} );
      __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins );

      // subtract static sizes
      using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type;
      // subtract dynamic sizes
      Ends ends;
      __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes );
      __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds );

      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }

   // iterate over local elements inside the given [begins, ends) range specified by global indices
   template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends >
   void forInternal( Func f, const Begins& begins, const Ends& ends ) const
   {
      // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds"
      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalInternal( Func f ) const
   {
      // add dynamic sizes
      LocalBeginsType begins;
      __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false );

      // subtract dynamic sizes
      SizesHolderType ends;
      __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false );

      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }


   // extra methods

+47 −0
Original line number Diff line number Diff line
@@ -221,6 +221,53 @@ public:
      dispatch( localBegins, localEnds, f );
   }

   // iterate over local elements which are not neighbours of *global* boundaries
   template< typename Device2 = DeviceType, typename Func >
   void forInternal( Func f ) const
   {
      // add static sizes
      using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >;
      // add dynamic sizes
      Begins begins;
      __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} );
      __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins );

      // subtract static sizes
      using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type;
      // subtract dynamic sizes
      Ends ends;
      __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes );
      __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds );

      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }

   // iterate over local elements inside the given [begins, ends) range specified by global indices
   template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends >
   void forInternal( Func f, const Begins& begins, const Ends& ends ) const
   {
      // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds"
      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }

   // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll)
   template< typename Device2 = DeviceType, typename Func >
   void forLocalInternal( Func f ) const
   {
      // add dynamic sizes
      LocalBeginsType begins;
      __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false );

      // subtract dynamic sizes
      SizesHolderType ends;
      __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false );

      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( begins, ends, f );
   }

protected:
   NDArrayView localView;
   CommunicationGroup group = Communicator::NullGroup;
+118 −8
Original line number Diff line number Diff line
@@ -118,28 +118,138 @@ void setSizesHelper( SizesHolder& holder,
template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesSubtractHelper
{
   static void subtract( TargetHolder& target,
                         const SourceHolder& source )
                         const SourceHolder& source,
                         bool negateOverlaps = true )
   {
      if( source.template getStaticSize< level >() == 0 ) {
         if( negateOverlaps )
            target.template setSize< level >( source.template getSize< level >() - ConstValue * ! get< level >( Overlaps{} ) );
         else
            target.template setSize< level >( source.template getSize< level >() - ConstValue * !! get< level >( Overlaps{} ) );
      }
      SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::subtract( target, source );
   }
};

template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps >
struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 >
{
      if( source.template getStaticSize< level >() == 0 )
         target.template setSize< level >( source.template getSize< level >() - ConstValue );
      SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, level - 1 >::subtract( target, source );
   static void subtract( TargetHolder& target,
                         const SourceHolder& source,
                         bool negateOverlaps = true )
   {
      if( source.template getStaticSize< 0 >() == 0 ) {
         if( negateOverlaps )
            target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * ! get< 0 >( Overlaps{} ) );
         else
            target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * !! get< 0 >( Overlaps{} ) );
      }
   }
};


// helper for the forInternal method (DistributedNDArray)
template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesAddHelper
{
   static void add( TargetHolder& target,
                    const SourceHolder& source,
                    bool negateOverlaps = true )
   {
      if( source.template getStaticSize< level >() == 0 ) {
         if( negateOverlaps )
            target.template setSize< level >( source.template getSize< level >() + ConstValue * ! get< level >( Overlaps{} ) );
         else
            target.template setSize< level >( source.template getSize< level >() + ConstValue * !! get< level >( Overlaps{} ) );
      }
      SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::add( target, source );
   }
};

template< std::size_t ConstValue,
          typename TargetHolder,
          typename SourceHolder,
          typename Overlaps >
struct SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 >
{
   static void add( TargetHolder& target,
                    const SourceHolder& source,
                    bool negateOverlaps = true )
   {
      if( source.template getStaticSize< 0 >() == 0 ) {
         if( negateOverlaps )
            target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * ! get< 0 >( Overlaps{} ) );
         else
            target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * !! get< 0 >( Overlaps{} ) );
      }
   }
};


// helper for the forInternal method (DistributedNDArray)
template< typename TargetHolder,
          typename SourceHolder,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesMaxHelper
{
   static void max( TargetHolder& target,
                    const SourceHolder& source )
   {
      if( source.template getStaticSize< level >() == 0 )
         target.template setSize< level >( std::max( target.template getSize< level >(), source.template getSize< level >() ) );
      SetSizesMaxHelper< TargetHolder, SourceHolder, level - 1 >::max( target, source );
   }
};

template< typename TargetHolder,
          typename SourceHolder >
struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, 0 >
struct SetSizesMaxHelper< TargetHolder, SourceHolder, 0 >
{
   static void subtract( TargetHolder& target,
   static void max( TargetHolder& target,
                    const SourceHolder& source )
   {
      if( source.template getStaticSize< 0 >() == 0 )
         target.template setSize< 0 >( std::max( target.template getSize< 0 >(), source.template getSize< 0 >() ) );
   }
};


// helper for the forInternal method (DistributedNDArray)
template< typename TargetHolder,
          typename SourceHolder,
          std::size_t level = TargetHolder::getDimension() - 1 >
struct SetSizesMinHelper
{
   static void min( TargetHolder& target,
                    const SourceHolder& source )
   {
      if( source.template getStaticSize< level >() == 0 )
         target.template setSize< level >( std::min( target.template getSize< level >(), source.template getSize< level >() ) );
      SetSizesMinHelper< TargetHolder, SourceHolder, level - 1 >::min( target, source );
   }
};

template< typename TargetHolder,
          typename SourceHolder >
struct SetSizesMinHelper< TargetHolder, SourceHolder, 0 >
{
   static void min( TargetHolder& target,
                    const SourceHolder& source )
   {
      if( source.template getStaticSize< 0 >() == 0 )
         target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue );
         target.template setSize< 0 >( std::min( target.template getSize< 0 >(), source.template getSize< 0 >() ) );
   }
};

+10 −6
Original line number Diff line number Diff line
@@ -276,14 +276,16 @@ struct SubtractedSizesHolder< SizesHolder< Index, sizes... >, ConstValue >


// wrapper for localBegins in DistributedNDArray (static sizes cannot be distributed, begins are always 0)
template< typename SizesHolder >
template< typename SizesHolder,
          // overridable value is useful in the forInternal method
          std::size_t ConstValue = 0 >
struct LocalBeginsHolder : public SizesHolder
{
   template< std::size_t dimension >
   static constexpr std::size_t getStaticSize()
   {
      static_assert( dimension < SizesHolder::getDimension(), "Invalid dimension passed to getStaticSize()." );
      return 0;
      return ConstValue;
   }

   template< std::size_t level >
@@ -291,18 +293,20 @@ struct LocalBeginsHolder : public SizesHolder
   typename SizesHolder::IndexType getSize() const
   {
      if( SizesHolder::template getStaticSize< level >() != 0 )
         return 0;
         return ConstValue;
      return SizesHolder::template getSize< level >();
   }
};

template< typename Index,
          std::size_t... sizes >
std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... > >& holder )
          std::size_t... sizes,
          std::size_t ConstValue >
std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... >, ConstValue >& holder )
{
   str << "LocalBeginsHolder< SizesHolder< ";
   TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderStaticSizePrinter >::execHost( str, (SizesHolder< Index, sizes... >) holder );
   str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " > >( ";
   str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " >, ";
   str << ConstValue << " >( ";
   TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderSizePrinter >::execHost( str, holder );
   str << holder.template getSize< sizeof...(sizes) - 1 >() << " )";
   return str;
+84 −0
Original line number Diff line number Diff line
@@ -292,6 +292,90 @@ TYPED_TEST( DistributedNDArrayTest, forAll )
   test_helper_forAll( this->distributedNDArray );
}

// separate function because nvcc does not allow __cuda_callable__ lambdas inside
// private or protected methods (which are created by TYPED_TEST macro)
template< typename DistributedArray >
void test_helper_forInternal( DistributedArray& a )
{
   using IndexType = typename DistributedArray::IndexType;

   const auto localRange = a.template getLocalRange< 0 >();
   auto a_view = a.getView();

   auto setter = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      a_view( i ) += 1;
   };

   a.setValue( 0 );
   a.forInternal( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
   {
      if( gi == 0 || gi == a.template getSize< 0 >() - 1 )
         EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;
      else
         EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;
   }

   a.setValue( 0 );
   a_view.forInternal( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
   {
      if( gi == 0 || gi == a.template getSize< 0 >() - 1 )
         EXPECT_EQ( a.getElement( gi ), 0 )
            << "gi = " << gi;
      else
         EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;
   }
}

TYPED_TEST( DistributedNDArrayTest, forInternal )
{
   test_helper_forInternal( this->distributedNDArray );
}

// separate function because nvcc does not allow __cuda_callable__ lambdas inside
// private or protected methods (which are created by TYPED_TEST macro)
template< typename DistributedArray >
void test_helper_forLocalInternal( DistributedArray& a )
{
   using IndexType = typename DistributedArray::IndexType;

   const auto localRange = a.template getLocalRange< 0 >();
   auto a_view = a.getView();

   auto setter = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      a_view( i ) += 1;
   };

   a.setValue( 0 );
   // equivalent to forAll because all overlaps are 0
   a.forLocalInternal( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;

   a.setValue( 0 );
   // equivalent to forAll because all overlaps are 0
   a_view.forLocalInternal( setter );

   for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ )
      EXPECT_EQ( a.getElement( gi ), 1 )
            << "gi = " << gi;
}

TYPED_TEST( DistributedNDArrayTest, forLocalInternal )
{
   test_helper_forLocalInternal( this->distributedNDArray );
}

#endif  // HAVE_GTEST