diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index add07f1dff5c587f3cde22d953a1390728f464e8..f140605590934619a644f556ad2c212daab6f2e5 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -30,6 +30,8 @@ class CSR using IndexType = Index; using OffsetsHolder = Containers::Vector< IndexType, DeviceType, typename std::remove_const< IndexType >::type, IndexAllocator >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index 2f89579702b543cf76f47d32e20bebf2c497828c..4917df9e8ea1044763136bfd2a2b644e5452ff67 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -29,6 +29,8 @@ class CSRView using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, IndexType >; using ConstOffsetsView = typename Containers::Vector< IndexType, DeviceType, IndexType >::ConstViewType; using ViewType = CSRView; + template< typename Device_, typename Index_ > + using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; __cuda_callable__ diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index b9b3e63c1efd7e92bc83eabc78d439637c448247..8cb430b6a4a57bf8c7733b7972a5af7b286af0f4 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -32,6 +32,8 @@ class Ellpack static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >; //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >; diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h index adbfee629c03d9ff49c572781cabc4a95c0ee0ba..6c6926be92f8cfc593fb7090dddb48acf2a6034d 100644 --- a/src/TNL/Containers/Segments/EllpackView.h +++ b/src/TNL/Containers/Segments/EllpackView.h @@ -33,6 +33,8 @@ class EllpackView static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView; //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index 9c2e7157f73d46094d0bb9bb8af251f164de9588..946c9b642c5102248043625618fc7085303cfd73 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -32,6 +32,8 @@ class SlicedEllpack static constexpr int getSliceSize() { return SliceSize; } static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using ViewType = SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >; + template< typename Device_, typename Index_ > + using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >; SlicedEllpack(); diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h index 275baacf5f5cdc2f367bee8ece316a4106f47ef8..adcf9ef5a090dca45e052733420dcecd17e1b3b5 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.h +++ b/src/TNL/Containers/Segments/SlicedEllpackView.h @@ -31,6 +31,8 @@ class SlicedEllpackView using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, typename std::remove_const < IndexType >::type >; static constexpr int getSliceSize() { return SliceSize; } static constexpr bool getRowMajorOrder() { return RowMajorOrder; } + template< typename Device_, typename Index_ > + using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ViewType = SlicedEllpackView; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Matrices/MatrixView.h b/src/TNL/Matrices/MatrixView.h index a2fa975cf1ae0e438870c36b667d437d40f3b4e8..80fa28acfd3b205ee3ea9581258aeae721336e2d 100644 --- a/src/TNL/Matrices/MatrixView.h +++ b/src/TNL/Matrices/MatrixView.h @@ -29,12 +29,14 @@ class MatrixView : public Object { public: using RealType = Real; - typedef Device DeviceType; - typedef Index IndexType; - typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; - typedef Containers::VectorView< IndexType, DeviceType, IndexType > CompressedRowLengthsVectorView; - typedef typename CompressedRowLengthsVectorView::ConstViewType ConstCompressedRowLengthsVectorView; - typedef Containers::VectorView< RealType, DeviceType, IndexType > ValuesView; + using DeviceType = Device; + using IndexType = Index; + using CompressedRowLengthsVector = Containers::Vector< IndexType, DeviceType, IndexType >; + using CompressedRowLengthsVectorView = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ConstCompressedRowLengthsVectorView = typename CompressedRowLengthsVectorView::ConstViewType; + using ValuesView = Containers::VectorView< RealType, DeviceType, IndexType >; + using ViewType = MatrixView< typename std::remove_const< Real >::type, Device, Index >; + using ConstViewType = MatrixView< typename std::add_const< Real >::type, Device, Index >; __cuda_callable__ MatrixView(); @@ -47,6 +49,12 @@ public: __cuda_callable__ MatrixView( const MatrixView& view ) = default; + __cuda_callable__ + ViewType getView(); + + __cuda_callable__ + ConstViewType getConstView() const; + virtual IndexType getRowLength( const IndexType row ) const = 0; // TODO: implementation is not parallel diff --git a/src/TNL/Matrices/MatrixView.hpp b/src/TNL/Matrices/MatrixView.hpp index bd3d9beaee73a8d55bc73c1e9de870a4c14b330c..55ebc3d67116ac6928a9eeabd647919eb3059535 100644 --- a/src/TNL/Matrices/MatrixView.hpp +++ b/src/TNL/Matrices/MatrixView.hpp @@ -42,6 +42,28 @@ MatrixView( const IndexType rows_, { } +template< typename Real, + typename Device, + typename Index > +__cuda_callable__ +auto +MatrixView< Real, Device, Index >:: +getView() ->ViewType +{ + return ViewType( rows, columns, values.getView() ); +} + +template< typename Real, + typename Device, + typename Index > +__cuda_callable__ +auto +MatrixView< Real, Device, Index >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( rows, columns, values.getConstView() ); +} + template< typename Real, typename Device, typename Index > diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h index 5f02e9fdecdb9194caf35cffce136c6ce51ef0de..558cbb5b1028a9ced91cb10324ef25d5ddb9f1e6 100644 --- a/src/TNL/Matrices/SparseMatrix.h +++ b/src/TNL/Matrices/SparseMatrix.h @@ -34,6 +34,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > template< typename Device_, typename Index_, typename IndexAllocator_ > using SegmentsTemplate = Segments< Device_, Index_, IndexAllocator_ >; using SegmentsType = Segments< Device, Index, IndexAllocator >; + template< typename Device_, typename Index_ > + using SegmentsViewTemplate = typename SegmentsType::ViewTemplate< Device_, Index >; using DeviceType = Device; using IndexType = Index; using RealAllocatorType = RealAllocator; @@ -43,6 +45,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector; using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >; + using ViewType = SparseMatrixView< Real, Device, Index, MatrixType, SegmentsViewTemplate >; + using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -63,6 +67,10 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > const RealAllocatorType& realAllocator = RealAllocatorType(), const IndexAllocatorType& indexAllocator = IndexAllocatorType() ); + ViewType getView(); + + ConstViewType getConstView() const; + static String getSerializationType(); virtual String getSerializationTypeVirtual() const; diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp index 08eae92b45a569630c315c94b3ab073e8301ac33..8af68bd4dc02615667054a0d237e70d83e262133 100644 --- a/src/TNL/Matrices/SparseMatrix.hpp +++ b/src/TNL/Matrices/SparseMatrix.hpp @@ -73,6 +73,42 @@ SparseMatrix( const IndexType rows, { } +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename, typename > class Segments, + typename RealAllocator, + typename IndexAllocator > +auto +SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >:: +getView() -> ViewType +{ + return ViewType( this->getRows(), + this->getColumns(), + this->getValues().getView(), + this->getColumnsIndexes().getView(), + this->segments.getView() ); +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename, typename > class Segments, + typename RealAllocator, + typename IndexAllocator > +auto +SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( this->getRows(), + this->getColumns(), + this->getValues().getConstView(), + this->getColumnsIndexes().getConstView(), + this->segments.getConstView() ); +} + template< typename Real, typename Device, typename Index, diff --git a/src/TNL/Matrices/SparseMatrixView.h b/src/TNL/Matrices/SparseMatrixView.h index b40d9c0c2dd476d5682933477ba7c553e488f284..847c21dd5fc6c94bf7b1109c12fae77e48bdab76 100644 --- a/src/TNL/Matrices/SparseMatrixView.h +++ b/src/TNL/Matrices/SparseMatrixView.h @@ -37,6 +37,9 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesViewType = typename MatrixView< Real, Device, Index >::ValuesView; using ColumnsViewType = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ViewType = SparseMatrixView< typename std::remove_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; + using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; + // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -61,6 +64,12 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > //__cuda_callable__ //SparseMatrixView( const SparseMatrixView&& m ) = default; + __cuda_callable__ + ViewType getView(); + + __cuda_callable__ + ConstViewType getConstView() const; + static String getSerializationType(); virtual String getSerializationTypeVirtual() const; diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp index 0c49cd58d2c82d8a02339d800aba3bcad2a298cd..ffcba43dce57309f03118b95e5ea9ad75c1e31a4 100644 --- a/src/TNL/Matrices/SparseMatrixView.hpp +++ b/src/TNL/Matrices/SparseMatrixView.hpp @@ -41,7 +41,41 @@ SparseMatrixView( const IndexType rows, ColumnsViewType& columnIndexes, SegmentsViewType& segments ) : MatrixView< Real, Device, Index >( rows, columns, values ), columnIndexes( columnIndexes ), segments( segments ) -{ +{ +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename > class SegmentsView > +__cuda_callable__ +auto +SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: +getView() -> ViewType +{ + return ViewType( this->getRows(), + this->getColumns(), + this->getValues().getView(), + this->getColumnsIndexes().getView(), + this->segments.getView() ); +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename > class SegmentsView > +__cuda_callable__ +auto +SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( this->getRows(), + this->getColumns(), + this->getValues().getConstView(), + this->getColumnsIndexes().getConstView(), + this->segments.getConstView() ); } template< typename Real,