From 0f8eb296551dbe58c43132365650beb85fe97897 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Sat, 28 Dec 2019 17:12:54 +0100 Subject: [PATCH] ViewType and ConstViewType added to Matrix(View) and SparseMatrix(View). --- src/TNL/Containers/Segments/CSR.h | 2 ++ src/TNL/Containers/Segments/CSRView.h | 2 ++ src/TNL/Containers/Segments/Ellpack.h | 2 ++ src/TNL/Containers/Segments/EllpackView.h | 2 ++ src/TNL/Containers/Segments/SlicedEllpack.h | 2 ++ .../Containers/Segments/SlicedEllpackView.h | 2 ++ src/TNL/Matrices/MatrixView.h | 20 +++++++---- src/TNL/Matrices/MatrixView.hpp | 22 ++++++++++++ src/TNL/Matrices/SparseMatrix.h | 8 +++++ src/TNL/Matrices/SparseMatrix.hpp | 36 +++++++++++++++++++ src/TNL/Matrices/SparseMatrixView.h | 9 +++++ src/TNL/Matrices/SparseMatrixView.hpp | 36 ++++++++++++++++++- 12 files changed, 136 insertions(+), 7 deletions(-) diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index add07f1dff..f140605590 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -30,6 +30,8 @@ class CSR using IndexType = Index; using OffsetsHolder = Containers::Vector< IndexType, DeviceType, typename std::remove_const< IndexType >::type, IndexAllocator >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index 2f89579702..4917df9e8e 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -29,6 +29,8 @@ class CSRView using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, IndexType >; using ConstOffsetsView = typename Containers::Vector< IndexType, DeviceType, IndexType >::ConstViewType; using ViewType = CSRView; + template< typename Device_, typename Index_ > + using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; __cuda_callable__ diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index b9b3e63c1e..8cb430b6a4 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -32,6 +32,8 @@ class Ellpack static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >; //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >; diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h index adbfee629c..6c6926be92 100644 --- a/src/TNL/Containers/Segments/EllpackView.h +++ b/src/TNL/Containers/Segments/EllpackView.h @@ -33,6 +33,8 @@ class EllpackView static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >; using SegmentsSizes = OffsetsHolder; + template< typename Device_, typename Index_ > + using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView; //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index 9c2e7157f7..946c9b642c 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -32,6 +32,8 @@ class SlicedEllpack static constexpr int getSliceSize() { return SliceSize; } static constexpr bool getRowMajorOrder() { return RowMajorOrder; } using ViewType = SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >; + template< typename Device_, typename Index_ > + using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >; SlicedEllpack(); diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h index 275baacf5f..adcf9ef5a0 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.h +++ b/src/TNL/Containers/Segments/SlicedEllpackView.h @@ -31,6 +31,8 @@ class SlicedEllpackView using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, typename std::remove_const < IndexType >::type >; static constexpr int getSliceSize() { return SliceSize; } static constexpr bool getRowMajorOrder() { return RowMajorOrder; } + template< typename Device_, typename Index_ > + using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ViewType = SlicedEllpackView; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >; diff --git a/src/TNL/Matrices/MatrixView.h b/src/TNL/Matrices/MatrixView.h index a2fa975cf1..80fa28acfd 100644 --- a/src/TNL/Matrices/MatrixView.h +++ b/src/TNL/Matrices/MatrixView.h @@ -29,12 +29,14 @@ class MatrixView : public Object { public: using RealType = Real; - typedef Device DeviceType; - typedef Index IndexType; - typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; - typedef Containers::VectorView< IndexType, DeviceType, IndexType > CompressedRowLengthsVectorView; - typedef typename CompressedRowLengthsVectorView::ConstViewType ConstCompressedRowLengthsVectorView; - typedef Containers::VectorView< RealType, DeviceType, IndexType > ValuesView; + using DeviceType = Device; + using IndexType = Index; + using CompressedRowLengthsVector = Containers::Vector< IndexType, DeviceType, IndexType >; + using CompressedRowLengthsVectorView = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ConstCompressedRowLengthsVectorView = typename CompressedRowLengthsVectorView::ConstViewType; + using ValuesView = Containers::VectorView< RealType, DeviceType, IndexType >; + using ViewType = MatrixView< typename std::remove_const< Real >::type, Device, Index >; + using ConstViewType = MatrixView< typename std::add_const< Real >::type, Device, Index >; __cuda_callable__ MatrixView(); @@ -47,6 +49,12 @@ public: __cuda_callable__ MatrixView( const MatrixView& view ) = default; + __cuda_callable__ + ViewType getView(); + + __cuda_callable__ + ConstViewType getConstView() const; + virtual IndexType getRowLength( const IndexType row ) const = 0; // TODO: implementation is not parallel diff --git a/src/TNL/Matrices/MatrixView.hpp b/src/TNL/Matrices/MatrixView.hpp index bd3d9beaee..55ebc3d671 100644 --- a/src/TNL/Matrices/MatrixView.hpp +++ b/src/TNL/Matrices/MatrixView.hpp @@ -42,6 +42,28 @@ MatrixView( const IndexType rows_, { } +template< typename Real, + typename Device, + typename Index > +__cuda_callable__ +auto +MatrixView< Real, Device, Index >:: +getView() ->ViewType +{ + return ViewType( rows, columns, values.getView() ); +} + +template< typename Real, + typename Device, + typename Index > +__cuda_callable__ +auto +MatrixView< Real, Device, Index >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( rows, columns, values.getConstView() ); +} + template< typename Real, typename Device, typename Index > diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h index 5f02e9fdec..558cbb5b10 100644 --- a/src/TNL/Matrices/SparseMatrix.h +++ b/src/TNL/Matrices/SparseMatrix.h @@ -34,6 +34,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > template< typename Device_, typename Index_, typename IndexAllocator_ > using SegmentsTemplate = Segments< Device_, Index_, IndexAllocator_ >; using SegmentsType = Segments< Device, Index, IndexAllocator >; + template< typename Device_, typename Index_ > + using SegmentsViewTemplate = typename SegmentsType::ViewTemplate< Device_, Index >; using DeviceType = Device; using IndexType = Index; using RealAllocatorType = RealAllocator; @@ -43,6 +45,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector; using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >; + using ViewType = SparseMatrixView< Real, Device, Index, MatrixType, SegmentsViewTemplate >; + using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -63,6 +67,10 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator > const RealAllocatorType& realAllocator = RealAllocatorType(), const IndexAllocatorType& indexAllocator = IndexAllocatorType() ); + ViewType getView(); + + ConstViewType getConstView() const; + static String getSerializationType(); virtual String getSerializationTypeVirtual() const; diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp index 08eae92b45..8af68bd4dc 100644 --- a/src/TNL/Matrices/SparseMatrix.hpp +++ b/src/TNL/Matrices/SparseMatrix.hpp @@ -73,6 +73,42 @@ SparseMatrix( const IndexType rows, { } +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename, typename > class Segments, + typename RealAllocator, + typename IndexAllocator > +auto +SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >:: +getView() -> ViewType +{ + return ViewType( this->getRows(), + this->getColumns(), + this->getValues().getView(), + this->getColumnsIndexes().getView(), + this->segments.getView() ); +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename, typename > class Segments, + typename RealAllocator, + typename IndexAllocator > +auto +SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( this->getRows(), + this->getColumns(), + this->getValues().getConstView(), + this->getColumnsIndexes().getConstView(), + this->segments.getConstView() ); +} + template< typename Real, typename Device, typename Index, diff --git a/src/TNL/Matrices/SparseMatrixView.h b/src/TNL/Matrices/SparseMatrixView.h index b40d9c0c2d..847c21dd5f 100644 --- a/src/TNL/Matrices/SparseMatrixView.h +++ b/src/TNL/Matrices/SparseMatrixView.h @@ -37,6 +37,9 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType; using ValuesViewType = typename MatrixView< Real, Device, Index >::ValuesView; using ColumnsViewType = Containers::VectorView< IndexType, DeviceType, IndexType >; + using ViewType = SparseMatrixView< typename std::remove_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; + using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >; + // TODO: remove this - it is here only for compatibility with original matrix implementation typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector; @@ -61,6 +64,12 @@ class SparseMatrixView : public MatrixView< Real, Device, Index > //__cuda_callable__ //SparseMatrixView( const SparseMatrixView&& m ) = default; + __cuda_callable__ + ViewType getView(); + + __cuda_callable__ + ConstViewType getConstView() const; + static String getSerializationType(); virtual String getSerializationTypeVirtual() const; diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp index 0c49cd58d2..ffcba43dce 100644 --- a/src/TNL/Matrices/SparseMatrixView.hpp +++ b/src/TNL/Matrices/SparseMatrixView.hpp @@ -41,7 +41,41 @@ SparseMatrixView( const IndexType rows, ColumnsViewType& columnIndexes, SegmentsViewType& segments ) : MatrixView< Real, Device, Index >( rows, columns, values ), columnIndexes( columnIndexes ), segments( segments ) -{ +{ +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename > class SegmentsView > +__cuda_callable__ +auto +SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: +getView() -> ViewType +{ + return ViewType( this->getRows(), + this->getColumns(), + this->getValues().getView(), + this->getColumnsIndexes().getView(), + this->segments.getView() ); +} + +template< typename Real, + typename Device, + typename Index, + typename MatrixType, + template< typename, typename > class SegmentsView > +__cuda_callable__ +auto +SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >:: +getConstView() const -> ConstViewType +{ + return ConstViewType( this->getRows(), + this->getColumns(), + this->getValues().getConstView(), + this->getColumnsIndexes().getConstView(), + this->segments.getConstView() ); } template< typename Real, -- GitLab