Loading src/TNL/Containers/DistributedArray.h +1 −3 Original line number Diff line number Diff line Loading @@ -168,9 +168,7 @@ public: // TODO: serialization (save, load) protected: LocalRangeType localRange; IndexType globalSize = 0; CommunicationGroup group = Communicator::NullGroup; ViewType view; LocalArrayType localData; }; Loading src/TNL/Containers/DistributedArray.hpp +21 −62 Original line number Diff line number Diff line Loading @@ -39,11 +39,9 @@ DistributedArray< Value, Device, Index, Communicator >:: setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) { TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" ); this->localRange = localRange; this->globalSize = globalSize; this->group = group; if( group != Communicator::NullGroup ) localData.setSize( localRange.getSize() ); view.bind( localRange, globalSize, group, localData.getView() ); } template< typename Value, Loading @@ -54,7 +52,7 @@ const Subrange< Index >& DistributedArray< Value, Device, Index, Communicator >:: getLocalRange() const { return localRange; return view.getLocalRange(); } template< typename Value, Loading @@ -65,7 +63,7 @@ typename Communicator::CommunicationGroup DistributedArray< Value, Device, Index, Communicator >:: getCommunicationGroup() const { return group; return view.getCommunicationGroup(); } template< typename Value, Loading Loading @@ -99,18 +97,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: copyFromGlobal( ConstLocalViewType globalArray ) { TNL_ASSERT_EQ( getSize(), globalArray.getSize(), "given global array has different size than the distributed array" ); LocalViewType localView( localData ); const LocalRangeType localRange = getLocalRange(); auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable { localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ]; }; Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel ); view.copyFromGlobal( globalArray ); } Loading @@ -126,7 +113,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ViewType DistributedArray< Value, Device, Index, Communicator >:: getView() { return ViewType( getLocalRange(), getSize(), getCommunicationGroup(), getLocalView() ); return view; } template< typename Value, Loading @@ -137,7 +124,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstViewType DistributedArray< Value, Device, Index, Communicator >:: getConstView() const { return ConstViewType( getLocalRange(), getSize(), getCommunicationGroup(), getConstLocalView() ); return view.getConstView(); } template< typename Value, Loading Loading @@ -169,10 +156,8 @@ void DistributedArray< Value, Device, Index, Communicator >:: setLike( const Array& array ) { localRange = array.getLocalRange(); globalSize = array.getSize(); group = array.getCommunicationGroup(); localData.setLike( array.getConstLocalView() ); view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); } template< typename Value, Loading @@ -183,9 +168,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: reset() { localRange.reset(); globalSize = 0; group = Communicator::NullGroup; view.reset(); localData.reset(); } Loading @@ -197,7 +180,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: empty() const { return getSize() == 0; return view.empty(); } template< typename Value, Loading @@ -208,7 +191,7 @@ Index DistributedArray< Value, Device, Index, Communicator >:: getSize() const { return globalSize; return view.getSize(); } template< typename Value, Loading @@ -219,7 +202,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: setValue( ValueType value ) { localData.setValue( value ); view.setValue( value ); } template< typename Value, Loading @@ -230,8 +213,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: setElement( IndexType i, ValueType value ) { const IndexType li = localRange.getLocalIndex( i ); localData.setElement( li, value ); view.setElement( i, value ); } template< typename Value, Loading @@ -242,8 +224,7 @@ Value DistributedArray< Value, Device, Index, Communicator >:: getElement( IndexType i ) const { const IndexType li = localRange.getLocalIndex( i ); return localData.getElement( li ); return view.getElement( i ); } template< typename Value, Loading @@ -255,8 +236,7 @@ Value& DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) { const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; return view[ i ]; } template< typename Value, Loading @@ -268,8 +248,7 @@ const Value& DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) const { const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; return view[ i ]; } template< typename Value, Loading @@ -281,7 +260,7 @@ DistributedArray< Value, Device, Index, Communicator >:: operator=( const DistributedArray& array ) { setLike( array ); localData = array.getConstLocalView(); view = array; return *this; } Loading @@ -295,7 +274,7 @@ DistributedArray< Value, Device, Index, Communicator >:: operator=( const Array& array ) { setLike( array ); localData = array.getConstLocalView(); view = array; return *this; } Loading @@ -308,17 +287,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: operator==( const Array& array ) const { // we can't run allreduce if the communication groups are different if( group != array.getCommunicationGroup() ) return false; const bool localResult = localRange == array.getLocalRange() && globalSize == array.getSize() && localData == array.getConstLocalView(); bool result = true; if( group != CommunicatorType::NullGroup ) CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group ); return result; return view == array; } template< typename Value, Loading @@ -330,7 +299,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: operator!=( const Array& array ) const { return ! (*this == array); return view != array; } template< typename Value, Loading @@ -341,12 +310,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: containsValue( ValueType value ) const { bool result = false; if( group != CommunicatorType::NullGroup ) { const bool localResult = localData.containsValue( value ); CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LOR, group ); } return result; return view.containsValue( value ); } template< typename Value, Loading @@ -357,12 +321,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: containsOnlyValue( ValueType value ) const { bool result = true; if( group != CommunicatorType::NullGroup ) { const bool localResult = localData.containsOnlyValue( value ); CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group ); } return result; return view.containsOnlyValue( value ); } } // namespace Containers Loading src/TNL/Containers/DistributedArrayView.h +6 −3 Original line number Diff line number Diff line Loading @@ -74,9 +74,12 @@ public: __cuda_callable__ DistributedArrayView( DistributedArrayView&& ) = default; // method for rebinding (reinitialization) // Note that you can also bind directly to Array and other types implicitly // convertible to ArrayView. // method for rebinding (reinitialization) to raw data __cuda_callable__ void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); // Note that you can also bind directly to DistributedArray and other types implicitly // convertible to DistributedArrayView. __cuda_callable__ void bind( DistributedArrayView view ); Loading src/TNL/Containers/DistributedArrayView.hpp +18 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,24 @@ DistributedArrayView( const DistributedArrayView< Value_, Device, Index, Communi localData( view.getConstLocalView() ) {} template< typename Value, typename Device, typename Index, typename Communicator > __cuda_callable__ void DistributedArrayView< Value, Device, Index, Communicator >:: bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) { TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(), "The local array size does not match the local range of the distributed array." ); this->localRange = localRange; this->globalSize = globalSize; this->group = group; this->localData.bind( localData ); } template< typename Value, typename Device, typename Index, Loading Loading
src/TNL/Containers/DistributedArray.h +1 −3 Original line number Diff line number Diff line Loading @@ -168,9 +168,7 @@ public: // TODO: serialization (save, load) protected: LocalRangeType localRange; IndexType globalSize = 0; CommunicationGroup group = Communicator::NullGroup; ViewType view; LocalArrayType localData; }; Loading
src/TNL/Containers/DistributedArray.hpp +21 −62 Original line number Diff line number Diff line Loading @@ -39,11 +39,9 @@ DistributedArray< Value, Device, Index, Communicator >:: setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) { TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" ); this->localRange = localRange; this->globalSize = globalSize; this->group = group; if( group != Communicator::NullGroup ) localData.setSize( localRange.getSize() ); view.bind( localRange, globalSize, group, localData.getView() ); } template< typename Value, Loading @@ -54,7 +52,7 @@ const Subrange< Index >& DistributedArray< Value, Device, Index, Communicator >:: getLocalRange() const { return localRange; return view.getLocalRange(); } template< typename Value, Loading @@ -65,7 +63,7 @@ typename Communicator::CommunicationGroup DistributedArray< Value, Device, Index, Communicator >:: getCommunicationGroup() const { return group; return view.getCommunicationGroup(); } template< typename Value, Loading Loading @@ -99,18 +97,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: copyFromGlobal( ConstLocalViewType globalArray ) { TNL_ASSERT_EQ( getSize(), globalArray.getSize(), "given global array has different size than the distributed array" ); LocalViewType localView( localData ); const LocalRangeType localRange = getLocalRange(); auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable { localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ]; }; Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel ); view.copyFromGlobal( globalArray ); } Loading @@ -126,7 +113,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ViewType DistributedArray< Value, Device, Index, Communicator >:: getView() { return ViewType( getLocalRange(), getSize(), getCommunicationGroup(), getLocalView() ); return view; } template< typename Value, Loading @@ -137,7 +124,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstViewType DistributedArray< Value, Device, Index, Communicator >:: getConstView() const { return ConstViewType( getLocalRange(), getSize(), getCommunicationGroup(), getConstLocalView() ); return view.getConstView(); } template< typename Value, Loading Loading @@ -169,10 +156,8 @@ void DistributedArray< Value, Device, Index, Communicator >:: setLike( const Array& array ) { localRange = array.getLocalRange(); globalSize = array.getSize(); group = array.getCommunicationGroup(); localData.setLike( array.getConstLocalView() ); view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); } template< typename Value, Loading @@ -183,9 +168,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: reset() { localRange.reset(); globalSize = 0; group = Communicator::NullGroup; view.reset(); localData.reset(); } Loading @@ -197,7 +180,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: empty() const { return getSize() == 0; return view.empty(); } template< typename Value, Loading @@ -208,7 +191,7 @@ Index DistributedArray< Value, Device, Index, Communicator >:: getSize() const { return globalSize; return view.getSize(); } template< typename Value, Loading @@ -219,7 +202,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: setValue( ValueType value ) { localData.setValue( value ); view.setValue( value ); } template< typename Value, Loading @@ -230,8 +213,7 @@ void DistributedArray< Value, Device, Index, Communicator >:: setElement( IndexType i, ValueType value ) { const IndexType li = localRange.getLocalIndex( i ); localData.setElement( li, value ); view.setElement( i, value ); } template< typename Value, Loading @@ -242,8 +224,7 @@ Value DistributedArray< Value, Device, Index, Communicator >:: getElement( IndexType i ) const { const IndexType li = localRange.getLocalIndex( i ); return localData.getElement( li ); return view.getElement( i ); } template< typename Value, Loading @@ -255,8 +236,7 @@ Value& DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) { const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; return view[ i ]; } template< typename Value, Loading @@ -268,8 +248,7 @@ const Value& DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) const { const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; return view[ i ]; } template< typename Value, Loading @@ -281,7 +260,7 @@ DistributedArray< Value, Device, Index, Communicator >:: operator=( const DistributedArray& array ) { setLike( array ); localData = array.getConstLocalView(); view = array; return *this; } Loading @@ -295,7 +274,7 @@ DistributedArray< Value, Device, Index, Communicator >:: operator=( const Array& array ) { setLike( array ); localData = array.getConstLocalView(); view = array; return *this; } Loading @@ -308,17 +287,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: operator==( const Array& array ) const { // we can't run allreduce if the communication groups are different if( group != array.getCommunicationGroup() ) return false; const bool localResult = localRange == array.getLocalRange() && globalSize == array.getSize() && localData == array.getConstLocalView(); bool result = true; if( group != CommunicatorType::NullGroup ) CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group ); return result; return view == array; } template< typename Value, Loading @@ -330,7 +299,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: operator!=( const Array& array ) const { return ! (*this == array); return view != array; } template< typename Value, Loading @@ -341,12 +310,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: containsValue( ValueType value ) const { bool result = false; if( group != CommunicatorType::NullGroup ) { const bool localResult = localData.containsValue( value ); CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LOR, group ); } return result; return view.containsValue( value ); } template< typename Value, Loading @@ -357,12 +321,7 @@ bool DistributedArray< Value, Device, Index, Communicator >:: containsOnlyValue( ValueType value ) const { bool result = true; if( group != CommunicatorType::NullGroup ) { const bool localResult = localData.containsOnlyValue( value ); CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group ); } return result; return view.containsOnlyValue( value ); } } // namespace Containers Loading
src/TNL/Containers/DistributedArrayView.h +6 −3 Original line number Diff line number Diff line Loading @@ -74,9 +74,12 @@ public: __cuda_callable__ DistributedArrayView( DistributedArrayView&& ) = default; // method for rebinding (reinitialization) // Note that you can also bind directly to Array and other types implicitly // convertible to ArrayView. // method for rebinding (reinitialization) to raw data __cuda_callable__ void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); // Note that you can also bind directly to DistributedArray and other types implicitly // convertible to DistributedArrayView. __cuda_callable__ void bind( DistributedArrayView view ); Loading
src/TNL/Containers/DistributedArrayView.hpp +18 −0 Original line number Diff line number Diff line Loading @@ -31,6 +31,24 @@ DistributedArrayView( const DistributedArrayView< Value_, Device, Index, Communi localData( view.getConstLocalView() ) {} template< typename Value, typename Device, typename Index, typename Communicator > __cuda_callable__ void DistributedArrayView< Value, Device, Index, Communicator >:: bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) { TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(), "The local array size does not match the local range of the distributed array." ); this->localRange = localRange; this->globalSize = globalSize; this->group = group; this->localData.bind( localData ); } template< typename Value, typename Device, typename Index, Loading