From 0f8eb296551dbe58c43132365650beb85fe97897 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Sat, 28 Dec 2019 17:12:54 +0100
Subject: [PATCH] ViewType and ConstViewType added to Matrix(View) and
 SparseMatrix(View).

---
 src/TNL/Containers/Segments/CSR.h             |  2 ++
 src/TNL/Containers/Segments/CSRView.h         |  2 ++
 src/TNL/Containers/Segments/Ellpack.h         |  2 ++
 src/TNL/Containers/Segments/EllpackView.h     |  2 ++
 src/TNL/Containers/Segments/SlicedEllpack.h   |  2 ++
 .../Containers/Segments/SlicedEllpackView.h   |  2 ++
 src/TNL/Matrices/MatrixView.h                 | 20 +++++++----
 src/TNL/Matrices/MatrixView.hpp               | 22 ++++++++++++
 src/TNL/Matrices/SparseMatrix.h               |  8 +++++
 src/TNL/Matrices/SparseMatrix.hpp             | 36 +++++++++++++++++++
 src/TNL/Matrices/SparseMatrixView.h           |  9 +++++
 src/TNL/Matrices/SparseMatrixView.hpp         | 36 ++++++++++++++++++-
 12 files changed, 136 insertions(+), 7 deletions(-)

diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h
index add07f1dff..f140605590 100644
--- a/src/TNL/Containers/Segments/CSR.h
+++ b/src/TNL/Containers/Segments/CSR.h
@@ -30,6 +30,8 @@ class CSR
       using IndexType = Index;
       using OffsetsHolder = Containers::Vector< IndexType, DeviceType, typename std::remove_const< IndexType >::type, IndexAllocator >;
       using SegmentsSizes = OffsetsHolder;
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = CSRView< Device_, Index_ >;
       using ViewType = CSRView< Device, Index >;
       using ConstViewType = CSRView< Device, std::add_const_t< Index > >;
 
diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h
index 2f89579702..4917df9e8e 100644
--- a/src/TNL/Containers/Segments/CSRView.h
+++ b/src/TNL/Containers/Segments/CSRView.h
@@ -29,6 +29,8 @@ class CSRView
       using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, IndexType >;
       using ConstOffsetsView = typename Containers::Vector< IndexType, DeviceType, IndexType >::ConstViewType;
       using ViewType = CSRView;
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = CSRView< Device_, Index_ >;
       using ConstViewType = CSRView< Device, std::add_const_t< Index > >;
 
       __cuda_callable__
diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h
index b9b3e63c1e..8cb430b6a4 100644
--- a/src/TNL/Containers/Segments/Ellpack.h
+++ b/src/TNL/Containers/Segments/Ellpack.h
@@ -32,6 +32,8 @@ class Ellpack
       static constexpr bool getRowMajorOrder() { return RowMajorOrder; }
       using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >;
       using SegmentsSizes = OffsetsHolder;
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = EllpackView< Device_, Index_ >;
       using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >;
       //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >;
 
diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h
index adbfee629c..6c6926be92 100644
--- a/src/TNL/Containers/Segments/EllpackView.h
+++ b/src/TNL/Containers/Segments/EllpackView.h
@@ -33,6 +33,8 @@ class EllpackView
       static constexpr bool getRowMajorOrder() { return RowMajorOrder; }
       using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >;
       using SegmentsSizes = OffsetsHolder;
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = EllpackView< Device_, Index_ >;
       using ViewType = EllpackView;
       //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >;
 
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h
index 9c2e7157f7..946c9b642c 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.h
+++ b/src/TNL/Containers/Segments/SlicedEllpack.h
@@ -32,6 +32,8 @@ class SlicedEllpack
       static constexpr int getSliceSize() { return SliceSize; }
       static constexpr bool getRowMajorOrder() { return RowMajorOrder; }
       using ViewType = SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >;
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = SlicedEllpackView< Device_, Index_ >;
       using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >;
 
       SlicedEllpack();
diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h
index 275baacf5f..adcf9ef5a0 100644
--- a/src/TNL/Containers/Segments/SlicedEllpackView.h
+++ b/src/TNL/Containers/Segments/SlicedEllpackView.h
@@ -31,6 +31,8 @@ class SlicedEllpackView
       using OffsetsView = typename Containers::VectorView< IndexType, DeviceType, typename std::remove_const < IndexType >::type >;
       static constexpr int getSliceSize() { return SliceSize; }
       static constexpr bool getRowMajorOrder() { return RowMajorOrder; }
+      template< typename Device_, typename Index_ >
+      using ViewTemplate = SlicedEllpackView< Device_, Index_ >;
       using ViewType = SlicedEllpackView;
       using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >;
 
diff --git a/src/TNL/Matrices/MatrixView.h b/src/TNL/Matrices/MatrixView.h
index a2fa975cf1..80fa28acfd 100644
--- a/src/TNL/Matrices/MatrixView.h
+++ b/src/TNL/Matrices/MatrixView.h
@@ -29,12 +29,14 @@ class MatrixView : public Object
 {
 public:
    using RealType = Real;
-   typedef Device DeviceType;
-   typedef Index IndexType;
-   typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector;
-   typedef Containers::VectorView< IndexType, DeviceType, IndexType > CompressedRowLengthsVectorView;
-   typedef typename CompressedRowLengthsVectorView::ConstViewType ConstCompressedRowLengthsVectorView;
-   typedef Containers::VectorView< RealType, DeviceType, IndexType > ValuesView;
+   using DeviceType = Device;
+   using IndexType = Index;
+   using CompressedRowLengthsVector = Containers::Vector< IndexType, DeviceType, IndexType >;
+   using CompressedRowLengthsVectorView = Containers::VectorView< IndexType, DeviceType, IndexType >;
+   using ConstCompressedRowLengthsVectorView = typename CompressedRowLengthsVectorView::ConstViewType;
+   using ValuesView = Containers::VectorView< RealType, DeviceType, IndexType >;
+   using ViewType = MatrixView< typename std::remove_const< Real >::type, Device, Index >;
+   using ConstViewType = MatrixView< typename std::add_const< Real >::type, Device, Index >;
 
    __cuda_callable__
    MatrixView();
@@ -47,6 +49,12 @@ public:
    __cuda_callable__
    MatrixView( const MatrixView& view ) = default;
 
+   __cuda_callable__
+   ViewType getView();
+
+   __cuda_callable__
+   ConstViewType getConstView() const;
+
    virtual IndexType getRowLength( const IndexType row ) const = 0;
 
    // TODO: implementation is not parallel
diff --git a/src/TNL/Matrices/MatrixView.hpp b/src/TNL/Matrices/MatrixView.hpp
index bd3d9beaee..55ebc3d671 100644
--- a/src/TNL/Matrices/MatrixView.hpp
+++ b/src/TNL/Matrices/MatrixView.hpp
@@ -42,6 +42,28 @@ MatrixView( const IndexType rows_,
 {
 }
 
+template< typename Real,
+          typename Device,
+          typename Index >
+__cuda_callable__
+auto
+MatrixView< Real, Device, Index >::
+getView() ->ViewType
+{
+   return ViewType( rows, columns, values.getView() );
+}
+
+template< typename Real,
+          typename Device,
+          typename Index >
+__cuda_callable__
+auto
+MatrixView< Real, Device, Index >::
+getConstView() const -> ConstViewType
+{
+   return ConstViewType( rows, columns, values.getConstView() );
+}
+
 template< typename Real,
           typename Device,
           typename Index >
diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h
index 5f02e9fdec..558cbb5b10 100644
--- a/src/TNL/Matrices/SparseMatrix.h
+++ b/src/TNL/Matrices/SparseMatrix.h
@@ -34,6 +34,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       template< typename Device_, typename Index_, typename IndexAllocator_ >
       using SegmentsTemplate = Segments< Device_, Index_, IndexAllocator_ >;
       using SegmentsType = Segments< Device, Index, IndexAllocator >;
+      template< typename Device_, typename Index_ >
+      using SegmentsViewTemplate = typename SegmentsType::ViewTemplate< Device_, Index >;
       using DeviceType = Device;
       using IndexType = Index;
       using RealAllocatorType = RealAllocator;
@@ -43,6 +45,8 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType;
       using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector;
       using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >;
+      using ViewType = SparseMatrixView< Real, Device, Index, MatrixType, SegmentsViewTemplate >;
+      using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
 
       // TODO: remove this - it is here only for compatibility with original matrix implementation
       typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector;
@@ -63,6 +67,10 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
                     const RealAllocatorType& realAllocator = RealAllocatorType(),
                     const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
 
+      ViewType getView();
+
+      ConstViewType getConstView() const;
+
       static String getSerializationType();
 
       virtual String getSerializationTypeVirtual() const;
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index 08eae92b45..8af68bd4dc 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -73,6 +73,42 @@ SparseMatrix( const IndexType rows,
 {
 }
 
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename, typename > class Segments,
+          typename RealAllocator,
+          typename IndexAllocator >
+auto
+SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
+getView() -> ViewType
+{
+   return ViewType( this->getRows(), 
+                    this->getColumns(),
+                    this->getValues().getView(),
+                    this->getColumnsIndexes().getView(),
+                    this->segments.getView() );
+}
+
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename, typename > class Segments,
+          typename RealAllocator,
+          typename IndexAllocator >
+auto
+SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
+getConstView() const -> ConstViewType
+{
+   return ConstViewType( this->getRows(),
+                         this->getColumns(),
+                         this->getValues().getConstView(),
+                         this->getColumnsIndexes().getConstView(),
+                         this->segments.getConstView() );
+}
+
 template< typename Real,
           typename Device,
           typename Index,
diff --git a/src/TNL/Matrices/SparseMatrixView.h b/src/TNL/Matrices/SparseMatrixView.h
index b40d9c0c2d..847c21dd5f 100644
--- a/src/TNL/Matrices/SparseMatrixView.h
+++ b/src/TNL/Matrices/SparseMatrixView.h
@@ -37,6 +37,9 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >
       using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType;
       using ValuesViewType = typename MatrixView< Real, Device, Index >::ValuesView;
       using ColumnsViewType = Containers::VectorView< IndexType, DeviceType, IndexType >;
+      using ViewType = SparseMatrixView< typename std::remove_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
+      using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
+
 
       // TODO: remove this - it is here only for compatibility with original matrix implementation
       typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector;
@@ -61,6 +64,12 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >
       //__cuda_callable__
       //SparseMatrixView( const SparseMatrixView&& m ) = default;
 
+      __cuda_callable__
+      ViewType getView();
+
+      __cuda_callable__
+      ConstViewType getConstView() const;
+
       static String getSerializationType();
 
       virtual String getSerializationTypeVirtual() const;
diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp
index 0c49cd58d2..ffcba43dce 100644
--- a/src/TNL/Matrices/SparseMatrixView.hpp
+++ b/src/TNL/Matrices/SparseMatrixView.hpp
@@ -41,7 +41,41 @@ SparseMatrixView( const IndexType rows,
                   ColumnsViewType& columnIndexes,
                   SegmentsViewType& segments )
  : MatrixView< Real, Device, Index >( rows, columns, values ), columnIndexes( columnIndexes ), segments( segments )
-{  
+{
+}
+
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename > class SegmentsView >
+__cuda_callable__
+auto
+SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >::
+getView() -> ViewType
+{
+   return ViewType( this->getRows(), 
+                    this->getColumns(),
+                    this->getValues().getView(),
+                    this->getColumnsIndexes().getView(),
+                    this->segments.getView() );
+}
+
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename > class SegmentsView >
+__cuda_callable__
+auto
+SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >::
+getConstView() const -> ConstViewType
+{
+   return ConstViewType( this->getRows(),
+                         this->getColumns(),
+                         this->getValues().getConstView(),
+                         this->getColumnsIndexes().getConstView(),
+                         this->segments.getConstView() );
 }
 
 template< typename Real,
-- 
GitLab