diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index ddf56b67d5594413d8f61e2c053b480734b30741..df7cb5686e03c8f7209c02febe0d039c4a1152c9 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -35,7 +35,7 @@ class CSR using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentView = CSRSegmentView< IndexType >; + using SegmentViewType = CSRSegmentView< IndexType >; CSR(); @@ -45,6 +45,8 @@ class CSR CSR( const CSR&& segments ); + static String getSerializationType(); + /** * \brief Set sizes of particular segments. */ @@ -86,7 +88,7 @@ class CSR void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index 16e8a7763ebf2f3f62ce9c4971d7ce116cbd43c7..9ab2186c36f01fc8c4806659ab01ce84bb17ee63 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -54,6 +54,16 @@ CSR( const CSR&& csr ) : offsets( std::move( csr.offsets ) ) } +template< typename Device, + typename Index, + typename IndexAllocator > +String +CSR< Device, Index, IndexAllocator >:: +getSerializationType() +{ + return "CSR< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index, typename IndexAllocator > @@ -164,7 +174,7 @@ template< typename Device, __cuda_callable__ auto CSR< Device, Index, IndexAllocator >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { return SegmentView( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); } diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index 3af5798f740ebc715cd0292ee2f050544802d6dd..860a35a0a27a679232ae4109a05e43522ccab132 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -33,7 +33,7 @@ class CSRView template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentView = CSRSegmentView< IndexType >; + using SegmentViewType = CSRSegmentView< IndexType >; __cuda_callable__ CSRView(); @@ -50,6 +50,8 @@ class CSRView __cuda_callable__ CSRView( const CSRView&& csr_view ); + static String getSerializationType(); + ViewType getView(); ConstViewType getConstView() const; @@ -85,7 +87,7 @@ class CSRView void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp index 0135c8c6810dea0e36292c32915fa32e63dceaac..f4f59370d46a5cbcb8502ff5e30d181397c382ca 100644 --- a/src/TNL/Containers/Segments/CSRView.hpp +++ b/src/TNL/Containers/Segments/CSRView.hpp @@ -64,6 +64,15 @@ CSRView( const CSRView&& csr_view ) { } +template< typename Device, + typename Index > +String +CSRView< Device, Index >:: +getSerializationType() +{ + return "CSR< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index > typename CSRView< Device, Index >::ViewType @@ -154,9 +163,9 @@ template< typename Device, __cuda_callable__ auto CSRView< Device, Index >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { - return SegmentView( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); + return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index 0ecae8e7d3e70cb2c04497b4e25b98481b1472ad..f73155335d11de10e0b86188cb832d072b0b8694 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -37,7 +37,7 @@ class Ellpack using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >; //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >; - using SegmentView = EllpackSegmentView< IndexType >; + using SegmentViewType = EllpackSegmentView< IndexType >; Ellpack(); @@ -50,6 +50,8 @@ class Ellpack Ellpack( const Ellpack&& segments ); + static String getSerializationType(); + ViewType getView(); //ConstViewType getConstView() const; @@ -83,7 +85,7 @@ class Ellpack void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp index 762d314dd659ea83dff076daa8d9aca3a616da7c..9f7702a6f2e071cfd255a2c69ed20b6f9a9b2856 100644 --- a/src/TNL/Containers/Segments/Ellpack.hpp +++ b/src/TNL/Containers/Segments/Ellpack.hpp @@ -76,6 +76,18 @@ Ellpack( const Ellpack&& ellpack ) { } +template< typename Device, + typename Index, + typename IndexAllocator, + bool RowMajorOrder, + int Alignment > +String +Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >:: +getSerializationType() +{ + return "Ellpack< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index, typename IndexAllocator, @@ -224,7 +236,7 @@ template< typename Device, __cuda_callable__ auto Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { if( RowMajorOrder ) return SegmentView( segmentIdx * this->segmentSize, this->segmentSize, 1 ); diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h index 185321adb2ec720b63dff65539d7745da776f7ce..682eeeb4a76ee1b69d4a6a78b1c95023021ee770 100644 --- a/src/TNL/Containers/Segments/EllpackView.h +++ b/src/TNL/Containers/Segments/EllpackView.h @@ -38,7 +38,7 @@ class EllpackView using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView; //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >; - using SegmentView = EllpackSegmentView< IndexType >; + using SegmentViewType = EllpackSegmentView< IndexType >; __cuda_callable__ EllpackView(); @@ -52,6 +52,8 @@ class EllpackView __cuda_callable__ EllpackView( const EllpackView&& ellpackView ); + static String getSerializationType(); + ViewType getView(); //ConstViewType getConstView() const; @@ -78,7 +80,7 @@ class EllpackView void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/EllpackView.hpp b/src/TNL/Containers/Segments/EllpackView.hpp index 914d30a2e9a30147bae0f311398756022fadb2c6..f5dba4f3d7fb65c5e144de97fde23f58a2893505 100644 --- a/src/TNL/Containers/Segments/EllpackView.hpp +++ b/src/TNL/Containers/Segments/EllpackView.hpp @@ -63,6 +63,17 @@ EllpackView( const EllpackView&& ellpack ) { } +template< typename Device, + typename Index, + bool RowMajorOrder, + int Alignment > +String +EllpackView< Device, Index, RowMajorOrder, Alignment >:: +getSerializationType() +{ + return "Ellpack< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index, bool RowMajorOrder, @@ -167,12 +178,12 @@ template< typename Device, __cuda_callable__ auto EllpackView< Device, Index, RowMajorOrder, Alignment >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { if( RowMajorOrder ) - return SegmentView( segmentIdx * this->segmentSize, this->segmentSize, 1 ); + return SegmentViewType( segmentIdx * this->segmentSize, this->segmentSize, 1 ); else - return SegmentView( segmentIdx, this->segmentSize, this->alignedSize ); + return SegmentViewType( segmentIdx, this->segmentSize, this->alignedSize ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index 8c01e8a28672d3ab5a1a055ac280327aec7068fb..1c110b1f11b18457bf9eeb8a70746b256739f6af 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -36,7 +36,7 @@ class SlicedEllpack template< typename Device_, typename Index_ > using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >; - using SegmentView = EllpackSegmentView< IndexType >; + using SegmentViewType = EllpackSegmentView< IndexType >; SlicedEllpack(); @@ -46,6 +46,8 @@ class SlicedEllpack SlicedEllpack( const SlicedEllpack&& segments ); + static String getSerializationType(); + ViewType getView(); ConstViewType getConstView() const; @@ -79,7 +81,7 @@ class SlicedEllpack void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp index 1f647970415845a328b5ba29461d2fc35a50b26c..e2aec924d58ef22c5c896bc815506a3e468b68de 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.hpp +++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp @@ -69,6 +69,18 @@ SlicedEllpack( const SlicedEllpack&& slicedEllpack ) { } +template< typename Device, + typename Index, + typename IndexAllocator, + bool RowMajorOrder, + int SliceSize > +String +SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >:: +getSerializationType() +{ + return "SlicedEllpack< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index, typename IndexAllocator, @@ -249,7 +261,7 @@ template< typename Device, __cuda_callable__ auto SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { const IndexType sliceIdx = segmentIdx / SliceSize; const IndexType segmentInSliceIdx = segmentIdx % SliceSize; @@ -257,7 +269,7 @@ getSegmentView( const IndexType segmentIdx ) const -> SegmentView const IndexType& segmentSize = this->sliceSegmentSizes[ sliceIdx ]; if( RowMajorOrder ) - return SegmentView( sliceOffset, segmentSize, 1 ); + return SegmentView( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); else return SegmentView( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); } diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h index 890814b8195dec8ddccf2fbf34a289fd8342abeb..e87c75229a9c16ad17023204fdd838ecdc09c425 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.h +++ b/src/TNL/Containers/Segments/SlicedEllpackView.h @@ -36,7 +36,7 @@ class SlicedEllpackView using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ViewType = SlicedEllpackView; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >; - using SegmentView = EllpackSegmentView< IndexType >; + using SegmentViewType = EllpackSegmentView< IndexType >; __cuda_callable__ SlicedEllpackView(); @@ -54,6 +54,8 @@ class SlicedEllpackView __cuda_callable__ SlicedEllpackView( const SlicedEllpackView&& slicedEllpackView ); + static String getSerializationType(); + ViewType getView(); ConstViewType getConstView() const; @@ -80,7 +82,7 @@ class SlicedEllpackView void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; __cuda_callable__ - SegmentView getSegmentView( const IndexType segmentIdx ) const; + SegmentViewType getSegmentView( const IndexType segmentIdx ) const; /*** * \brief Go over all segments and for each segment element call diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.hpp b/src/TNL/Containers/Segments/SlicedEllpackView.hpp index 45e33b236dc46128b15b87ebb3dc8852def48071..139a09a15e515f1a5ffe83cbd4f3c5f2c8933490 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.hpp +++ b/src/TNL/Containers/Segments/SlicedEllpackView.hpp @@ -72,6 +72,17 @@ SlicedEllpackView( const SlicedEllpackView&& slicedEllpackView ) { } +template< typename Device, + typename Index, + bool RowMajorOrder, + int SliceSize > +String +SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >:: +getSerializationType() +{ + return "SlicedEllpack< [any_device], " + TNL::getSerializationType< IndexType >() + " >"; +} + template< typename Device, typename Index, bool RowMajorOrder, @@ -203,7 +214,7 @@ template< typename Device, __cuda_callable__ auto SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >:: -getSegmentView( const IndexType segmentIdx ) const -> SegmentView +getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { const IndexType sliceIdx = segmentIdx / SliceSize; const IndexType segmentInSliceIdx = segmentIdx % SliceSize; @@ -211,9 +222,9 @@ getSegmentView( const IndexType segmentIdx ) const -> SegmentView const IndexType& segmentSize = this->sliceSegmentSizes[ sliceIdx ]; if( RowMajorOrder ) - return SegmentView( sliceOffset, segmentSize, 1 ); + return SegmentViewType( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); else - return SegmentView( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); + return SegmentViewType( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); } template< typename Device, diff --git a/src/TNL/Matrices/Matrix.h b/src/TNL/Matrices/Matrix.h index 96409c89b74831e20e5676454321f427a72fef43..66a686046724b93ebdc4210bb9514ad2e42b5fe2 100644 --- a/src/TNL/Matrices/Matrix.h +++ b/src/TNL/Matrices/Matrix.h @@ -47,9 +47,9 @@ public: const IndexType columns, const RealAllocatorType& allocator = RealAllocatorType() ); - ViewType getView(); + /*ViewType getView(); - ConstViewType getConstView() const; + ConstViewType getConstView() const;*/ virtual void setDimensions( const IndexType rows, const IndexType columns ); diff --git a/src/TNL/Matrices/Matrix.hpp b/src/TNL/Matrices/Matrix.hpp index 91b81ffcf0e9e6635e7a49a0db035276e4a3bc61..3a09d0088ad9f9814b3b1822124884227c713373 100644 --- a/src/TNL/Matrices/Matrix.hpp +++ b/src/TNL/Matrices/Matrix.hpp @@ -43,7 +43,7 @@ Matrix( const IndexType rows_, const IndexType columns_, const RealAllocatorType { } -template< typename Real, +/*template< typename Real, typename Device, typename Index, typename RealAllocator > @@ -63,7 +63,7 @@ Matrix< Real, Device, Index, RealAllocator >:: getConstView() const -> ConstViewType { return ConstViewType( rows, columns, values.getConstView() ); -} +}*/ template< typename Real, typename Device, diff --git a/src/TNL/Matrices/MatrixView.h b/src/TNL/Matrices/MatrixView.h index 80fa28acfd3b205ee3ea9581258aeae721336e2d..18a9fb488a1bf6104dc0fee1e45be5b8936a0e20 100644 --- a/src/TNL/Matrices/MatrixView.h +++ b/src/TNL/Matrices/MatrixView.h @@ -49,11 +49,11 @@ public: __cuda_callable__ MatrixView( const MatrixView& view ) = default; - __cuda_callable__ - ViewType getView(); + //__cuda_callable__ + //ViewType getView(); - __cuda_callable__ - ConstViewType getConstView() const; + //__cuda_callable__ + //ConstViewType getConstView() const; virtual IndexType getRowLength( const IndexType row ) const = 0; @@ -65,7 +65,7 @@ public: IndexType getNumberOfMatrixElements() const; - virtual IndexType getNumberOfNonzeroMatrixElements() const = 0; + virtual IndexType getNumberOfNonzeroMatrixElements() const; void reset(); diff --git a/src/TNL/Matrices/MatrixView.hpp b/src/TNL/Matrices/MatrixView.hpp index 55ebc3d67116ac6928a9eeabd647919eb3059535..0473f52b84ddd15d7d253eeacfaeadd6f4a3a336 100644 --- a/src/TNL/Matrices/MatrixView.hpp +++ b/src/TNL/Matrices/MatrixView.hpp @@ -42,7 +42,7 @@ MatrixView( const IndexType rows_, { } -template< typename Real, +/*template< typename Real, typename Device, typename Index > __cuda_callable__ @@ -62,7 +62,7 @@ MatrixView< Real, Device, Index >:: getConstView() const -> ConstViewType { return ConstViewType( rows, columns, values.getConstView() ); -} +}*/ template< typename Real, typename Device, diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h index 46c02dfb0d3ed69e9880661a983bc5ef27437989..8169f89f2f29fcbc0c177b7d9ef1b3bac50f4958 100644 --- a/src/TNL/Matrices/SparseMatrix.h +++ b/src/TNL/Matrices/SparseMatrix.h @@ -36,8 +36,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > using SegmentsTemplate = Segments< Device_, Index_, IndexAllocator_ >; using SegmentsType = Segments< Device, Index, IndexAllocator >; template< typename Device_, typename Index_ > - using SegmentsViewTemplate = typename SegmentsType::ViewTemplate< Device_, Index >; - using SegmentViewType = typename SegmentsType::ViewType; + using SegmentsViewTemplate = typename SegmentsType::template ViewTemplate< Device_, Index >; + using SegmentViewType = typename SegmentsType::SegmentViewType; using DeviceType = Device; using IndexType = Index; using RealAllocatorType = RealAllocator; @@ -46,10 +46,12 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > using RowsCapacitiesView = Containers::VectorView< IndexType, DeviceType, IndexType >; using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector; - using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >; + using ValuesViewType = typename ValuesVectorType::ViewType; + using ColumnsIndexesVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >; + using ColumnsIndexesViewType = typename ColumnsIndexesVectorType::ViewType; using ViewType = SparseMatrixView< Real, Device, Index, MatrixType, SegmentsViewTemplate >; using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; - using RowView = SparseMatrixRowView< RealType, SegmentViewType >; + using RowView = SparseMatrixRowView< SegmentViewType, ValuesViewType, ColumnsIndexesViewType >; // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -246,7 +248,7 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > // TODO: restore it and also in Matrix // protected: - ColumnsVectorType columnIndexes; + ColumnsIndexesVectorType columnIndexes; SegmentsType segments; diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp index 3f26c95cae73c4bc3b65873c0d9a049c8be1ec64..c0dd3b9a3ecd9623177374f81b39bad96e3f7528 100644 --- a/src/TNL/Matrices/SparseMatrix.hpp +++ b/src/TNL/Matrices/SparseMatrix.hpp @@ -87,7 +87,7 @@ getView() -> ViewType return ViewType( this->getRows(), this->getColumns(), this->getValues().getView(), - this->getColumnsIndexes().getView(), + this->columnIndexes.getView(), this->segments.getView() ); } @@ -105,7 +105,7 @@ getConstView() const -> ConstViewType return ConstViewType( this->getRows(), this->getColumns(), this->getValues().getConstView(), - this->getColumnsIndexes().getConstView(), + this->columnIndexes.getConstView(), this->segments.getConstView() ); } @@ -299,8 +299,6 @@ SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAll reset() { Matrix< Real, Device, Index >::reset(); - this->columnIndexes.reset(); - } template< typename Real, diff --git a/src/TNL/Matrices/SparseMatrixRowView.h b/src/TNL/Matrices/SparseMatrixRowView.h index c6d0468f99f7ea9f6db616ea24d6fd969ad3ddf5..19445f531441a15b815c25cfdab519706b4a0abf 100644 --- a/src/TNL/Matrices/SparseMatrixRowView.h +++ b/src/TNL/Matrices/SparseMatrixRowView.h @@ -13,23 +13,23 @@ namespace TNL { namespace Matrices { -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > class SparseMatrixRowView { public: - using RealType = Real; + using RealType = typename ValuesView::RealType; using SegmentViewType = SegmentView; - using DeviceType = typename SegmentViewType::DeviceType; using IndexType = typename SegmentViewType::IndexType; - using ValuesView = Containers::VectorView< RealType, DeviceType, IndexType >; - using ColumnIndexesView = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ValuesViewType = ValuesView; + using ColumnsIndexesViewType = ColumnsIndexesView; __cuda_callable__ - SparseMatrixRowView( const SegmentView& segmentView, - const ValuesView& values, - const ColumnIndexesView& columnIndexes ); + SparseMatrixRowView( const SegmentViewType& segmentView, + const ValuesViewType& values, + const ColumnsIndexesViewType& columnIndexes ); __cuda_callable__ IndexType getSize() const; @@ -52,11 +52,11 @@ class SparseMatrixRowView const RealType& value ); protected: - SegmentView segmentView; + SegmentViewType segmentView; - ValuesView values; + ValuesViewType values; - ColumnIndexesView columnIndexes; + ColumnsIndexesViewType columnIndexes; }; } // namespace Matrices } // namespace TNL diff --git a/src/TNL/Matrices/SparseMatrixRowView.hpp b/src/TNL/Matrices/SparseMatrixRowView.hpp index 364bb8e2eb4e6a7780d57b16422c5544c21232a5..70dac874e04ff1ec9632da7940ebde93644a869c 100644 --- a/src/TNL/Matrices/SparseMatrixRowView.hpp +++ b/src/TNL/Matrices/SparseMatrixRowView.hpp @@ -15,70 +15,77 @@ namespace TNL { namespace Matrices { -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ -SparseMatrixRowView< Real, SegmentView >:: -SparseMatrixRowView( const SegmentView& segmentView, - const ValuesView& values, - const ColumnIndexesView& columnIndexes ) +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: +SparseMatrixRowView( const SegmentViewType& segmentView, + const ValuesViewType& values, + const ColumnsIndexesViewType& columnIndexes ) : segmentView( segmentView ), values( values ), columnIndexes( columnIndexes ) { } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ auto -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: getSize() const -> IndexType { return segmentView.getSize(); } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ auto -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: getColumnIndex( const IndexType localIdx ) const -> const IndexType& { TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." ); return columnIndexes[ segmentView.getGlobalIndex( localIdx ) ]; } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ auto -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: getColumnIndex( const IndexType localIdx ) -> IndexType& { TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." ); return columnIndexes[ segmentView.getGlobalIndex( localIdx ) ]; } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ auto -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: getValue( const IndexType localIdx ) const -> const RealType& { TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." ); return values[ segmentView.getGlobalIndex( localIdx ) ]; } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ auto -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: getValue( const IndexType localIdx ) -> RealType& { TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." ); return values[ segmentView.getGlobalIndex( localIdx ) ]; } -template< typename Real, - typename SegmentView > +template< typename SegmentView, + typename ValuesView, + typename ColumnsIndexesView > __cuda_callable__ void -SparseMatrixRowView< Real, SegmentView >:: +SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >:: setElement( const IndexType localIdx, const IndexType column, const RealType& value ) diff --git a/src/TNL/Matrices/SparseMatrixView.h b/src/TNL/Matrices/SparseMatrixView.h index a674ee807054c7cbaa207e109c364e7bdcfb3c76..714692df85b717cb1eb4ba1ccdb13434e58548db 100644 --- a/src/TNL/Matrices/SparseMatrixView.h +++ b/src/TNL/Matrices/SparseMatrixView.h @@ -32,15 +32,16 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > template< typename Device_, typename Index_ > using SegmentsViewTemplate = SegmentsView< Device_, Index_ >; using SegmentsViewType = SegmentsView< Device, Index >; + using SegmentViewType = typename SegmentsViewType::SegmentViewType; using DeviceType = Device; using IndexType = Index; using RowsCapacitiesView = Containers::VectorView< IndexType, DeviceType, IndexType >; using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesViewType = typename MatrixView< Real, Device, Index >::ValuesView; - using ColumnsViewType = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ColumnsIndexesViewType = Containers::VectorView< IndexType, DeviceType, IndexType >; using ViewType = SparseMatrixView< typename std::remove_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; - using RowView = SparseMatrixRowView< RealType, SegmentsViewType >; + using RowView = SparseMatrixRowView< SegmentViewType, ValuesViewType, ColumnsIndexesViewType >; // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -55,9 +56,9 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > __cuda_callable__ SparseMatrixView( const IndexType rows, const IndexType columns, - ValuesViewType& values, - ColumnsViewType& columnIndexes, - SegmentsViewType& segments ); + const ValuesViewType& values, + const ColumnsIndexesViewType& columnIndexes, + const SegmentsViewType& segments ); __cuda_callable__ SparseMatrixView( const SparseMatrixView& m ) = default; @@ -204,7 +205,7 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > protected: - ColumnsViewType columnIndexes; + ColumnsIndexesViewType columnIndexes; SegmentsViewType segments; }; diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp index 3f97431248d19f06b6a8ede31683b8b0d9de117b..5ac494a9b6a593fed8f7c8580f6169bb4eb0c65a 100644 --- a/src/TNL/Matrices/SparseMatrixView.hpp +++ b/src/TNL/Matrices/SparseMatrixView.hpp @@ -37,9 +37,9 @@ __cuda_callable__ SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: SparseMatrixView( const IndexType rows, const IndexType columns, - ValuesViewType& values, - ColumnsViewType& columnIndexes, - SegmentsViewType& segments ) + const ValuesViewType& values, + const ColumnsIndexesViewType& columnIndexes, + const SegmentsViewType& segments ) : MatrixView< Real, Device, Index >( rows, columns, values ), columnIndexes( columnIndexes ), segments( segments ) { } @@ -57,7 +57,7 @@ getView() -> ViewType return ViewType( this->getRows(), this->getColumns(), this->getValues().getView(), - this->getColumnsIndexes().getView(), + this->columnIndexes.getView(), this->segments.getView() ); } @@ -89,7 +89,7 @@ getSerializationType() { return String( "Matrices::SparseMatrix< " ) + TNL::getSerializationType< RealType >() + ", " + - TNL::getSerializationType< SegmentsView >() + ", [any_device], " + + TNL::getSerializationType< SegmentsViewType >() + ", [any_device], " + TNL::getSerializationType< IndexType >() + ", [any_allocator] >"; } @@ -648,7 +648,7 @@ void SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: save( File& file ) const { - Matrix< RealType, DeviceType, IndexType >::save( file ); + MatrixView< RealType, DeviceType, IndexType >::save( file ); file << this->columnIndexes; this->segments.save( file ); } diff --git a/src/UnitTests/Matrices/SparseMatrixTest.hpp b/src/UnitTests/Matrices/SparseMatrixTest.hpp index 07a60178f172239cc907cf742a6922011b39d9f2..72dfc90e8ed12730a0e15dd6d204de8529321145 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest.hpp +++ b/src/UnitTests/Matrices/SparseMatrixTest.hpp @@ -11,6 +11,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Containers/VectorView.h> #include <TNL/Math.h> +#include <TNL/Algorithms/ParallelFor.h> #include <iostream> // Temporary, until test_OperatorEquals doesn't work for all formats. @@ -249,6 +250,232 @@ void test_Reset() EXPECT_EQ( m.getColumns(), 0 ); } +template< typename Matrix > +void test_GetRow() +{ + using RealType = typename Matrix::RealType; + using DeviceType = typename Matrix::DeviceType; + using IndexType = typename Matrix::IndexType; + +/* + * Sets up the following 10x10 sparse matrix: + * + * / 1 0 2 0 3 0 4 0 0 0 \ + * | 5 6 7 0 0 0 0 0 0 0 | + * | 8 9 10 11 12 13 14 15 0 0 | + * | 16 17 0 0 0 0 0 0 0 0 | + * | 18 0 0 0 0 0 0 0 0 0 | + * | 19 0 0 0 0 0 0 0 0 0 | + * | 20 0 0 0 0 0 0 0 0 0 | + * | 21 0 0 0 0 0 0 0 0 0 | + * | 22 23 24 25 26 27 28 29 30 31 | + * \ 32 33 34 35 36 37 38 39 40 41 / + */ + + const IndexType rows = 10; + const IndexType cols = 10; + + Matrix m( rows, cols ); + + typename Matrix::CompressedRowLengthsVector rowLengths; + rowLengths.setSize( rows ); + rowLengths.setElement( 0, 4 ); + rowLengths.setElement( 1, 3 ); + rowLengths.setElement( 2, 8 ); + rowLengths.setElement( 3, 2 ); + for( IndexType i = 4; i < rows - 2; i++ ) + { + rowLengths.setElement( i, 1 ); + } + rowLengths.setElement( 8, 10 ); + rowLengths.setElement( 9, 10 ); + m.setCompressedRowLengths( rowLengths ); + + /*RealType value = 1; + for( IndexType i = 0; i < 4; i++ ) + m.setElement( 0, 2 * i, value++ ); + + for( IndexType i = 0; i < 3; i++ ) + m.setElement( 1, i, value++ ); + + for( IndexType i = 0; i < 8; i++ ) + m.setElement( 2, i, value++ ); + + for( IndexType i = 0; i < 2; i++ ) + m.setElement( 3, i, value++ ); + + for( IndexType i = 4; i < 8; i++ ) + m.setElement( i, 0, value++ ); + + for( IndexType j = 8; j < rows; j++) + { + for( IndexType i = 0; i < cols; i++ ) + m.setElement( j, i, value++ ); + }*/ + auto matrixView = m.getView(); + auto f = [=] __cuda_callable__ ( const IndexType rowIdx ) mutable { + auto row = matrixView.getRow( rowIdx ); + RealType val; + switch( rowIdx ) + { + case 0: + val = 1; + for( IndexType i = 0; i < 4; i++ ) + row.setElement( i, 2 * i, val++ ); + break; + case 1: + val = 5; + for( IndexType i = 0; i < 3; i++ ) + row.setElement( i, i, val++ ); + break; + case 2: + val = 8; + for( IndexType i = 0; i < 8; i++ ) + row.setElement( i, i, val++ ); + break; + case 3: + val = 16; + for( IndexType i = 0; i < 2; i++ ) + row.setElement( i, i, val++ ); + break; + case 4: + row.setElement( 0, 0, 18 ); + break; + case 5: + row.setElement( 0, 0, 19 ); + break; + case 6: + row.setElement( 0, 0, 20 ); + break; + case 7: + row.setElement( 0, 0, 21 ); + break; + case 8: + val = 22; + for( IndexType i = 0; i < rows; i++ ) + row.setElement( i, i, val++ ); + break; + case 9: + val = 32; + for( IndexType i = 0; i < rows; i++ ) + row.setElement( i, i, val++ ); + break; + } + }; + TNL::Algorithms::ParallelFor< DeviceType >::exec( ( IndexType ) 0, rows, f ); + + EXPECT_EQ( m.getElement( 0, 0 ), 1 ); + EXPECT_EQ( m.getElement( 0, 1 ), 0 ); + EXPECT_EQ( m.getElement( 0, 2 ), 2 ); + EXPECT_EQ( m.getElement( 0, 3 ), 0 ); + EXPECT_EQ( m.getElement( 0, 4 ), 3 ); + EXPECT_EQ( m.getElement( 0, 5 ), 0 ); + EXPECT_EQ( m.getElement( 0, 6 ), 4 ); + EXPECT_EQ( m.getElement( 0, 7 ), 0 ); + EXPECT_EQ( m.getElement( 0, 8 ), 0 ); + EXPECT_EQ( m.getElement( 0, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 1, 0 ), 5 ); + EXPECT_EQ( m.getElement( 1, 1 ), 6 ); + EXPECT_EQ( m.getElement( 1, 2 ), 7 ); + EXPECT_EQ( m.getElement( 1, 3 ), 0 ); + EXPECT_EQ( m.getElement( 1, 4 ), 0 ); + EXPECT_EQ( m.getElement( 1, 5 ), 0 ); + EXPECT_EQ( m.getElement( 1, 6 ), 0 ); + EXPECT_EQ( m.getElement( 1, 7 ), 0 ); + EXPECT_EQ( m.getElement( 1, 8 ), 0 ); + EXPECT_EQ( m.getElement( 1, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 2, 0 ), 8 ); + EXPECT_EQ( m.getElement( 2, 1 ), 9 ); + EXPECT_EQ( m.getElement( 2, 2 ), 10 ); + EXPECT_EQ( m.getElement( 2, 3 ), 11 ); + EXPECT_EQ( m.getElement( 2, 4 ), 12 ); + EXPECT_EQ( m.getElement( 2, 5 ), 13 ); + EXPECT_EQ( m.getElement( 2, 6 ), 14 ); + EXPECT_EQ( m.getElement( 2, 7 ), 15 ); + EXPECT_EQ( m.getElement( 2, 8 ), 0 ); + EXPECT_EQ( m.getElement( 2, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 3, 0 ), 16 ); + EXPECT_EQ( m.getElement( 3, 1 ), 17 ); + EXPECT_EQ( m.getElement( 3, 2 ), 0 ); + EXPECT_EQ( m.getElement( 3, 3 ), 0 ); + EXPECT_EQ( m.getElement( 3, 4 ), 0 ); + EXPECT_EQ( m.getElement( 3, 5 ), 0 ); + EXPECT_EQ( m.getElement( 3, 6 ), 0 ); + EXPECT_EQ( m.getElement( 3, 7 ), 0 ); + EXPECT_EQ( m.getElement( 3, 8 ), 0 ); + EXPECT_EQ( m.getElement( 3, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 4, 0 ), 18 ); + EXPECT_EQ( m.getElement( 4, 1 ), 0 ); + EXPECT_EQ( m.getElement( 4, 2 ), 0 ); + EXPECT_EQ( m.getElement( 4, 3 ), 0 ); + EXPECT_EQ( m.getElement( 4, 4 ), 0 ); + EXPECT_EQ( m.getElement( 4, 5 ), 0 ); + EXPECT_EQ( m.getElement( 4, 6 ), 0 ); + EXPECT_EQ( m.getElement( 4, 7 ), 0 ); + EXPECT_EQ( m.getElement( 4, 8 ), 0 ); + EXPECT_EQ( m.getElement( 4, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 5, 0 ), 19 ); + EXPECT_EQ( m.getElement( 5, 1 ), 0 ); + EXPECT_EQ( m.getElement( 5, 2 ), 0 ); + EXPECT_EQ( m.getElement( 5, 3 ), 0 ); + EXPECT_EQ( m.getElement( 5, 4 ), 0 ); + EXPECT_EQ( m.getElement( 5, 5 ), 0 ); + EXPECT_EQ( m.getElement( 5, 6 ), 0 ); + EXPECT_EQ( m.getElement( 5, 7 ), 0 ); + EXPECT_EQ( m.getElement( 5, 8 ), 0 ); + EXPECT_EQ( m.getElement( 5, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 6, 0 ), 20 ); + EXPECT_EQ( m.getElement( 6, 1 ), 0 ); + EXPECT_EQ( m.getElement( 6, 2 ), 0 ); + EXPECT_EQ( m.getElement( 6, 3 ), 0 ); + EXPECT_EQ( m.getElement( 6, 4 ), 0 ); + EXPECT_EQ( m.getElement( 6, 5 ), 0 ); + EXPECT_EQ( m.getElement( 6, 6 ), 0 ); + EXPECT_EQ( m.getElement( 6, 7 ), 0 ); + EXPECT_EQ( m.getElement( 6, 8 ), 0 ); + EXPECT_EQ( m.getElement( 6, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 7, 0 ), 21 ); + EXPECT_EQ( m.getElement( 7, 1 ), 0 ); + EXPECT_EQ( m.getElement( 7, 2 ), 0 ); + EXPECT_EQ( m.getElement( 7, 3 ), 0 ); + EXPECT_EQ( m.getElement( 7, 4 ), 0 ); + EXPECT_EQ( m.getElement( 7, 5 ), 0 ); + EXPECT_EQ( m.getElement( 7, 6 ), 0 ); + EXPECT_EQ( m.getElement( 7, 7 ), 0 ); + EXPECT_EQ( m.getElement( 7, 8 ), 0 ); + EXPECT_EQ( m.getElement( 7, 9 ), 0 ); + + EXPECT_EQ( m.getElement( 8, 0 ), 22 ); + EXPECT_EQ( m.getElement( 8, 1 ), 23 ); + EXPECT_EQ( m.getElement( 8, 2 ), 24 ); + EXPECT_EQ( m.getElement( 8, 3 ), 25 ); + EXPECT_EQ( m.getElement( 8, 4 ), 26 ); + EXPECT_EQ( m.getElement( 8, 5 ), 27 ); + EXPECT_EQ( m.getElement( 8, 6 ), 28 ); + EXPECT_EQ( m.getElement( 8, 7 ), 29 ); + EXPECT_EQ( m.getElement( 8, 8 ), 30 ); + EXPECT_EQ( m.getElement( 8, 9 ), 31 ); + + EXPECT_EQ( m.getElement( 9, 0 ), 32 ); + EXPECT_EQ( m.getElement( 9, 1 ), 33 ); + EXPECT_EQ( m.getElement( 9, 2 ), 34 ); + EXPECT_EQ( m.getElement( 9, 3 ), 35 ); + EXPECT_EQ( m.getElement( 9, 4 ), 36 ); + EXPECT_EQ( m.getElement( 9, 5 ), 37 ); + EXPECT_EQ( m.getElement( 9, 6 ), 38 ); + EXPECT_EQ( m.getElement( 9, 7 ), 39 ); + EXPECT_EQ( m.getElement( 9, 8 ), 40 ); + EXPECT_EQ( m.getElement( 9, 9 ), 41 ); +} + + template< typename Matrix > void test_SetElement() { diff --git a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h index 353dcdbb0e1721f54803a20b04c417ca04bcca3f..e86e34f0ac882b485106cca0903436a3ba5b1a36 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h +++ b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h @@ -94,6 +94,14 @@ TYPED_TEST( CSRMatrixTest, resetTest ) test_Reset< CSRMatrixType >(); } +TYPED_TEST( CSRMatrixTest, getRowTest ) +{ + using CSRMatrixType = typename TestFixture::CSRMatrixType; + + test_GetRow< CSRMatrixType >(); +} + + TYPED_TEST( CSRMatrixTest, setElementTest ) { using CSRMatrixType = typename TestFixture::CSRMatrixType; diff --git a/src/UnitTests/Matrices/SparseMatrixTest_Ellpack_segments.h b/src/UnitTests/Matrices/SparseMatrixTest_Ellpack_segments.h index b7dc338345e4300786d6844f31ab70bf4637eeaa..f597e31993ba14405aa242cf63840d04c3159cf9 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest_Ellpack_segments.h +++ b/src/UnitTests/Matrices/SparseMatrixTest_Ellpack_segments.h @@ -105,6 +105,13 @@ TYPED_TEST( EllpackMatrixTest, resetTest ) test_Reset< EllpackMatrixType >(); } +TYPED_TEST( EllpackMatrixTest, getRowTest ) +{ + using EllpackMatrixType = typename TestFixture::EllpackMatrixType; + + test_GetRow< EllpackMatrixType >(); +} + TYPED_TEST( EllpackMatrixTest, setElementTest ) { using EllpackMatrixType = typename TestFixture::EllpackMatrixType; diff --git a/src/UnitTests/Matrices/SparseMatrixTest_SlicedEllpack_segments.h b/src/UnitTests/Matrices/SparseMatrixTest_SlicedEllpack_segments.h index b2404fe68cc01c163757b4568bc2706508c24c50..172ed722ac53ab47b23cd847e45f2a428af4b161 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest_SlicedEllpack_segments.h +++ b/src/UnitTests/Matrices/SparseMatrixTest_SlicedEllpack_segments.h @@ -106,6 +106,13 @@ TYPED_TEST( SlicedEllpackMatrixTest, resetTest ) test_Reset< SlicedEllpackMatrixType >(); } +TYPED_TEST( SlicedEllpackMatrixTest, getRowTest ) +{ + using SlicedEllpackMatrixType = typename TestFixture::SlicedEllpackMatrixType; + + test_GetRow< SlicedEllpackMatrixType >(); +} + TYPED_TEST( SlicedEllpackMatrixTest, setElementTest ) { using SlicedEllpackMatrixType = typename TestFixture::SlicedEllpackMatrixType;