From 127b3bc9b6108aba737b64368d9d3c16e941f556 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Sun, 22 Dec 2019 23:40:13 +0100
Subject: [PATCH] Fixing sparse matrix assignment operator.

---
 src/TNL/Containers/Segments/SlicedEllpack.hpp |  4 +--
 src/TNL/Matrices/SparseMatrix.h               |  8 ++++-
 src/TNL/Matrices/SparseMatrix.hpp             | 29 ++++++++++++-------
 3 files changed, 28 insertions(+), 13 deletions(-)

diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp
index ad83f666a0..c9c1d85608 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.hpp
+++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp
@@ -255,7 +255,7 @@ forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
    const auto sliceOffsets_view = this->sliceOffsets.getConstView();
    if( RowMajorOrder )
    {
-      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) {
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
          const IndexType sliceIdx = segmentIdx / SliceSize;
          const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
          const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
@@ -270,7 +270,7 @@ forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
    }
    else
    {
-      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) {
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
          const IndexType sliceIdx = segmentIdx / SliceSize;
          const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
          const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h
index 8c8fef599a..44ded93a65 100644
--- a/src/TNL/Matrices/SparseMatrix.h
+++ b/src/TNL/Matrices/SparseMatrix.h
@@ -62,7 +62,13 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
 
       virtual String getSerializationTypeVirtual() const;
 
-      void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );
+      template< typename RowsCapacitiesVector >
+      void setCompressedRowLengths( const RowsCapacitiesVector& rowCapacities );
+
+      // TODO: Remove this when possible
+      void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths ) {
+         this->setCompressedRowLengths( rowLengths );
+      };
 
       template< typename Vector >
       void getCompressedRowLengths( Vector& rowLengths ) const;
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index 5de4473ab5..964e9eb221 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -104,12 +104,21 @@ template< typename Real,
           typename Index,
           typename RealAllocator,
           typename IndexAllocator >
+   template< typename RowsCapacitiesVector >
 void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
-setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths )
+setCompressedRowLengths( const RowsCapacitiesVector& rowsCapacities )
 {
-   TNL_ASSERT_EQ( rowLengths.getSize(), this->getRows(), "Number of matrix rows does not fit with rowLengths vector size." );
-   this->segments.setSegmentsSizes( rowLengths );
+   TNL_ASSERT_EQ( rowsCapacities.getSize(), this->getRows(), "Number of matrix rows does not fit with rowLengths vector size." );
+   using RowsCapacitiesVectorDevice = typename RowsCapacitiesVector::DeviceType;
+   if( std::is_same< DeviceType, RowsCapacitiesVectorDevice >::value )
+      this->segments.setSegmentsSizes( rowsCapacities );
+   else
+   {
+      RowsCapacitiesType thisRowsCapacities;
+      thisRowsCapacities = rowsCapacities;
+      this->segments.setSegmentsSizes( thisRowsCapacities );
+   }
    this->values.setSize( this->segments.getStorageSize() );
    this->values = ( RealType ) 0;
    this->columnIndexes.setSize( this->segments.getStorageSize() );
@@ -594,13 +603,11 @@ forRows( IndexType first, IndexType last, Function& function ) const
    const auto columns_view = this->columnIndexes.getConstView();
    const auto values_view = this->values.getConstView();
    const IndexType paddingIndex_ = this->getPaddingIndex();
-   /*auto fetch_ = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable -> decltype( fetch( IndexType(), IndexType(), RealType() ) ) {
-      IndexType columnIdx = columns_view[ globalIdx ];
-      if( columnIdx != paddingIndex_ )
-         return fetch( rowIdx, columnIdx, values_view[ globalIdx ] );
-      return zero;
+   auto f = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType globalIdx ) mutable -> bool {
+      function( rowIdx, localIdx, globalIdx );
+      return true;
    };
-   this->segments.segmentsReduction( first, last, fetch_, reduce, keep, zero );*/
+   this->segments.forSegments( first, last, f );
 
 }
 
@@ -702,8 +709,9 @@ SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix )
 {
    using RHSMatrixType = SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >;
-   RowsCapacitiesType rowLengths;
+   typename RHSMatrixType::RowsCapacitiesType rowLengths;
    matrix.getCompressedRowLengths( rowLengths );
+   this->setDimensions( matrix.getRows(), matrix.getColumns() );
    this->setCompressedRowLengths( rowLengths );
 
    // TODO: Replace this with SparseMatrixView
@@ -712,6 +720,7 @@ operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2
    const IndexType paddingIndex = this->getPaddingIndex();
    auto this_columns_view = this->columnIndexes.getView();
    auto this_values_view = this->values.getView();
+   this_columns_view = paddingIndex;
 
    if( std::is_same< Device, Device2 >::value )
    {
-- 
GitLab