Loading src/TNL/Containers/DistributedNDArray.h +47 −0 Original line number Diff line number Diff line Loading @@ -192,6 +192,53 @@ public: dispatch( localBegins, localEnds, f ); } // iterate over local elements which are not neighbours of *global* boundaries template< typename Device2 = DeviceType, typename Func > void forInternal( Func f ) const { // add static sizes using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >; // add dynamic sizes Begins begins; __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} ); __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins ); // subtract static sizes using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type; // subtract dynamic sizes Ends ends; __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes ); __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements inside the given [begins, ends) range specified by global indices template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends > void forInternal( Func f, const Begins& begins, const Ends& ends ) const { // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds" __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll) template< typename Device2 = DeviceType, typename Func > void forLocalInternal( Func f ) const { // add dynamic sizes LocalBeginsType begins; __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false ); // subtract dynamic sizes SizesHolderType ends; __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // extra methods Loading src/TNL/Containers/DistributedNDArrayView.h +47 −0 Original line number Diff line number Diff line Loading @@ -221,6 +221,53 @@ public: dispatch( localBegins, localEnds, f ); } // iterate over local elements which are not neighbours of *global* boundaries template< typename Device2 = DeviceType, typename Func > void forInternal( Func f ) const { // add static sizes using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >; // add dynamic sizes Begins begins; __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} ); __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins ); // subtract static sizes using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type; // subtract dynamic sizes Ends ends; __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes ); __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements inside the given [begins, ends) range specified by global indices template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends > void forInternal( Func f, const Begins& begins, const Ends& ends ) const { // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds" __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll) template< typename Device2 = DeviceType, typename Func > void forLocalInternal( Func f ) const { // add dynamic sizes LocalBeginsType begins; __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false ); // subtract dynamic sizes SizesHolderType ends; __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } protected: NDArrayView localView; CommunicationGroup group = Communicator::NullGroup; Loading src/TNL/Containers/ndarray/Indexing.h +118 −8 Original line number Diff line number Diff line Loading @@ -118,28 +118,138 @@ void setSizesHelper( SizesHolder& holder, template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesSubtractHelper { static void subtract( TargetHolder& target, const SourceHolder& source ) const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< level >() == 0 ) { if( negateOverlaps ) target.template setSize< level >( source.template getSize< level >() - ConstValue * ! get< level >( Overlaps{} ) ); else target.template setSize< level >( source.template getSize< level >() - ConstValue * !! get< level >( Overlaps{} ) ); } SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::subtract( target, source ); } }; template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps > struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 > { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( source.template getSize< level >() - ConstValue ); SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, level - 1 >::subtract( target, source ); static void subtract( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< 0 >() == 0 ) { if( negateOverlaps ) target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * ! get< 0 >( Overlaps{} ) ); else target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * !! get< 0 >( Overlaps{} ) ); } } }; // helper for the forInternal method (DistributedNDArray) template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesAddHelper { static void add( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< level >() == 0 ) { if( negateOverlaps ) target.template setSize< level >( source.template getSize< level >() + ConstValue * ! get< level >( Overlaps{} ) ); else target.template setSize< level >( source.template getSize< level >() + ConstValue * !! get< level >( Overlaps{} ) ); } SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::add( target, source ); } }; template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps > struct SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 > { static void add( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< 0 >() == 0 ) { if( negateOverlaps ) target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * ! get< 0 >( Overlaps{} ) ); else target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * !! get< 0 >( Overlaps{} ) ); } } }; // helper for the forInternal method (DistributedNDArray) template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesMaxHelper { static void max( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( std::max( target.template getSize< level >(), source.template getSize< level >() ) ); SetSizesMaxHelper< TargetHolder, SourceHolder, level - 1 >::max( target, source ); } }; template< typename TargetHolder, typename SourceHolder > struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, 0 > struct SetSizesMaxHelper< TargetHolder, SourceHolder, 0 > { static void subtract( TargetHolder& target, static void max( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< 0 >() == 0 ) target.template setSize< 0 >( std::max( target.template getSize< 0 >(), source.template getSize< 0 >() ) ); } }; // helper for the forInternal method (DistributedNDArray) template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesMinHelper { static void min( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( std::min( target.template getSize< level >(), source.template getSize< level >() ) ); SetSizesMinHelper< TargetHolder, SourceHolder, level - 1 >::min( target, source ); } }; template< typename TargetHolder, typename SourceHolder > struct SetSizesMinHelper< TargetHolder, SourceHolder, 0 > { static void min( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< 0 >() == 0 ) target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue ); target.template setSize< 0 >( std::min( target.template getSize< 0 >(), source.template getSize< 0 >() ) ); } }; Loading src/TNL/Containers/ndarray/SizesHolder.h +10 −6 Original line number Diff line number Diff line Loading @@ -276,14 +276,16 @@ struct SubtractedSizesHolder< SizesHolder< Index, sizes... >, ConstValue > // wrapper for localBegins in DistributedNDArray (static sizes cannot be distributed, begins are always 0) template< typename SizesHolder > template< typename SizesHolder, // overridable value is useful in the forInternal method std::size_t ConstValue = 0 > struct LocalBeginsHolder : public SizesHolder { template< std::size_t dimension > static constexpr std::size_t getStaticSize() { static_assert( dimension < SizesHolder::getDimension(), "Invalid dimension passed to getStaticSize()." ); return 0; return ConstValue; } template< std::size_t level > Loading @@ -291,18 +293,20 @@ struct LocalBeginsHolder : public SizesHolder typename SizesHolder::IndexType getSize() const { if( SizesHolder::template getStaticSize< level >() != 0 ) return 0; return ConstValue; return SizesHolder::template getSize< level >(); } }; template< typename Index, std::size_t... sizes > std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... > >& holder ) std::size_t... sizes, std::size_t ConstValue > std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... >, ConstValue >& holder ) { str << "LocalBeginsHolder< SizesHolder< "; TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderStaticSizePrinter >::execHost( str, (SizesHolder< Index, sizes... >) holder ); str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " > >( "; str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " >, "; str << ConstValue << " >( "; TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderSizePrinter >::execHost( str, holder ); str << holder.template getSize< sizeof...(sizes) - 1 >() << " )"; return str; Loading src/UnitTests/Containers/ndarray/DistributedNDArrayTest.h +84 −0 Original line number Diff line number Diff line Loading @@ -292,6 +292,90 @@ TYPED_TEST( DistributedNDArrayTest, forAll ) test_helper_forAll( this->distributedNDArray ); } // separate function because nvcc does not allow __cuda_callable__ lambdas inside // private or protected methods (which are created by TYPED_TEST macro) template< typename DistributedArray > void test_helper_forInternal( DistributedArray& a ) { using IndexType = typename DistributedArray::IndexType; const auto localRange = a.template getLocalRange< 0 >(); auto a_view = a.getView(); auto setter = [=] __cuda_callable__ ( IndexType i ) mutable { a_view( i ) += 1; }; a.setValue( 0 ); a.forInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) { if( gi == 0 || gi == a.template getSize< 0 >() - 1 ) EXPECT_EQ( a.getElement( gi ), 0 ) << "gi = " << gi; else EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } a.setValue( 0 ); a_view.forInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) { if( gi == 0 || gi == a.template getSize< 0 >() - 1 ) EXPECT_EQ( a.getElement( gi ), 0 ) << "gi = " << gi; else EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } } TYPED_TEST( DistributedNDArrayTest, forInternal ) { test_helper_forInternal( this->distributedNDArray ); } // separate function because nvcc does not allow __cuda_callable__ lambdas inside // private or protected methods (which are created by TYPED_TEST macro) template< typename DistributedArray > void test_helper_forLocalInternal( DistributedArray& a ) { using IndexType = typename DistributedArray::IndexType; const auto localRange = a.template getLocalRange< 0 >(); auto a_view = a.getView(); auto setter = [=] __cuda_callable__ ( IndexType i ) mutable { a_view( i ) += 1; }; a.setValue( 0 ); // equivalent to forAll because all overlaps are 0 a.forLocalInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; a.setValue( 0 ); // equivalent to forAll because all overlaps are 0 a_view.forLocalInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } TYPED_TEST( DistributedNDArrayTest, forLocalInternal ) { test_helper_forLocalInternal( this->distributedNDArray ); } #endif // HAVE_GTEST Loading Loading
src/TNL/Containers/DistributedNDArray.h +47 −0 Original line number Diff line number Diff line Loading @@ -192,6 +192,53 @@ public: dispatch( localBegins, localEnds, f ); } // iterate over local elements which are not neighbours of *global* boundaries template< typename Device2 = DeviceType, typename Func > void forInternal( Func f ) const { // add static sizes using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >; // add dynamic sizes Begins begins; __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} ); __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins ); // subtract static sizes using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type; // subtract dynamic sizes Ends ends; __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes ); __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements inside the given [begins, ends) range specified by global indices template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends > void forInternal( Func f, const Begins& begins, const Ends& ends ) const { // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds" __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll) template< typename Device2 = DeviceType, typename Func > void forLocalInternal( Func f ) const { // add dynamic sizes LocalBeginsType begins; __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false ); // subtract dynamic sizes SizesHolderType ends; __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // extra methods Loading
src/TNL/Containers/DistributedNDArrayView.h +47 −0 Original line number Diff line number Diff line Loading @@ -221,6 +221,53 @@ public: dispatch( localBegins, localEnds, f ); } // iterate over local elements which are not neighbours of *global* boundaries template< typename Device2 = DeviceType, typename Func > void forInternal( Func f ) const { // add static sizes using Begins = __ndarray_impl::LocalBeginsHolder< SizesHolderType, 1 >; // add dynamic sizes Begins begins; __ndarray_impl::SetSizesAddHelper< 1, Begins, SizesHolderType, Overlaps >::add( begins, SizesHolderType{} ); __ndarray_impl::SetSizesMaxHelper< Begins, LocalBeginsType >::max( begins, localBegins ); // subtract static sizes using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolderType, 1 >::type; // subtract dynamic sizes Ends ends; __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolderType, Overlaps >::subtract( ends, globalSizes ); __ndarray_impl::SetSizesMinHelper< Ends, SizesHolderType >::min( ends, localEnds ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements inside the given [begins, ends) range specified by global indices template< typename Device2 = DeviceType, typename Func, typename Begins, typename Ends > void forInternal( Func f, const Begins& begins, const Ends& ends ) const { // TODO: assert "localBegins <= begins <= localEnds", "localBegins <= ends <= localEnds" __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } // iterate over local elements which are not neighbours of overlaps (if all overlaps are 0, it is equivalent to forAll) template< typename Device2 = DeviceType, typename Func > void forLocalInternal( Func f ) const { // add dynamic sizes LocalBeginsType begins; __ndarray_impl::SetSizesAddHelper< 1, LocalBeginsType, SizesHolderType, Overlaps >::add( begins, localBegins, false ); // subtract dynamic sizes SizesHolderType ends; __ndarray_impl::SetSizesSubtractHelper< 1, SizesHolderType, SizesHolderType, Overlaps >::subtract( ends, localEnds, false ); __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch; dispatch( begins, ends, f ); } protected: NDArrayView localView; CommunicationGroup group = Communicator::NullGroup; Loading
src/TNL/Containers/ndarray/Indexing.h +118 −8 Original line number Diff line number Diff line Loading @@ -118,28 +118,138 @@ void setSizesHelper( SizesHolder& holder, template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesSubtractHelper { static void subtract( TargetHolder& target, const SourceHolder& source ) const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< level >() == 0 ) { if( negateOverlaps ) target.template setSize< level >( source.template getSize< level >() - ConstValue * ! get< level >( Overlaps{} ) ); else target.template setSize< level >( source.template getSize< level >() - ConstValue * !! get< level >( Overlaps{} ) ); } SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::subtract( target, source ); } }; template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps > struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 > { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( source.template getSize< level >() - ConstValue ); SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, level - 1 >::subtract( target, source ); static void subtract( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< 0 >() == 0 ) { if( negateOverlaps ) target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * ! get< 0 >( Overlaps{} ) ); else target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue * !! get< 0 >( Overlaps{} ) ); } } }; // helper for the forInternal method (DistributedNDArray) template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps = make_constant_index_sequence< TargetHolder::getDimension(), 0 >, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesAddHelper { static void add( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< level >() == 0 ) { if( negateOverlaps ) target.template setSize< level >( source.template getSize< level >() + ConstValue * ! get< level >( Overlaps{} ) ); else target.template setSize< level >( source.template getSize< level >() + ConstValue * !! get< level >( Overlaps{} ) ); } SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, level - 1 >::add( target, source ); } }; template< std::size_t ConstValue, typename TargetHolder, typename SourceHolder, typename Overlaps > struct SetSizesAddHelper< ConstValue, TargetHolder, SourceHolder, Overlaps, 0 > { static void add( TargetHolder& target, const SourceHolder& source, bool negateOverlaps = true ) { if( source.template getStaticSize< 0 >() == 0 ) { if( negateOverlaps ) target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * ! get< 0 >( Overlaps{} ) ); else target.template setSize< 0 >( source.template getSize< 0 >() + ConstValue * !! get< 0 >( Overlaps{} ) ); } } }; // helper for the forInternal method (DistributedNDArray) template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesMaxHelper { static void max( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( std::max( target.template getSize< level >(), source.template getSize< level >() ) ); SetSizesMaxHelper< TargetHolder, SourceHolder, level - 1 >::max( target, source ); } }; template< typename TargetHolder, typename SourceHolder > struct SetSizesSubtractHelper< ConstValue, TargetHolder, SourceHolder, 0 > struct SetSizesMaxHelper< TargetHolder, SourceHolder, 0 > { static void subtract( TargetHolder& target, static void max( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< 0 >() == 0 ) target.template setSize< 0 >( std::max( target.template getSize< 0 >(), source.template getSize< 0 >() ) ); } }; // helper for the forInternal method (DistributedNDArray) template< typename TargetHolder, typename SourceHolder, std::size_t level = TargetHolder::getDimension() - 1 > struct SetSizesMinHelper { static void min( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< level >() == 0 ) target.template setSize< level >( std::min( target.template getSize< level >(), source.template getSize< level >() ) ); SetSizesMinHelper< TargetHolder, SourceHolder, level - 1 >::min( target, source ); } }; template< typename TargetHolder, typename SourceHolder > struct SetSizesMinHelper< TargetHolder, SourceHolder, 0 > { static void min( TargetHolder& target, const SourceHolder& source ) { if( source.template getStaticSize< 0 >() == 0 ) target.template setSize< 0 >( source.template getSize< 0 >() - ConstValue ); target.template setSize< 0 >( std::min( target.template getSize< 0 >(), source.template getSize< 0 >() ) ); } }; Loading
src/TNL/Containers/ndarray/SizesHolder.h +10 −6 Original line number Diff line number Diff line Loading @@ -276,14 +276,16 @@ struct SubtractedSizesHolder< SizesHolder< Index, sizes... >, ConstValue > // wrapper for localBegins in DistributedNDArray (static sizes cannot be distributed, begins are always 0) template< typename SizesHolder > template< typename SizesHolder, // overridable value is useful in the forInternal method std::size_t ConstValue = 0 > struct LocalBeginsHolder : public SizesHolder { template< std::size_t dimension > static constexpr std::size_t getStaticSize() { static_assert( dimension < SizesHolder::getDimension(), "Invalid dimension passed to getStaticSize()." ); return 0; return ConstValue; } template< std::size_t level > Loading @@ -291,18 +293,20 @@ struct LocalBeginsHolder : public SizesHolder typename SizesHolder::IndexType getSize() const { if( SizesHolder::template getStaticSize< level >() != 0 ) return 0; return ConstValue; return SizesHolder::template getSize< level >(); } }; template< typename Index, std::size_t... sizes > std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... > >& holder ) std::size_t... sizes, std::size_t ConstValue > std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... >, ConstValue >& holder ) { str << "LocalBeginsHolder< SizesHolder< "; TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderStaticSizePrinter >::execHost( str, (SizesHolder< Index, sizes... >) holder ); str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " > >( "; str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " >, "; str << ConstValue << " >( "; TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderSizePrinter >::execHost( str, holder ); str << holder.template getSize< sizeof...(sizes) - 1 >() << " )"; return str; Loading
src/UnitTests/Containers/ndarray/DistributedNDArrayTest.h +84 −0 Original line number Diff line number Diff line Loading @@ -292,6 +292,90 @@ TYPED_TEST( DistributedNDArrayTest, forAll ) test_helper_forAll( this->distributedNDArray ); } // separate function because nvcc does not allow __cuda_callable__ lambdas inside // private or protected methods (which are created by TYPED_TEST macro) template< typename DistributedArray > void test_helper_forInternal( DistributedArray& a ) { using IndexType = typename DistributedArray::IndexType; const auto localRange = a.template getLocalRange< 0 >(); auto a_view = a.getView(); auto setter = [=] __cuda_callable__ ( IndexType i ) mutable { a_view( i ) += 1; }; a.setValue( 0 ); a.forInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) { if( gi == 0 || gi == a.template getSize< 0 >() - 1 ) EXPECT_EQ( a.getElement( gi ), 0 ) << "gi = " << gi; else EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } a.setValue( 0 ); a_view.forInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) { if( gi == 0 || gi == a.template getSize< 0 >() - 1 ) EXPECT_EQ( a.getElement( gi ), 0 ) << "gi = " << gi; else EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } } TYPED_TEST( DistributedNDArrayTest, forInternal ) { test_helper_forInternal( this->distributedNDArray ); } // separate function because nvcc does not allow __cuda_callable__ lambdas inside // private or protected methods (which are created by TYPED_TEST macro) template< typename DistributedArray > void test_helper_forLocalInternal( DistributedArray& a ) { using IndexType = typename DistributedArray::IndexType; const auto localRange = a.template getLocalRange< 0 >(); auto a_view = a.getView(); auto setter = [=] __cuda_callable__ ( IndexType i ) mutable { a_view( i ) += 1; }; a.setValue( 0 ); // equivalent to forAll because all overlaps are 0 a.forLocalInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; a.setValue( 0 ); // equivalent to forAll because all overlaps are 0 a_view.forLocalInternal( setter ); for( int gi = localRange.getBegin(); gi < localRange.getEnd(); gi++ ) EXPECT_EQ( a.getElement( gi ), 1 ) << "gi = " << gi; } TYPED_TEST( DistributedNDArrayTest, forLocalInternal ) { test_helper_forLocalInternal( this->distributedNDArray ); } #endif // HAVE_GTEST Loading