From a87203a0b91da16f5e16b379e33ee09a8e2a1c57 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Sun, 22 Dec 2019 20:52:43 +0100
Subject: [PATCH] Implementing sparse matrix assignment.

---
 src/TNL/Containers/Segments/CSR.h             |  6 +-
 src/TNL/Containers/Segments/CSR.hpp           | 11 +++
 src/TNL/Containers/Segments/Ellpack.h         |  5 ++
 src/TNL/Containers/Segments/Ellpack.hpp       | 15 ++++
 src/TNL/Containers/Segments/SlicedEllpack.h   |  5 ++
 src/TNL/Containers/Segments/SlicedEllpack.hpp | 17 ++++
 src/TNL/Matrices/SparseMatrix.h               |  2 +-
 src/TNL/Matrices/SparseMatrix.hpp             | 84 +++++++++++++++----
 src/UnitTests/Matrices/SparseMatrixCopyTest.h | 32 +++++--
 9 files changed, 156 insertions(+), 21 deletions(-)

diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h
index b83e43f1d1..add07f1dff 100644
--- a/src/TNL/Containers/Segments/CSR.h
+++ b/src/TNL/Containers/Segments/CSR.h
@@ -93,7 +93,6 @@ class CSR
       template< typename Function, typename... Args >
       void forAll( Function& f, Args... args ) const;
 
-
       /***
        * \brief Go over all segments and perform a reduction in each of them.
        */
@@ -103,6 +102,11 @@ class CSR
       template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
       void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
 
+      CSR& operator=( const CSR& rhsSegments ) = default;
+
+      template< typename Device_, typename Index_, typename IndexAllocator_ >
+      CSR& operator=( const CSR< Device_, Index_, IndexAllocator_ >& source );
+
       void save( File& file ) const;
 
       void load( File& file );
diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp
index 280ed6ebf2..61720869ca 100644
--- a/src/TNL/Containers/Segments/CSR.hpp
+++ b/src/TNL/Containers/Segments/CSR.hpp
@@ -221,6 +221,17 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re
    this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
 }
 
+template< typename Device,
+          typename Index,
+          typename IndexAllocator >
+   template< typename Device_, typename Index_, typename IndexAllocator_ >
+CSR< Device, Index, IndexAllocator >&
+CSR< Device, Index, IndexAllocator >::
+operator=( const CSR< Device_, Index_, IndexAllocator_ >& source )
+{
+   this->offsets = source.offsets;
+}
+
 template< typename Device,
           typename Index,
           typename IndexAllocator >
diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h
index 9c81a84281..b9b3e63c1e 100644
--- a/src/TNL/Containers/Segments/Ellpack.h
+++ b/src/TNL/Containers/Segments/Ellpack.h
@@ -100,6 +100,11 @@ class Ellpack
       template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
       void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
 
+      Ellpack& operator=( const Ellpack& source ) = default;
+
+      template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_, int Alignment_ >
+      Ellpack& operator=( const Ellpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, Alignment_ >& source );
+
       void save( File& file ) const;
 
       void load( File& file );
diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp
index 482c87d4f3..97d30d3147 100644
--- a/src/TNL/Containers/Segments/Ellpack.hpp
+++ b/src/TNL/Containers/Segments/Ellpack.hpp
@@ -322,6 +322,21 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re
    this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
 }
 
+template< typename Device,
+          typename Index,
+          typename IndexAllocator,
+          bool RowMajorOrder,
+          int Alignment >
+   template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_, int Alignment_ >
+Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >&
+Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >::
+operator=( const Ellpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, Alignment_ >& source )
+{
+   this->segmentSize = source.segmentSize;
+   this->size = source.size;
+   this->alignedSize = roundUpDivision( size, this->getAlignment() ) * this->getAlignment();
+}
+
 template< typename Device,
           typename Index,
           typename IndexAllocator,
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h
index fc514c51f3..9c2e7157f7 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.h
+++ b/src/TNL/Containers/Segments/SlicedEllpack.h
@@ -96,6 +96,11 @@ class SlicedEllpack
       template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
       void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
 
+      SlicedEllpack& operator=( const SlicedEllpack& source ) = default;
+
+      template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_ >
+      SlicedEllpack& operator=( const SlicedEllpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, SliceSize >& source );
+
       void save( File& file ) const;
 
       void load( File& file );
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp
index bdf28ff73b..ad83f666a0 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.hpp
+++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp
@@ -356,6 +356,23 @@ allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Re
    this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
 }
 
+template< typename Device,
+          typename Index,
+          typename IndexAllocator,
+          bool RowMajorOrder,
+          int SliceSize >
+   template< typename Device_, typename Index_, typename IndexAllocator_, bool RowMajorOrder_ >
+SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >&
+SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >::
+operator=( const SlicedEllpack< Device_, Index_, IndexAllocator_, RowMajorOrder_, SliceSize >& source )
+{
+   this->size = source.size;
+   this->alignedSize = source.alignedSize;
+   this->segmentsCount = source.segmentsCount;
+   this->sliceOffsets = source.sliceOffsets;
+   this->sliceSegmentSizes = source.sliceSegmentSizes;
+}
+
 template< typename Device,
           typename Index,
           typename IndexAllocator,
diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h
index 1512f8574d..8c8fef599a 100644
--- a/src/TNL/Matrices/SparseMatrix.h
+++ b/src/TNL/Matrices/SparseMatrix.h
@@ -218,7 +218,7 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
 
       SegmentsType segments;
 
-      IndexAllocator indexAlloctor;
+      IndexAllocator indexAllocator;
 
       RealAllocator realAllocator;
 
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index e24ed2f448..5de4473ab5 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -680,7 +680,7 @@ operator=( const SparseMatrix& matrix )
    Matrix< Real, Device, Index >::operator=( matrix );
    this->columnIndexes = matrix.columnIndexes;
    this->segments = matrix.segments;
-   this->indexAlloctor = matrix.indexAllocator;
+   this->indexAllocator = matrix.indexAllocator;
    this->realAllocator = matrix.realAllocator;
 }
 
@@ -702,29 +702,85 @@ SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix )
 {
    using RHSMatrixType = SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >;
+   RowsCapacitiesType rowLengths;
+   matrix.getCompressedRowLengths( rowLengths );
+   this->setCompressedRowLengths( rowLengths );
+
+   // TODO: Replace this with SparseMatrixView
+   const auto matrix_columns_view = matrix.columnIndexes.getConstView();
+   const auto matrix_values_view = matrix.values.getConstView();
+   const IndexType paddingIndex = this->getPaddingIndex();
+   auto this_columns_view = this->columnIndexes.getView();
+   auto this_values_view = this->values.getView();
+
    if( std::is_same< Device, Device2 >::value )
    {
-      /*RowsCapacitiesType rowLengths;
-      matrix.getCompressedRowLengths( rowLengths );
-      this->setCompressedRowLengths( rowLengths );
-      // TODO: Replace this with SparseMatrixView
-      const auto matrix_columns_view = matrix.columnIndexes.getConstView();
-      const auto matrix_values_view = matrix.values.getConstView();
-      const auto segments_view = this->segments.getConstView();
-      auto this_columns_view = this->columnIndexes.getView();
-      auto this_values_view = this->values.getView();
-      const IndexType paddingIndex = this->getPaddingIndex();
-      auto f = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) {
+      const auto this_segments_view = this->segments.getView();
+      auto f = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable {
          const IndexType column = matrix_columns_view[ globalIdx ];
          if( column != paddingIndex )
          {
             const RealType value = matrix_values_view[ globalIdx ];
-            IndexType thisGlobalIdx = segments_view.getGlobalIdx( rowIdx, localIdx );
+            IndexType thisGlobalIdx = this_segments_view.getGlobalIndex( rowIdx, localIdx );
             this_columns_view[ thisGlobalIdx ] = column;
             this_values_view[ thisGlobalIdx ] = value;
          }
       };
-      matrix.forAllRows( f );*/
+      matrix.forAllRows( f );
+   }
+   else
+   {
+      const IndexType maxRowLength = max( rowLengths );
+      const IndexType bufferRowsCount( 128 );
+      const size_t bufferSize = bufferRowsCount * maxRowLength;
+      Containers::Vector< Real2, Device2, Index2, RealAllocator2 > matrixValuesBuffer( bufferSize );
+      Containers::Vector< Index2, Device2, Index2, IndexAllocator2 > matrixColumnsBuffer( bufferSize );
+      Containers::Vector< RealType, DeviceType, IndexType, RealAllocator > thisValuesBuffer( bufferSize );
+      Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocator > thisColumnsBuffer( bufferSize );
+      auto matrixValuesBuffer_view = matrixValuesBuffer.getView();
+      auto matrixColumnsBuffer_view = matrixColumnsBuffer.getView();
+      auto thisValuesBuffer_view = thisValuesBuffer.getView();
+      auto thisColumnsBuffer_view = thisColumnsBuffer.getView();
+
+      IndexType baseRow( 0 );
+      const IndexType rowsCount = this->getRows();
+      while( baseRow < rowsCount )
+      {
+         const IndexType lastRow = min( baseRow + bufferRowsCount, rowsCount );
+         thisColumnsBuffer = paddingIndex;
+
+         ////
+         // Copy matrix elements into buffer
+         auto f1 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable {
+            const IndexType column = matrix_columns_view[ globalIdx ];
+            if( column != paddingIndex )
+            {
+               const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx;
+               matrixValuesBuffer_view[ bufferIdx ] = matrix_values_view[ globalIdx ];
+               matrixColumnsBuffer_view[ bufferIdx ] = column;
+            }
+         };
+         matrix.forRows( baseRow, lastRow, f1 );
+
+         ////
+         // Copy the source matrix buffer to this matrix buffer
+         thisValuesBuffer_view = matrixValuesBuffer_view;
+         thisColumnsBuffer_view = matrixColumnsBuffer_view;
+
+         ////
+         // Copy matrix elements from the buffer to the matrix
+         auto f2 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable {
+            const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx;
+            const IndexType column = thisColumnsBuffer_view[ bufferIdx ];
+            if( column != paddingIndex )
+            {
+               this_columns_view[ globalIdx ] = column;
+               this_values_view[ globalIdx ] = thisValuesBuffer_view[ bufferIdx ];
+            }
+         };
+         this->forRows( baseRow, lastRow, f2 );
+         baseRow += bufferRowsCount;
+      }
    }
 }
 
diff --git a/src/UnitTests/Matrices/SparseMatrixCopyTest.h b/src/UnitTests/Matrices/SparseMatrixCopyTest.h
index 9b09ef4d45..684a6a8713 100644
--- a/src/UnitTests/Matrices/SparseMatrixCopyTest.h
+++ b/src/UnitTests/Matrices/SparseMatrixCopyTest.h
@@ -12,12 +12,31 @@
 #include <TNL/Matrices/Ellpack.h>
 #include <TNL/Matrices/SlicedEllpack.h>
 
-using CSR_host = TNL::Matrices::CSR< int, TNL::Devices::Host, int >;
+#include <TNL/Matrices/SparseMatrix.h>
+#include <TNL/Containers/Segments/CSR.h>
+#include <TNL/Containers/Segments/Ellpack.h>
+#include <TNL/Containers/Segments/SlicedEllpack.h>
+
+/*using CSR_host = TNL::Matrices::CSR< int, TNL::Devices::Host, int >;
 using CSR_cuda = TNL::Matrices::CSR< int, TNL::Devices::Cuda, int >;
 using E_host = TNL::Matrices::Ellpack< int, TNL::Devices::Host, int >;
 using E_cuda = TNL::Matrices::Ellpack< int, TNL::Devices::Cuda, int >;
 using SE_host = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Host, int, 2 >;
-using SE_cuda = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Cuda, int, 2 >;
+using SE_cuda = TNL::Matrices::SlicedEllpack< int, TNL::Devices::Cuda, int, 2 >;*/
+
+template< typename Device, typename Index, typename IndexAllocator >
+using EllpackSegments = TNL::Containers::Segments::Ellpack< Device, Index, IndexAllocator >;
+
+template< typename Device, typename Index, typename IndexAllocator >
+using SlicedEllpackSegments = TNL::Containers::Segments::SlicedEllpack< Device, Index, IndexAllocator >;
+
+using CSR_host = TNL::Matrices::SparseMatrix< int, TNL::Containers::Segments::CSR, TNL::Devices::Host, int >;
+using CSR_cuda = TNL::Matrices::SparseMatrix< int, TNL::Containers::Segments::CSR, TNL::Devices::Cuda, int >;
+using E_host   = TNL::Matrices::SparseMatrix< int, EllpackSegments, TNL::Devices::Host, int >;
+using E_cuda   = TNL::Matrices::SparseMatrix< int, EllpackSegments, TNL::Devices::Cuda, int >;
+using SE_host  = TNL::Matrices::SparseMatrix< int, SlicedEllpackSegments, TNL::Devices::Host, int >;
+using SE_cuda  = TNL::Matrices::SparseMatrix< int, SlicedEllpackSegments, TNL::Devices::Cuda, int >;
+
 
 #ifdef HAVE_GTEST 
 #include <gtest/gtest.h>
@@ -388,7 +407,8 @@ void testConversion()
         checkTriDiagMatrix( triDiag1 );
         
         Matrix2 triDiag2;
-        TNL::Matrices::copySparseMatrix( triDiag2, triDiag1 );
+        //TNL::Matrices::copySparseMatrix( triDiag2, triDiag1 );
+        triDiag2 = triDiag1;
         checkTriDiagMatrix( triDiag2 );
    }
    
@@ -400,7 +420,8 @@ void testConversion()
         checkAntiTriDiagMatrix( antiTriDiag1 );
         
         Matrix2 antiTriDiag2;
-        TNL::Matrices::copySparseMatrix( antiTriDiag2, antiTriDiag1 );
+        //TNL::Matrices::copySparseMatrix( antiTriDiag2, antiTriDiag1 );
+        antiTriDiag2 = antiTriDiag1;
         checkAntiTriDiagMatrix( antiTriDiag2 );
    }
    
@@ -411,7 +432,8 @@ void testConversion()
         checkUnevenRowSizeMatrix( unevenRowSize1 );
         
         Matrix2 unevenRowSize2;
-        TNL::Matrices::copySparseMatrix( unevenRowSize2, unevenRowSize1 );
+        //TNL::Matrices::copySparseMatrix( unevenRowSize2, unevenRowSize1 );
+        unevenRowSize2 = unevenRowSize1;
         checkUnevenRowSizeMatrix( unevenRowSize2 );
    }
 }
-- 
GitLab