diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index b83e43f1d146091219e3948afdbb94bfa0ae0b4e..add07f1dff5c587f3cde22d953a1390728f464e8 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -93,7 +93,6 @@ class CSR template< typename Function, typename... Args > void forAll( Function& f, Args... args ) const; - /*** * \brief Go over all segments and perform a reduction in each of them. */ @@ -103,6 +102,11 @@ class CSR template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + CSR& operator=( const CSR& rhsSegments ) = default; + + template< typename Device_, typename Index_, typename IndexAllocator_ > + CSR& operator=( const CSR< Device_, Index_, IndexAllocator_ >& source ); + void save( File& file ) const; void load( File& file ); diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index 280ed6ebf212d0419ab16ac7dbccb7ddd664c507..61720869caa418d189d105ce6f0769e0690d9b35 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -221,6 +221,17 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... ); } +template< typename Device, + typename Index, + typename IndexAllocator > + template< typename Device_, typename Index_, typename IndexAllocator_ > +CSR< Device, Index, IndexAllocator >& +CSR< Device, Index, IndexAllocator >:: +operator=( const CSR< Device_, Index_, IndexAllocator_ >& source ) +{ + this->offsets = source.offsets; +} + template< typename Device, typename Index, typename IndexAllocator > diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index 9c81a84281925b2ec971bceba5a161aa464c83e6..b9b3e63c1efd7e92bc83eabc78d439637c448247 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -100,6 +100,11 @@ class Ellpack template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + Ellpack& operator=( const Ellpack& source ) = default; + + template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_, int Alignment_ > + Ellpack& operator=( const Ellpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, Alignment_ >& source ); + void save( File& file ) const; void load( File& file ); diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp index 482c87d4f36274edf8aef622f576ce07773df3b9..97d30d3147ecf266f334f9bddc0219a423d89a23 100644 --- a/src/TNL/Containers/Segments/Ellpack.hpp +++ b/src/TNL/Containers/Segments/Ellpack.hpp @@ -322,6 +322,21 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... ); } +template< typename Device, + typename Index, + typename IndexAllocator, + bool RowMajorOrder, + int Alignment > + template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_, int Alignment_ > +Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >& +Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >:: +operator=( const Ellpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, Alignment_ >& source ) +{ + this->segmentSize = source.segmentSize; + this->size = source.size; + this->alignedSize = roundUpDivision( size, this->getAlignment() ) * this->getAlignment(); +} + template< typename Device, typename Index, typename IndexAllocator, diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index fc514c51f3edcff69682e35378515c9d90dc8ffc..9c2e7157f73d46094d0bb9bb8af251f164de9588 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -96,6 +96,11 @@ class SlicedEllpack template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + SlicedEllpack& operator=( const SlicedEllpack& source ) = default; + + template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_ > + SlicedEllpack& operator=( const SlicedEllpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, SliceSize >& source ); + void save( File& file ) const; void load( File& file ); diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp index bdf28ff73b46de6e9423eebcd03849d7ef5cca2b..ad83f666a00a073fb30f65e50a085e84ba6abd05 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.hpp +++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp @@ -356,6 +356,23 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... ); } +template< typename Device, + typename Index, + typename IndexAllocator, + bool RowMajorOrder, + int SliceSize > + template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_ > +SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >& +SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >:: +operator=( const SlicedEllpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, SliceSize >& source ) +{ + this->size = source.size; + this->alignedSize = source.alignedSize; + this->segmentsCount = source.segmentsCount; + this->sliceOffsets = source.sliceOffsets; + this->sliceSegmentSizes = source.sliceSegmentSizes; +} + template< typename Device, typename Index, typename IndexAllocator, diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h index 1512f8574d1e4a29534b61a63f1ebed40c25af82..8c8fef599a9f57fcd8af11ec028b0217de45cbe2 100644 --- a/src/TNL/Matrices/SparseMatrix.h +++ b/src/TNL/Matrices/SparseMatrix.h @@ -218,7 +218,7 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > SegmentsType segments; - IndexAllocator indexAlloctor; + IndexAllocator indexAllocator; RealAllocator realAllocator; diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp index e24ed2f448b8b9dc8fe7a78f7f0b8af0431fbf77..5de4473ab51f121553fbae9148afe7a8044a00bd 100644 --- a/src/TNL/Matrices/SparseMatrix.hpp +++ b/src/TNL/Matrices/SparseMatrix.hpp @@ -680,7 +680,7 @@ operator=( const SparseMatrix& matrix ) Matrix< Real, Device, Index >::operator=( matrix ); this->columnIndexes = matrix.columnIndexes; this->segments = matrix.segments; - this->indexAlloctor = matrix.indexAllocator; + this->indexAllocator = matrix.indexAllocator; this->realAllocator = matrix.realAllocator; } @@ -702,29 +702,85 @@ SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix ) { using RHSMatrixType = SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >; + RowsCapacitiesType rowLengths; + matrix.getCompressedRowLengths( rowLengths ); + this->setCompressedRowLengths( rowLengths ); + + // TODO: Replace this with SparseMatrixView + const auto matrix_columns_view = matrix.columnIndexes.getConstView(); + const auto matrix_values_view = matrix.values.getConstView(); + const IndexType paddingIndex = this->getPaddingIndex(); + auto this_columns_view = this->columnIndexes.getView(); + auto this_values_view = this->values.getView(); + if( std::is_same< Device, Device2 >::value ) { - /*RowsCapacitiesType rowLengths; - matrix.getCompressedRowLengths( rowLengths ); - this->setCompressedRowLengths( rowLengths ); - // TODO: Replace this with SparseMatrixView - const auto matrix_columns_view = matrix.columnIndexes.getConstView(); - const auto matrix_values_view = matrix.values.getConstView(); - const auto segments_view = this->segments.getConstView(); - auto this_columns_view = this->columnIndexes.getView(); - auto this_values_view = this->values.getView(); - const IndexType paddingIndex = this->getPaddingIndex(); - auto f = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) { + const auto this_segments_view = this->segments.getView(); + auto f = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable { const IndexType column = matrix_columns_view[ globalIdx ]; if( column != paddingIndex ) { const RealType value = matrix_values_view[ globalIdx ]; - IndexType thisGlobalIdx = segments_view.getGlobalIdx( rowIdx, localIdx ); + IndexType thisGlobalIdx = this_segments_view.getGlobalIndex( rowIdx, localIdx ); this_columns_view[ thisGlobalIdx ] = column; this_values_view[ thisGlobalIdx ] = value; } }; - matrix.forAllRows( f );*/ + matrix.forAllRows( f ); + } + else + { + const IndexType maxRowLength = max( rowLengths ); + const IndexType bufferRowsCount( 128 ); + const size_t bufferSize = bufferRowsCount * maxRowLength; + Containers::Vector< Real2, Device2, Index2, RealAllocator2 > matrixValuesBuffer( bufferSize ); + Containers::Vector< Index2, Device2, Index2, IndexAllocator2 > matrixColumnsBuffer( bufferSize ); + Containers::Vector< RealType, DeviceType, IndexType, RealAllocator > thisValuesBuffer( bufferSize ); + Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocator > thisColumnsBuffer( bufferSize ); + auto matrixValuesBuffer_view = matrixValuesBuffer.getView(); + auto matrixColumnsBuffer_view = matrixColumnsBuffer.getView(); + auto thisValuesBuffer_view = thisValuesBuffer.getView(); + auto thisColumnsBuffer_view = thisColumnsBuffer.getView(); + + IndexType baseRow( 0 ); + const IndexType rowsCount = this->getRows(); + while( baseRow < rowsCount ) + { + const IndexType lastRow = min( baseRow + bufferRowsCount, rowsCount ); + thisColumnsBuffer = paddingIndex; + + //// + // Copy matrix elements into buffer + auto f1 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable { + const IndexType column = matrix_columns_view[ globalIdx ]; + if( column != paddingIndex ) + { + const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx; + matrixValuesBuffer_view[ bufferIdx ] = matrix_values_view[ globalIdx ]; + matrixColumnsBuffer_view[ bufferIdx ] = column; + } + }; + matrix.forRows( baseRow, lastRow, f1 ); + + //// + // Copy the source matrix buffer to this matrix buffer + thisValuesBuffer_view = matrixValuesBuffer_view; + thisColumnsBuffer_view = matrixColumnsBuffer_view; + + //// + // Copy matrix elements from the buffer to the matrix + auto f2 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable { + const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx; + const IndexType column = thisColumnsBuffer_view[ bufferIdx ]; + if( column != paddingIndex ) + { + this_columns_view[ globalIdx ] = column; + this_values_view[ globalIdx ] = thisValuesBuffer_view[ bufferIdx ]; + } + }; + this->forRows( baseRow, lastRow, f2 ); + baseRow += bufferRowsCount; + } } } diff --git a/src/UnitTests/Matrices/SparseMatrixCopyTest.h b/src/UnitTests/Matrices/SparseMatrixCopyTest.h index 9b09ef4d45cc4ee2e27fb582aa6f56e3de7e09b5..684a6a8713b534edb02c89778008332f54095795 100644 --- a/src/UnitTests/Matrices/SparseMatrixCopyTest.h +++ b/src/UnitTests/Matrices/SparseMatrixCopyTest.h @@ -12,12 +12,31 @@ #include <TNL/Matrices/Ellpack.h> #include <TNL/Matrices/SlicedEllpack.h> -using CSR_host = TNL::Matrices::CSR< int, TNL::Devices::Host, int >; +#include <TNL/Matrices/SparseMatrix.h> +#include <TNL/Containers/Segments/CSR.h> +#include <TNL/Containers/Segments/Ellpack.h> +#include <TNL/Containers/Segments/SlicedEllpack.h> + +/*using CSR_host = TNL::Matrices::CSR< int, TNL::Devices::Host, int >; using CSR_cuda = TNL::Matrices::CSR< int, TNL::Devices::Cuda, int >; using E_host = TNL::Matrices::Ellpack< int, TNL::Devices::Host, int >; using E_cuda = TNL::Matrices::Ellpack< int, TNL::Devices::Cuda, int >; using SE_host = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Host, int, 2 >; -using SE_cuda = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Cuda, int, 2 >; +using SE_cuda = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Cuda, int, 2 >;*/ + +template< typename Device, typename Index, typename IndexAllocator > +using EllpackSegments = TNL::Containers::Segments::Ellpack< Device, Index, IndexAllocator >; + +template< typename Device, typename Index, typename IndexAllocator > +using SlicedEllpackSegments = TNL::Containers::Segments::SlicedEllpack< Device, Index, IndexAllocator >; + +using CSR_host = TNL::Matrices::SparseMatrix< int, TNL::Containers::Segments::CSR, TNL::Devices::Host, int >; +using CSR_cuda = TNL::Matrices::SparseMatrix< int, TNL::Containers::Segments::CSR, TNL::Devices::Cuda, int >; +using E_host = TNL::Matrices::SparseMatrix< int, EllpackSegments, TNL::Devices::Host, int >; +using E_cuda = TNL::Matrices::SparseMatrix< int, EllpackSegments, TNL::Devices::Cuda, int >; +using SE_host = TNL::Matrices::SparseMatrix< int, SlicedEllpackSegments, TNL::Devices::Host, int >; +using SE_cuda = TNL::Matrices::SparseMatrix< int, SlicedEllpackSegments, TNL::Devices::Cuda, int >; + #ifdef HAVE_GTEST #include <gtest/gtest.h> @@ -388,7 +407,8 @@ void testConversion() checkTriDiagMatrix( triDiag1 ); Matrix2 triDiag2; - TNL::Matrices::copySparseMatrix( triDiag2, triDiag1 ); + //TNL::Matrices::copySparseMatrix( triDiag2, triDiag1 ); + triDiag2 = triDiag1; checkTriDiagMatrix( triDiag2 ); } @@ -400,7 +420,8 @@ void testConversion() checkAntiTriDiagMatrix( antiTriDiag1 ); Matrix2 antiTriDiag2; - TNL::Matrices::copySparseMatrix( antiTriDiag2, antiTriDiag1 ); + //TNL::Matrices::copySparseMatrix( antiTriDiag2, antiTriDiag1 ); + antiTriDiag2 = antiTriDiag1; checkAntiTriDiagMatrix( antiTriDiag2 ); } @@ -411,7 +432,8 @@ void testConversion() checkUnevenRowSizeMatrix( unevenRowSize1 ); Matrix2 unevenRowSize2; - TNL::Matrices::copySparseMatrix( unevenRowSize2, unevenRowSize1 ); + //TNL::Matrices::copySparseMatrix( unevenRowSize2, unevenRowSize1 ); + unevenRowSize2 = unevenRowSize1; checkUnevenRowSizeMatrix( unevenRowSize2 ); } }