Loading src/Benchmarks/DistSpMV/tnl-benchmark-distributed-spmv.h +2 −2 Original line number Diff line number Diff line Loading @@ -228,10 +228,10 @@ struct SpmvBenchmark const auto group = CommunicatorType::AllGroup; const auto localRange = Partitioner::splitRange( matrix.getRows(), group ); DistributedMatrix distributedMatrix( localRange, matrix.getRows(), matrix.getColumns(), group ); DistributedVector distributedVector( localRange, matrix.getRows(), group ); DistributedVector distributedVector( localRange, 0, matrix.getRows(), group ); // copy the row lengths from the global matrix to the distributed matrix DistributedRowLengths distributedRowLengths( localRange, matrix.getRows(), group ); DistributedRowLengths distributedRowLengths( localRange, 0, matrix.getRows(), group ); for( IndexType i = 0; i < distributedMatrix.getLocalMatrix().getRows(); i++ ) { const auto gi = distributedMatrix.getLocalRowRange().getGlobalIndex( i ); distributedRowLengths[ gi ] = matrix.getRowCapacity( gi ); Loading src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h +3 −3 Original line number Diff line number Diff line Loading @@ -435,11 +435,11 @@ struct LinearSolversBenchmark const auto group = CommunicatorType::AllGroup; const auto localRange = Partitioner::splitRange( matrixPointer->getRows(), group ); SharedPointer< DistributedMatrix > distMatrixPointer( localRange, matrixPointer->getRows(), matrixPointer->getColumns(), group ); DistributedVector dist_x0( localRange, matrixPointer->getRows(), group ); DistributedVector dist_b( localRange, matrixPointer->getRows(), group ); DistributedVector dist_x0( localRange, 0, matrixPointer->getRows(), group ); DistributedVector dist_b( localRange, 0, matrixPointer->getRows(), group ); // copy the row capacities from the global matrix to the distributed matrix DistributedRowLengths distributedRowLengths( localRange, matrixPointer->getRows(), group ); DistributedRowLengths distributedRowLengths( localRange, 0, matrixPointer->getRows(), group ); for( IndexType i = 0; i < distMatrixPointer->getLocalMatrix().getRows(); i++ ) { const auto gi = distMatrixPointer->getLocalRowRange().getGlobalIndex( i ); distributedRowLengths[ gi ] = matrixPointer->getRowCapacity( gi ); Loading src/TNL/Containers/DistributedArray.h +39 −21 Original line number Diff line number Diff line Loading @@ -37,6 +37,7 @@ public: using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >; using ViewType = DistributedArrayView< Value, Device, Index, Communicator >; using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >; using SynchronizerType = typename ViewType::SynchronizerType; /** * \brief A template which allows to quickly obtain a \ref DistributedArray type with changed template parameters. Loading @@ -50,46 +51,54 @@ public: DistributedArray() = default; DistributedArray( const DistributedArray& ) = default; // Copy-constructor does deep copy. DistributedArray( const DistributedArray& ); DistributedArray( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); DistributedArray( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); void setDistribution( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); void setDistribution( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); const LocalRangeType& getLocalRange() const; IndexType getGhosts() const; CommunicationGroup getCommunicationGroup() const; /** * \brief Returns a modifiable view of the local part of the array. * * If \e begin or \e end is set to a non-zero value, a view for the * sub-interval `[begin, end)` is returned. Otherwise a view for whole * local part of the array view is returned. * * \param begin The beginning of the array view sub-interval. It is 0 by * default. * \param end The end of the array view sub-interval. The default value is 0 * which is, however, replaced with the array size. */ LocalViewType getLocalView(); /** * \brief Returns a non-modifiable view of the local part of the array. * * If \e begin or \e end is set to a non-zero value, a view for the * sub-interval `[begin, end)` is returned. Otherwise a view for whole * local part of the array view is returned. * * \param begin The beginning of the array view sub-interval. It is 0 by * default. * \param end The end of the array view sub-interval. The default value is 0 * which is, however, replaced with the array size. */ ConstLocalViewType getConstLocalView() const; /** * \brief Returns a modifiable view of the local part of the array, * including ghost values. */ LocalViewType getLocalViewWithGhosts(); /** * \brief Returns a non-modifiable view of the local part of the array, * including ghost values. */ ConstLocalViewType getConstLocalViewWithGhosts() const; void copyFromGlobal( ConstLocalViewType globalArray ); // synchronizer stuff void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 ); std::shared_ptr< SynchronizerType > getSynchronizer() const; int getValuesPerElement() const; void startSynchronization(); void waitForSynchronization() const; // Usual Array methods follow below. Loading Loading @@ -170,6 +179,15 @@ public: protected: ViewType view; LocalArrayType localData; private: template< typename Array, std::enable_if_t< std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true > static void setSynchronizerHelper( ViewType& view, const Array& array ) { view.setSynchronizer( array.getSynchronizer(), array.getValuesPerElement() ); } template< typename Array, std::enable_if_t< ! std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true > static void setSynchronizerHelper( ViewType& view, const Array& array ) {} }; } // namespace Containers Loading src/TNL/Containers/DistributedArray.hpp +111 −9 Original line number Diff line number Diff line Loading @@ -25,9 +25,20 @@ template< typename Value, typename Index, typename Communicator > DistributedArray< Value, Device, Index, Communicator >:: DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) DistributedArray( const DistributedArray& array ) { setDistribution( localRange, globalSize, group ); setLike( array ); localData = array.getConstLocalViewWithGhosts(); } template< typename Value, typename Device, typename Index, typename Communicator > DistributedArray< Value, Device, Index, Communicator >:: DistributedArray( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group ) { setDistribution( localRange, ghosts, globalSize, group ); } template< typename Value, Loading @@ -36,12 +47,12 @@ template< typename Value, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) setDistribution( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group ) { TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" ); if( group != Communicator::NullGroup ) localData.setSize( localRange.getSize() ); view.bind( localRange, globalSize, group, localData.getView() ); localData.setSize( localRange.getSize() + ghosts ); view.bind( localRange, ghosts, globalSize, group, localData.getView() ); } template< typename Value, Loading @@ -55,6 +66,17 @@ getLocalRange() const return view.getLocalRange(); } template< typename Value, typename Device, typename Index, typename Communicator > Index DistributedArray< Value, Device, Index, Communicator >:: getGhosts() const { return view.getGhosts(); } template< typename Value, typename Device, typename Index, Loading @@ -74,7 +96,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType DistributedArray< Value, Device, Index, Communicator >:: getLocalView() { return localData.getView(); return view.getLocalView(); } template< typename Value, Loading @@ -85,7 +107,29 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewT DistributedArray< Value, Device, Index, Communicator >:: getConstLocalView() const { return localData.getConstView(); return view.getConstLocalView(); } template< typename Value, typename Device, typename Index, typename Communicator > typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType DistributedArray< Value, Device, Index, Communicator >:: getLocalViewWithGhosts() { return view.getLocalViewWithGhosts(); } template< typename Value, typename Device, typename Index, typename Communicator > typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewType DistributedArray< Value, Device, Index, Communicator >:: getConstLocalViewWithGhosts() const { return view.getConstLocalViewWithGhosts(); } Loading @@ -100,6 +144,61 @@ copyFromGlobal( ConstLocalViewType globalArray ) view.copyFromGlobal( globalArray ); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement ) { view.setSynchronizer( synchronizer, valuesPerElement ); } template< typename Value, typename Device, typename Index, typename Communicator > std::shared_ptr< typename DistributedArrayView< Value, Device, Index, Communicator >::SynchronizerType > DistributedArray< Value, Device, Index, Communicator >:: getSynchronizer() const { return view.getSynchronizer(); } template< typename Value, typename Device, typename Index, typename Communicator > int DistributedArray< Value, Device, Index, Communicator >:: getValuesPerElement() const { return view.getValuesPerElement(); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: startSynchronization() { view.startSynchronization(); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: waitForSynchronization() const { view.waitForSynchronization(); } /* * Usual Array methods follow below. Loading Loading @@ -156,8 +255,11 @@ void DistributedArray< Value, Device, Index, Communicator >:: setLike( const Array& array ) { localData.setLike( array.getConstLocalView() ); view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); localData.setLike( array.getConstLocalViewWithGhosts() ); view.bind( array.getLocalRange(), array.getGhosts(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); // set, but do not unset, the synchronizer if( array.getSynchronizer() ) setSynchronizerHelper( view, array ); } template< typename Value, Loading src/TNL/Containers/DistributedArrayView.h +31 −5 Original line number Diff line number Diff line Loading @@ -12,9 +12,12 @@ #pragma once #include <memory> #include <TNL/Containers/ArrayView.h> #include <TNL/Communicators/MpiCommunicator.h> #include <TNL/Containers/Subrange.h> #include <TNL/Containers/ByteArraySynchronizer.h> namespace TNL { namespace Containers { Loading @@ -36,6 +39,7 @@ public: using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >; using ViewType = DistributedArrayView< Value, Device, Index, Communicator >; using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >; using SynchronizerType = ByteArraySynchronizer< DeviceType, IndexType >; /** * \brief A template which allows to quickly obtain a \ref DistributedArrayView type with changed template parameters. Loading @@ -48,11 +52,12 @@ public: // Initialization by raw data DistributedArrayView( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) : localRange(localRange), globalSize(globalSize), group(group), localData(localData) DistributedArrayView( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) : localRange(localRange), ghosts(ghosts), globalSize(globalSize), group(group), localData(localData) { TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(), TNL_ASSERT_EQ( localData.getSize(), localRange.getSize() + ghosts, "The local array size does not match the local range of the distributed array." ); TNL_ASSERT_GE( ghosts, 0, "The ghosts count must be non-negative." ); } DistributedArrayView() = default; Loading @@ -68,27 +73,44 @@ public: DistributedArrayView( DistributedArrayView&& ) = default; // method for rebinding (reinitialization) to raw data void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); void bind( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); // Note that you can also bind directly to DistributedArray and other types implicitly // convertible to DistributedArrayView. void bind( DistributedArrayView view ); // binding to local array via raw pointer // (local range, global size and communication group are preserved) // (local range, ghosts, global size and communication group are preserved) template< typename Value_ > void bind( Value_* data, IndexType localSize ); const LocalRangeType& getLocalRange() const; IndexType getGhosts() const; CommunicationGroup getCommunicationGroup() const; LocalViewType getLocalView(); ConstLocalViewType getConstLocalView() const; LocalViewType getLocalViewWithGhosts(); ConstLocalViewType getConstLocalViewWithGhosts() const; void copyFromGlobal( ConstLocalViewType globalArray ); // synchronizer stuff void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 ); std::shared_ptr< SynchronizerType > getSynchronizer() const; int getValuesPerElement() const; void startSynchronization(); void waitForSynchronization() const; /* * Usual ArrayView methods follow below. Loading Loading @@ -156,9 +178,13 @@ public: protected: LocalRangeType localRange; IndexType ghosts = 0; IndexType globalSize = 0; CommunicationGroup group = Communicator::NullGroup; LocalViewType localData; std::shared_ptr< SynchronizerType > synchronizer = nullptr; int valuesPerElement = 1; }; } // namespace Containers Loading Loading
src/Benchmarks/DistSpMV/tnl-benchmark-distributed-spmv.h +2 −2 Original line number Diff line number Diff line Loading @@ -228,10 +228,10 @@ struct SpmvBenchmark const auto group = CommunicatorType::AllGroup; const auto localRange = Partitioner::splitRange( matrix.getRows(), group ); DistributedMatrix distributedMatrix( localRange, matrix.getRows(), matrix.getColumns(), group ); DistributedVector distributedVector( localRange, matrix.getRows(), group ); DistributedVector distributedVector( localRange, 0, matrix.getRows(), group ); // copy the row lengths from the global matrix to the distributed matrix DistributedRowLengths distributedRowLengths( localRange, matrix.getRows(), group ); DistributedRowLengths distributedRowLengths( localRange, 0, matrix.getRows(), group ); for( IndexType i = 0; i < distributedMatrix.getLocalMatrix().getRows(); i++ ) { const auto gi = distributedMatrix.getLocalRowRange().getGlobalIndex( i ); distributedRowLengths[ gi ] = matrix.getRowCapacity( gi ); Loading
src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h +3 −3 Original line number Diff line number Diff line Loading @@ -435,11 +435,11 @@ struct LinearSolversBenchmark const auto group = CommunicatorType::AllGroup; const auto localRange = Partitioner::splitRange( matrixPointer->getRows(), group ); SharedPointer< DistributedMatrix > distMatrixPointer( localRange, matrixPointer->getRows(), matrixPointer->getColumns(), group ); DistributedVector dist_x0( localRange, matrixPointer->getRows(), group ); DistributedVector dist_b( localRange, matrixPointer->getRows(), group ); DistributedVector dist_x0( localRange, 0, matrixPointer->getRows(), group ); DistributedVector dist_b( localRange, 0, matrixPointer->getRows(), group ); // copy the row capacities from the global matrix to the distributed matrix DistributedRowLengths distributedRowLengths( localRange, matrixPointer->getRows(), group ); DistributedRowLengths distributedRowLengths( localRange, 0, matrixPointer->getRows(), group ); for( IndexType i = 0; i < distMatrixPointer->getLocalMatrix().getRows(); i++ ) { const auto gi = distMatrixPointer->getLocalRowRange().getGlobalIndex( i ); distributedRowLengths[ gi ] = matrixPointer->getRowCapacity( gi ); Loading
src/TNL/Containers/DistributedArray.h +39 −21 Original line number Diff line number Diff line Loading @@ -37,6 +37,7 @@ public: using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >; using ViewType = DistributedArrayView< Value, Device, Index, Communicator >; using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >; using SynchronizerType = typename ViewType::SynchronizerType; /** * \brief A template which allows to quickly obtain a \ref DistributedArray type with changed template parameters. Loading @@ -50,46 +51,54 @@ public: DistributedArray() = default; DistributedArray( const DistributedArray& ) = default; // Copy-constructor does deep copy. DistributedArray( const DistributedArray& ); DistributedArray( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); DistributedArray( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); void setDistribution( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); void setDistribution( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); const LocalRangeType& getLocalRange() const; IndexType getGhosts() const; CommunicationGroup getCommunicationGroup() const; /** * \brief Returns a modifiable view of the local part of the array. * * If \e begin or \e end is set to a non-zero value, a view for the * sub-interval `[begin, end)` is returned. Otherwise a view for whole * local part of the array view is returned. * * \param begin The beginning of the array view sub-interval. It is 0 by * default. * \param end The end of the array view sub-interval. The default value is 0 * which is, however, replaced with the array size. */ LocalViewType getLocalView(); /** * \brief Returns a non-modifiable view of the local part of the array. * * If \e begin or \e end is set to a non-zero value, a view for the * sub-interval `[begin, end)` is returned. Otherwise a view for whole * local part of the array view is returned. * * \param begin The beginning of the array view sub-interval. It is 0 by * default. * \param end The end of the array view sub-interval. The default value is 0 * which is, however, replaced with the array size. */ ConstLocalViewType getConstLocalView() const; /** * \brief Returns a modifiable view of the local part of the array, * including ghost values. */ LocalViewType getLocalViewWithGhosts(); /** * \brief Returns a non-modifiable view of the local part of the array, * including ghost values. */ ConstLocalViewType getConstLocalViewWithGhosts() const; void copyFromGlobal( ConstLocalViewType globalArray ); // synchronizer stuff void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 ); std::shared_ptr< SynchronizerType > getSynchronizer() const; int getValuesPerElement() const; void startSynchronization(); void waitForSynchronization() const; // Usual Array methods follow below. Loading Loading @@ -170,6 +179,15 @@ public: protected: ViewType view; LocalArrayType localData; private: template< typename Array, std::enable_if_t< std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true > static void setSynchronizerHelper( ViewType& view, const Array& array ) { view.setSynchronizer( array.getSynchronizer(), array.getValuesPerElement() ); } template< typename Array, std::enable_if_t< ! std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true > static void setSynchronizerHelper( ViewType& view, const Array& array ) {} }; } // namespace Containers Loading
src/TNL/Containers/DistributedArray.hpp +111 −9 Original line number Diff line number Diff line Loading @@ -25,9 +25,20 @@ template< typename Value, typename Index, typename Communicator > DistributedArray< Value, Device, Index, Communicator >:: DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) DistributedArray( const DistributedArray& array ) { setDistribution( localRange, globalSize, group ); setLike( array ); localData = array.getConstLocalViewWithGhosts(); } template< typename Value, typename Device, typename Index, typename Communicator > DistributedArray< Value, Device, Index, Communicator >:: DistributedArray( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group ) { setDistribution( localRange, ghosts, globalSize, group ); } template< typename Value, Loading @@ -36,12 +47,12 @@ template< typename Value, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) setDistribution( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group ) { TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" ); if( group != Communicator::NullGroup ) localData.setSize( localRange.getSize() ); view.bind( localRange, globalSize, group, localData.getView() ); localData.setSize( localRange.getSize() + ghosts ); view.bind( localRange, ghosts, globalSize, group, localData.getView() ); } template< typename Value, Loading @@ -55,6 +66,17 @@ getLocalRange() const return view.getLocalRange(); } template< typename Value, typename Device, typename Index, typename Communicator > Index DistributedArray< Value, Device, Index, Communicator >:: getGhosts() const { return view.getGhosts(); } template< typename Value, typename Device, typename Index, Loading @@ -74,7 +96,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType DistributedArray< Value, Device, Index, Communicator >:: getLocalView() { return localData.getView(); return view.getLocalView(); } template< typename Value, Loading @@ -85,7 +107,29 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewT DistributedArray< Value, Device, Index, Communicator >:: getConstLocalView() const { return localData.getConstView(); return view.getConstLocalView(); } template< typename Value, typename Device, typename Index, typename Communicator > typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType DistributedArray< Value, Device, Index, Communicator >:: getLocalViewWithGhosts() { return view.getLocalViewWithGhosts(); } template< typename Value, typename Device, typename Index, typename Communicator > typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewType DistributedArray< Value, Device, Index, Communicator >:: getConstLocalViewWithGhosts() const { return view.getConstLocalViewWithGhosts(); } Loading @@ -100,6 +144,61 @@ copyFromGlobal( ConstLocalViewType globalArray ) view.copyFromGlobal( globalArray ); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement ) { view.setSynchronizer( synchronizer, valuesPerElement ); } template< typename Value, typename Device, typename Index, typename Communicator > std::shared_ptr< typename DistributedArrayView< Value, Device, Index, Communicator >::SynchronizerType > DistributedArray< Value, Device, Index, Communicator >:: getSynchronizer() const { return view.getSynchronizer(); } template< typename Value, typename Device, typename Index, typename Communicator > int DistributedArray< Value, Device, Index, Communicator >:: getValuesPerElement() const { return view.getValuesPerElement(); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: startSynchronization() { view.startSynchronization(); } template< typename Value, typename Device, typename Index, typename Communicator > void DistributedArray< Value, Device, Index, Communicator >:: waitForSynchronization() const { view.waitForSynchronization(); } /* * Usual Array methods follow below. Loading Loading @@ -156,8 +255,11 @@ void DistributedArray< Value, Device, Index, Communicator >:: setLike( const Array& array ) { localData.setLike( array.getConstLocalView() ); view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); localData.setLike( array.getConstLocalViewWithGhosts() ); view.bind( array.getLocalRange(), array.getGhosts(), array.getSize(), array.getCommunicationGroup(), localData.getView() ); // set, but do not unset, the synchronizer if( array.getSynchronizer() ) setSynchronizerHelper( view, array ); } template< typename Value, Loading
src/TNL/Containers/DistributedArrayView.h +31 −5 Original line number Diff line number Diff line Loading @@ -12,9 +12,12 @@ #pragma once #include <memory> #include <TNL/Containers/ArrayView.h> #include <TNL/Communicators/MpiCommunicator.h> #include <TNL/Containers/Subrange.h> #include <TNL/Containers/ByteArraySynchronizer.h> namespace TNL { namespace Containers { Loading @@ -36,6 +39,7 @@ public: using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >; using ViewType = DistributedArrayView< Value, Device, Index, Communicator >; using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >; using SynchronizerType = ByteArraySynchronizer< DeviceType, IndexType >; /** * \brief A template which allows to quickly obtain a \ref DistributedArrayView type with changed template parameters. Loading @@ -48,11 +52,12 @@ public: // Initialization by raw data DistributedArrayView( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) : localRange(localRange), globalSize(globalSize), group(group), localData(localData) DistributedArrayView( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData ) : localRange(localRange), ghosts(ghosts), globalSize(globalSize), group(group), localData(localData) { TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(), TNL_ASSERT_EQ( localData.getSize(), localRange.getSize() + ghosts, "The local array size does not match the local range of the distributed array." ); TNL_ASSERT_GE( ghosts, 0, "The ghosts count must be non-negative." ); } DistributedArrayView() = default; Loading @@ -68,27 +73,44 @@ public: DistributedArrayView( DistributedArrayView&& ) = default; // method for rebinding (reinitialization) to raw data void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); void bind( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData ); // Note that you can also bind directly to DistributedArray and other types implicitly // convertible to DistributedArrayView. void bind( DistributedArrayView view ); // binding to local array via raw pointer // (local range, global size and communication group are preserved) // (local range, ghosts, global size and communication group are preserved) template< typename Value_ > void bind( Value_* data, IndexType localSize ); const LocalRangeType& getLocalRange() const; IndexType getGhosts() const; CommunicationGroup getCommunicationGroup() const; LocalViewType getLocalView(); ConstLocalViewType getConstLocalView() const; LocalViewType getLocalViewWithGhosts(); ConstLocalViewType getConstLocalViewWithGhosts() const; void copyFromGlobal( ConstLocalViewType globalArray ); // synchronizer stuff void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 ); std::shared_ptr< SynchronizerType > getSynchronizer() const; int getValuesPerElement() const; void startSynchronization(); void waitForSynchronization() const; /* * Usual ArrayView methods follow below. Loading Loading @@ -156,9 +178,13 @@ public: protected: LocalRangeType localRange; IndexType ghosts = 0; IndexType globalSize = 0; CommunicationGroup group = Communicator::NullGroup; LocalViewType localData; std::shared_ptr< SynchronizerType > synchronizer = nullptr; int valuesPerElement = 1; }; } // namespace Containers Loading