Loading src/TNL/DistributedContainers/CMakeLists.txt +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ SET( headers DistributedArray.h DistributedVectorView_impl.h Partitioner.h Subrange.h ThreePartVector.h ) INSTALL( FILES ${headers} DESTINATION ${TNL_TARGET_INCLUDE_DIRECTORY}/DistributedContainers ) src/TNL/DistributedContainers/DistributedSpMV.h +64 −37 Original line number Diff line number Diff line Loading @@ -18,9 +18,11 @@ // buffers #include <vector> #include <utility> // std::pair #include <limits> // std::numeric_limits #include <TNL/Matrices/Dense.h> #include <TNL/Containers/Vector.h> #include <TNL/Containers/VectorView.h> #include <TNL/DistributedContainers/ThreePartVector.h> // operations #include <type_traits> // std::add_const Loading @@ -43,26 +45,31 @@ public: using CommunicationGroup = typename CommunicatorType::CommunicationGroup; using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >; // - communication pattern matrix is an nproc x nproc binary matrix C, where // C_ij = 1 iff the i-th process needs data from the j-th process // - communication pattern: vector components whose indices are in the range // [start_ij, end_ij) are copied from the j-th process to the i-th process // (an empty range with start_ij == end_ij indicates that there is no // communication between the i-th and j-th processes) // - communication pattern matrices - we need to assemble two nproc x nproc // matrices commPatternStarts and commPatternEnds holding the values // start_ij and end_ij respectively // - assembly of the i-th row involves traversal of the local matrix stored // in the i-th process // - assembly the full matrix needs all-to-all communication // - assembly of the full matrix needs all-to-all communication void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group ) { const int rank = CommunicatorType::GetRank( group ); const int nproc = CommunicatorType::GetSize( group ); commPattern.setDimensions( nproc, nproc ); commPatternStarts.setDimensions( nproc, nproc ); commPatternEnds.setDimensions( nproc, nproc ); // pass the localMatrix to the device const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix ); // buffer for the local row of the commPattern matrix // using AtomicBool = Atomic< bool, DeviceType >; // FIXME: CUDA does not support atomic operations for bool using AtomicBool = Atomic< int, DeviceType >; Containers::Array< AtomicBool, DeviceType > buffer( nproc ); buffer.setValue( false ); using AtomicIndex = Atomic< IndexType, DeviceType >; Containers::Array< AtomicIndex, DeviceType > span_starts( nproc ), span_ends( nproc ); span_starts.setValue( std::numeric_limits<IndexType>::max() ); span_ends.setValue( 0 ); // optimization for banded matrices using AtomicIndex = Atomic< IndexType, DeviceType >; Loading @@ -71,7 +78,7 @@ public: local_span.setElement( 1, localMatrix.getRows() ); // span end auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix, AtomicBool* buffer, AtomicIndex* local_span ) AtomicIndex* span_starts, AtomicIndex* span_ends, AtomicIndex* local_span ) { const IndexType columns = localMatrix->getColumns(); const auto row = localMatrix->getRow( i ); Loading @@ -82,8 +89,9 @@ public: if( j < columns ) { const int owner = Partitioner::getOwner( j, columns, nproc ); // atomic assignment buffer[ owner ].store( true ); // update comm_left/Right span_starts[ owner ].fetch_min( j ); span_ends[ owner ].fetch_max( j + 1 ); // update comm_left/right if( owner < rank ) comm_left = true; if( owner > rank ) Loading @@ -100,7 +108,8 @@ public: ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel, &localMatrixPointer.template getData< DeviceType >(), buffer.getData(), span_starts.getData(), span_ends.getData(), local_span.getData() ); Loading @@ -108,16 +117,19 @@ public: localOnlySpan.first = local_span.getElement( 0 ); localOnlySpan.second = local_span.getElement( 1 ); // copy the buffer into all rows of the preCommPattern matrix Matrices::Dense< bool, Devices::Host, int > preCommPattern; preCommPattern.setLike( commPattern ); // copy the buffer into all rows of the commPattern* matrices for( int j = 0; j < nproc; j++ ) for( int i = 0; i < nproc; i++ ) preCommPattern.setElementFast( j, i, buffer.getElement( i ) ); for( int i = 0; i < nproc; i++ ) { commPatternStarts.setElementFast( j, i, span_starts.getElement( i ) ); commPatternEnds.setElementFast( j, i, span_ends.getElement( i ) ); } // assemble the commPattern matrix CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc, &commPattern(0, 0), nproc, // assemble the commPattern* matrices CommunicatorType::Alltoall( &commPatternStarts(0, 0), nproc, &commPatternStarts(0, 0), nproc, group ); CommunicatorType::Alltoall( &commPatternEnds(0, 0), nproc, &commPatternEnds(0, 0), nproc, group ); } Loading @@ -132,28 +144,37 @@ public: const int nproc = CommunicatorType::GetSize( group ); // update communication pattern if( commPattern.getRows() != nproc ) if( commPatternStarts.getRows() != nproc || commPatternEnds.getRows() != nproc ) updateCommunicationPattern( localMatrix, group ); // prepare buffers globalBuffer.setSize( localMatrix.getColumns() ); commRequests.clear(); globalBuffer.init( Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ), inVector.getLocalVectorView(), localMatrix.getColumns() - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ) - inVector.getLocalVectorView().getSize() ); const auto globalBufferView = globalBuffer.getConstView(); // send our data to all processes that need it for( int i = 0; i < commPattern.getRows(); i++ ) if( commPattern( i, rank ) ) for( int i = 0; i < commPatternStarts.getRows(); i++ ) { if( i == rank ) continue; if( commPatternStarts( i, rank ) < commPatternEnds( i, rank ) ) commRequests.push_back( CommunicatorType::ISend( inVector.getLocalVectorView().getData(), inVector.getLocalVectorView().getSize(), inVector.getLocalVectorView().getData() + commPatternStarts( i, rank ) - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ), commPatternEnds( i, rank ) - commPatternStarts( i, rank ), i, 0, group ) ); } // receive data that we need for( int j = 0; j < commPattern.getRows(); j++ ) if( commPattern( rank, j ) ) for( int j = 0; j < commPatternStarts.getRows(); j++ ) { if( j == rank ) continue; if( commPatternStarts( rank, j ) < commPatternEnds( rank, j ) ) commRequests.push_back( CommunicatorType::IRecv( &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ], Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ), &globalBuffer[ commPatternStarts( rank, j ) ], commPatternEnds( rank, j ) - commPatternStarts( rank, j ), j, 0, group ) ); } // general variant if( localOnlySpan.first >= localOnlySpan.second ) { Loading @@ -161,8 +182,14 @@ public: CommunicatorType::WaitAll( &commRequests[0], commRequests.size() ); // perform matrix-vector multiplication auto outView = outVector.getLocalVectorView(); localMatrix.vectorProduct( globalBuffer, outView ); auto outVectorView = outVector.getLocalVectorView(); const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix ); auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable { outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView ); }; ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel, &localMatrixPointer.template getData< DeviceType >() ); } // optimization for banded matrices else { Loading @@ -183,7 +210,6 @@ public: CommunicatorType::WaitAll( &commRequests[0], commRequests.size() ); // finish the multiplication by adding the non-local entries Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer ); auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable { outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView ); Loading @@ -197,7 +223,8 @@ public: void reset() { commPattern.reset(); commPatternStarts.reset(); commPatternEnds.reset(); localOnlySpan.first = localOnlySpan.second = 0; globalBuffer.reset(); commRequests.clear(); Loading @@ -205,13 +232,13 @@ public: protected: // communication pattern Matrices::Dense< bool, Devices::Host, int > commPattern; Matrices::Dense< IndexType, Devices::Host, int > commPatternStarts, commPatternEnds; // span of rows with only block-diagonal entries std::pair< IndexType, IndexType > localOnlySpan; // global buffer for non-local elements of the vector Containers::Vector< RealType, DeviceType, IndexType > globalBuffer; ThreePartVector< RealType, DeviceType, IndexType > globalBuffer; // buffer for asynchronous communication requests std::vector< typename CommunicatorType::Request > commRequests; Loading src/TNL/DistributedContainers/ThreePartVector.h 0 → 100644 +157 −0 Original line number Diff line number Diff line /*************************************************************************** ThreePartVector.h - description ------------------- begin : Dec 19, 2018 copyright : (C) 2018 by Tomas Oberhuber et al. email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Implemented by: Jakub Klinkovský #pragma once #include <TNL/Containers/Vector.h> #include <TNL/Containers/VectorView.h> namespace TNL { namespace DistributedContainers { template< typename Real, typename Device = Devices::Host, typename Index = int > class ThreePartVectorView { public: using RealType = Real; using DeviceType = Device; using IndexType = Index; using VectorView = Containers::VectorView< Real, Device, Index >; ThreePartVectorView() = default; ThreePartVectorView( const ThreePartVectorView& ) = default; ThreePartVectorView( ThreePartVectorView&& ) = default; ThreePartVectorView( VectorView view_left, VectorView view_mid, VectorView view_right ) { bind( view_left, view_mid, view_right ); } void bind( VectorView view_left, VectorView view_mid, VectorView view_right ) { left.bind( view_left ); middle.bind( view_mid ); right.bind( view_right ); } void reset() { left.reset(); middle.reset(); right.reset(); } // __cuda_callable__ // Real& operator[]( Index i ) // { // if( i < left.getSize() ) // return left[ i ]; // else if( i < left.getSize() + middle.getSize() ) // return middle[ i - left.getSize() ]; // else // return right[ i - left.getSize() - middle.getSize() ]; // } __cuda_callable__ const Real& operator[]( Index i ) const { if( i < left.getSize() ) return left[ i ]; else if( i < left.getSize() + middle.getSize() ) return middle[ i - left.getSize() ]; else return right[ i - left.getSize() - middle.getSize() ]; } friend std::ostream& operator<<( std::ostream& str, const ThreePartVectorView& v ) { str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]"; return str; } protected: VectorView left, middle, right; }; template< typename Real, typename Device = Devices::Host, typename Index = int > class ThreePartVector { using ConstReal = typename std::add_const< Real >::type; public: using RealType = Real; using DeviceType = Device; using IndexType = Index; using Vector = Containers::Vector< Real, Device, Index >; using VectorView = Containers::VectorView< Real, Device, Index >; using ConstVectorView = Containers::VectorView< ConstReal, Device, Index >; ThreePartVector() = default; ThreePartVector( ThreePartVector& ) = default; void init( Index size_left, ConstVectorView view_mid, Index size_right ) { left.setSize( size_left ); middle.bind( view_mid ); right.setSize( size_right ); } void reset() { left.reset(); middle.reset(); right.reset(); } ThreePartVectorView< ConstReal, Device, Index > getConstView() { return {left, middle, right}; } // __cuda_callable__ // Real& operator[]( Index i ) // { // if( i < left.getSize() ) // return left[ i ]; // else if( i < left.getSize() + middle.getSize() ) // return middle[ i - left.getSize() ]; // else // return right[ i - left.getSize() - middle.getSize() ]; // } __cuda_callable__ const Real& operator[]( Index i ) const { if( i < left.getSize() ) return left[ i ]; else if( i < left.getSize() + middle.getSize() ) return middle[ i - left.getSize() ]; else return right[ i - left.getSize() - middle.getSize() ]; } friend std::ostream& operator<<( std::ostream& str, const ThreePartVector& v ) { str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]"; return str; } protected: Vector left, right; ConstVectorView middle; }; } // namespace DistributedContainers } // namespace TNL src/UnitTests/DistributedContainers/DistributedMatrixTest.h +6 −2 Original line number Diff line number Diff line Loading @@ -214,7 +214,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput ) DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() ); this->matrix.vectorProduct( inVector, outVector ); EXPECT_EQ( outVector, this->rowLengths ); EXPECT_EQ( outVector, this->rowLengths ) << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView() << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView(); } TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput ) Loading @@ -229,7 +231,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput ) DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() ); this->matrix.vectorProduct( inVector, outVector ); EXPECT_EQ( outVector, this->rowLengths ); EXPECT_EQ( outVector, this->rowLengths ) << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView() << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView(); } #endif // HAVE_GTEST Loading Loading
src/TNL/DistributedContainers/CMakeLists.txt +1 −0 Original line number Diff line number Diff line Loading @@ -11,6 +11,7 @@ SET( headers DistributedArray.h DistributedVectorView_impl.h Partitioner.h Subrange.h ThreePartVector.h ) INSTALL( FILES ${headers} DESTINATION ${TNL_TARGET_INCLUDE_DIRECTORY}/DistributedContainers )
src/TNL/DistributedContainers/DistributedSpMV.h +64 −37 Original line number Diff line number Diff line Loading @@ -18,9 +18,11 @@ // buffers #include <vector> #include <utility> // std::pair #include <limits> // std::numeric_limits #include <TNL/Matrices/Dense.h> #include <TNL/Containers/Vector.h> #include <TNL/Containers/VectorView.h> #include <TNL/DistributedContainers/ThreePartVector.h> // operations #include <type_traits> // std::add_const Loading @@ -43,26 +45,31 @@ public: using CommunicationGroup = typename CommunicatorType::CommunicationGroup; using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >; // - communication pattern matrix is an nproc x nproc binary matrix C, where // C_ij = 1 iff the i-th process needs data from the j-th process // - communication pattern: vector components whose indices are in the range // [start_ij, end_ij) are copied from the j-th process to the i-th process // (an empty range with start_ij == end_ij indicates that there is no // communication between the i-th and j-th processes) // - communication pattern matrices - we need to assemble two nproc x nproc // matrices commPatternStarts and commPatternEnds holding the values // start_ij and end_ij respectively // - assembly of the i-th row involves traversal of the local matrix stored // in the i-th process // - assembly the full matrix needs all-to-all communication // - assembly of the full matrix needs all-to-all communication void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group ) { const int rank = CommunicatorType::GetRank( group ); const int nproc = CommunicatorType::GetSize( group ); commPattern.setDimensions( nproc, nproc ); commPatternStarts.setDimensions( nproc, nproc ); commPatternEnds.setDimensions( nproc, nproc ); // pass the localMatrix to the device const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix ); // buffer for the local row of the commPattern matrix // using AtomicBool = Atomic< bool, DeviceType >; // FIXME: CUDA does not support atomic operations for bool using AtomicBool = Atomic< int, DeviceType >; Containers::Array< AtomicBool, DeviceType > buffer( nproc ); buffer.setValue( false ); using AtomicIndex = Atomic< IndexType, DeviceType >; Containers::Array< AtomicIndex, DeviceType > span_starts( nproc ), span_ends( nproc ); span_starts.setValue( std::numeric_limits<IndexType>::max() ); span_ends.setValue( 0 ); // optimization for banded matrices using AtomicIndex = Atomic< IndexType, DeviceType >; Loading @@ -71,7 +78,7 @@ public: local_span.setElement( 1, localMatrix.getRows() ); // span end auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix, AtomicBool* buffer, AtomicIndex* local_span ) AtomicIndex* span_starts, AtomicIndex* span_ends, AtomicIndex* local_span ) { const IndexType columns = localMatrix->getColumns(); const auto row = localMatrix->getRow( i ); Loading @@ -82,8 +89,9 @@ public: if( j < columns ) { const int owner = Partitioner::getOwner( j, columns, nproc ); // atomic assignment buffer[ owner ].store( true ); // update comm_left/Right span_starts[ owner ].fetch_min( j ); span_ends[ owner ].fetch_max( j + 1 ); // update comm_left/right if( owner < rank ) comm_left = true; if( owner > rank ) Loading @@ -100,7 +108,8 @@ public: ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel, &localMatrixPointer.template getData< DeviceType >(), buffer.getData(), span_starts.getData(), span_ends.getData(), local_span.getData() ); Loading @@ -108,16 +117,19 @@ public: localOnlySpan.first = local_span.getElement( 0 ); localOnlySpan.second = local_span.getElement( 1 ); // copy the buffer into all rows of the preCommPattern matrix Matrices::Dense< bool, Devices::Host, int > preCommPattern; preCommPattern.setLike( commPattern ); // copy the buffer into all rows of the commPattern* matrices for( int j = 0; j < nproc; j++ ) for( int i = 0; i < nproc; i++ ) preCommPattern.setElementFast( j, i, buffer.getElement( i ) ); for( int i = 0; i < nproc; i++ ) { commPatternStarts.setElementFast( j, i, span_starts.getElement( i ) ); commPatternEnds.setElementFast( j, i, span_ends.getElement( i ) ); } // assemble the commPattern matrix CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc, &commPattern(0, 0), nproc, // assemble the commPattern* matrices CommunicatorType::Alltoall( &commPatternStarts(0, 0), nproc, &commPatternStarts(0, 0), nproc, group ); CommunicatorType::Alltoall( &commPatternEnds(0, 0), nproc, &commPatternEnds(0, 0), nproc, group ); } Loading @@ -132,28 +144,37 @@ public: const int nproc = CommunicatorType::GetSize( group ); // update communication pattern if( commPattern.getRows() != nproc ) if( commPatternStarts.getRows() != nproc || commPatternEnds.getRows() != nproc ) updateCommunicationPattern( localMatrix, group ); // prepare buffers globalBuffer.setSize( localMatrix.getColumns() ); commRequests.clear(); globalBuffer.init( Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ), inVector.getLocalVectorView(), localMatrix.getColumns() - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ) - inVector.getLocalVectorView().getSize() ); const auto globalBufferView = globalBuffer.getConstView(); // send our data to all processes that need it for( int i = 0; i < commPattern.getRows(); i++ ) if( commPattern( i, rank ) ) for( int i = 0; i < commPatternStarts.getRows(); i++ ) { if( i == rank ) continue; if( commPatternStarts( i, rank ) < commPatternEnds( i, rank ) ) commRequests.push_back( CommunicatorType::ISend( inVector.getLocalVectorView().getData(), inVector.getLocalVectorView().getSize(), inVector.getLocalVectorView().getData() + commPatternStarts( i, rank ) - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ), commPatternEnds( i, rank ) - commPatternStarts( i, rank ), i, 0, group ) ); } // receive data that we need for( int j = 0; j < commPattern.getRows(); j++ ) if( commPattern( rank, j ) ) for( int j = 0; j < commPatternStarts.getRows(); j++ ) { if( j == rank ) continue; if( commPatternStarts( rank, j ) < commPatternEnds( rank, j ) ) commRequests.push_back( CommunicatorType::IRecv( &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ], Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ), &globalBuffer[ commPatternStarts( rank, j ) ], commPatternEnds( rank, j ) - commPatternStarts( rank, j ), j, 0, group ) ); } // general variant if( localOnlySpan.first >= localOnlySpan.second ) { Loading @@ -161,8 +182,14 @@ public: CommunicatorType::WaitAll( &commRequests[0], commRequests.size() ); // perform matrix-vector multiplication auto outView = outVector.getLocalVectorView(); localMatrix.vectorProduct( globalBuffer, outView ); auto outVectorView = outVector.getLocalVectorView(); const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix ); auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable { outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView ); }; ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel, &localMatrixPointer.template getData< DeviceType >() ); } // optimization for banded matrices else { Loading @@ -183,7 +210,6 @@ public: CommunicatorType::WaitAll( &commRequests[0], commRequests.size() ); // finish the multiplication by adding the non-local entries Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer ); auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable { outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView ); Loading @@ -197,7 +223,8 @@ public: void reset() { commPattern.reset(); commPatternStarts.reset(); commPatternEnds.reset(); localOnlySpan.first = localOnlySpan.second = 0; globalBuffer.reset(); commRequests.clear(); Loading @@ -205,13 +232,13 @@ public: protected: // communication pattern Matrices::Dense< bool, Devices::Host, int > commPattern; Matrices::Dense< IndexType, Devices::Host, int > commPatternStarts, commPatternEnds; // span of rows with only block-diagonal entries std::pair< IndexType, IndexType > localOnlySpan; // global buffer for non-local elements of the vector Containers::Vector< RealType, DeviceType, IndexType > globalBuffer; ThreePartVector< RealType, DeviceType, IndexType > globalBuffer; // buffer for asynchronous communication requests std::vector< typename CommunicatorType::Request > commRequests; Loading
src/TNL/DistributedContainers/ThreePartVector.h 0 → 100644 +157 −0 Original line number Diff line number Diff line /*************************************************************************** ThreePartVector.h - description ------------------- begin : Dec 19, 2018 copyright : (C) 2018 by Tomas Oberhuber et al. email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Implemented by: Jakub Klinkovský #pragma once #include <TNL/Containers/Vector.h> #include <TNL/Containers/VectorView.h> namespace TNL { namespace DistributedContainers { template< typename Real, typename Device = Devices::Host, typename Index = int > class ThreePartVectorView { public: using RealType = Real; using DeviceType = Device; using IndexType = Index; using VectorView = Containers::VectorView< Real, Device, Index >; ThreePartVectorView() = default; ThreePartVectorView( const ThreePartVectorView& ) = default; ThreePartVectorView( ThreePartVectorView&& ) = default; ThreePartVectorView( VectorView view_left, VectorView view_mid, VectorView view_right ) { bind( view_left, view_mid, view_right ); } void bind( VectorView view_left, VectorView view_mid, VectorView view_right ) { left.bind( view_left ); middle.bind( view_mid ); right.bind( view_right ); } void reset() { left.reset(); middle.reset(); right.reset(); } // __cuda_callable__ // Real& operator[]( Index i ) // { // if( i < left.getSize() ) // return left[ i ]; // else if( i < left.getSize() + middle.getSize() ) // return middle[ i - left.getSize() ]; // else // return right[ i - left.getSize() - middle.getSize() ]; // } __cuda_callable__ const Real& operator[]( Index i ) const { if( i < left.getSize() ) return left[ i ]; else if( i < left.getSize() + middle.getSize() ) return middle[ i - left.getSize() ]; else return right[ i - left.getSize() - middle.getSize() ]; } friend std::ostream& operator<<( std::ostream& str, const ThreePartVectorView& v ) { str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]"; return str; } protected: VectorView left, middle, right; }; template< typename Real, typename Device = Devices::Host, typename Index = int > class ThreePartVector { using ConstReal = typename std::add_const< Real >::type; public: using RealType = Real; using DeviceType = Device; using IndexType = Index; using Vector = Containers::Vector< Real, Device, Index >; using VectorView = Containers::VectorView< Real, Device, Index >; using ConstVectorView = Containers::VectorView< ConstReal, Device, Index >; ThreePartVector() = default; ThreePartVector( ThreePartVector& ) = default; void init( Index size_left, ConstVectorView view_mid, Index size_right ) { left.setSize( size_left ); middle.bind( view_mid ); right.setSize( size_right ); } void reset() { left.reset(); middle.reset(); right.reset(); } ThreePartVectorView< ConstReal, Device, Index > getConstView() { return {left, middle, right}; } // __cuda_callable__ // Real& operator[]( Index i ) // { // if( i < left.getSize() ) // return left[ i ]; // else if( i < left.getSize() + middle.getSize() ) // return middle[ i - left.getSize() ]; // else // return right[ i - left.getSize() - middle.getSize() ]; // } __cuda_callable__ const Real& operator[]( Index i ) const { if( i < left.getSize() ) return left[ i ]; else if( i < left.getSize() + middle.getSize() ) return middle[ i - left.getSize() ]; else return right[ i - left.getSize() - middle.getSize() ]; } friend std::ostream& operator<<( std::ostream& str, const ThreePartVector& v ) { str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]"; return str; } protected: Vector left, right; ConstVectorView middle; }; } // namespace DistributedContainers } // namespace TNL
src/UnitTests/DistributedContainers/DistributedMatrixTest.h +6 −2 Original line number Diff line number Diff line Loading @@ -214,7 +214,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput ) DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() ); this->matrix.vectorProduct( inVector, outVector ); EXPECT_EQ( outVector, this->rowLengths ); EXPECT_EQ( outVector, this->rowLengths ) << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView() << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView(); } TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput ) Loading @@ -229,7 +231,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput ) DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() ); this->matrix.vectorProduct( inVector, outVector ); EXPECT_EQ( outVector, this->rowLengths ); EXPECT_EQ( outVector, this->rowLengths ) << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView() << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView(); } #endif // HAVE_GTEST Loading