From 3bfc83cce078241e3b656ebd06dcca10f85db09e Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Sun, 19 Jan 2020 20:33:43 +0100
Subject: [PATCH] Fixing MultidiagonalMatrix.

---
 src/TNL/Matrices/Multidiagonal.hpp           | 6 +++---
 src/TNL/Matrices/MultidiagonalMatrixView.hpp | 9 ++++++---
 2 files changed, 9 insertions(+), 6 deletions(-)

diff --git a/src/TNL/Matrices/Multidiagonal.hpp b/src/TNL/Matrices/Multidiagonal.hpp
index 7bc83f2d49..94470d3d1c 100644
--- a/src/TNL/Matrices/Multidiagonal.hpp
+++ b/src/TNL/Matrices/Multidiagonal.hpp
@@ -668,7 +668,7 @@ operator=( const Multidiagonal< Real_, Device_, Index_, RowMajorOrder_, RealAllo
       if( std::is_same< Device, Device_ >::value )
       {
          const auto matrix_view = matrix.getView();
-         auto f = [=] __cuda_callable__ ( const IndexType& rowIdx, const IndexType& localIdx, const IndexType& column, Real& value ) mutable {
+         auto f = [=] __cuda_callable__ ( const IndexType& rowIdx, const IndexType& localIdx, const IndexType& column, Real& value, bool& compute ) mutable {
             value = matrix_view.getValues()[ matrix_view.getIndexer().getGlobalIndex( rowIdx, localIdx ) ];
          };
          this->forAllRows( f );
@@ -695,7 +695,7 @@ operator=( const Multidiagonal< Real_, Device_, Index_, RowMajorOrder_, RealAllo
 
             ////
             // Copy matrix elements into buffer
-            auto f1 = [=] __cuda_callable__ ( RHSIndexType rowIdx, RHSIndexType localIdx, RHSIndexType columnIndex, const RHSRealType& value ) mutable {
+            auto f1 = [=] __cuda_callable__ ( RHSIndexType rowIdx, RHSIndexType localIdx, RHSIndexType columnIndex, const RHSRealType& value, bool& compute ) mutable {
                   const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx;
                   matrixValuesBuffer_view[ bufferIdx ] = value;
             };
@@ -707,7 +707,7 @@ operator=( const Multidiagonal< Real_, Device_, Index_, RowMajorOrder_, RealAllo
 
             ////
             // Copy matrix elements from the buffer to the matrix
-            auto f2 = [=] __cuda_callable__ ( const IndexType rowIdx, const IndexType localIdx, const IndexType columnIndex, RealType& value  ) mutable {
+            auto f2 = [=] __cuda_callable__ ( const IndexType rowIdx, const IndexType localIdx, const IndexType columnIndex, RealType& value, bool& compute  ) mutable {
                const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx;
                   value = thisValuesBuffer_view[ bufferIdx ];
             };
diff --git a/src/TNL/Matrices/MultidiagonalMatrixView.hpp b/src/TNL/Matrices/MultidiagonalMatrixView.hpp
index 96312d03cf..2243684654 100644
--- a/src/TNL/Matrices/MultidiagonalMatrixView.hpp
+++ b/src/TNL/Matrices/MultidiagonalMatrixView.hpp
@@ -216,8 +216,10 @@ void
 MultidiagonalMatrixView< Real, Device, Index, RowMajorOrder >::
 setValue( const RealType& v )
 {
+   // we dont do this->values = v here because it would set even elements 'outside' the matrix
+   // method getNumberOfNonzeroElements would not well
    const RealType newValue = v;
-   auto f = [=] __cuda_callable__ ( const IndexType& rowIdx, const IndexType& localIdx, const IndexType columnIdx, RealType& value ) mutable {
+   auto f = [=] __cuda_callable__ ( const IndexType& rowIdx, const IndexType& localIdx, const IndexType columnIdx, RealType& value, bool& compute ) mutable {
       value = newValue;
    };
    this->forAllRows( f );
@@ -419,12 +421,13 @@ forRows( IndexType first, IndexType last, Function& function )
    const IndexType diagonalsCount = this->diagonalsShifts.getSize();
    const IndexType columns = this->getColumns();
    const auto indexer = this->indexer;
+   bool compute( true );
    auto f = [=] __cuda_callable__ ( IndexType rowIdx ) mutable {
-      for( IndexType localIdx = 0; localIdx < diagonalsCount; localIdx++ )
+      for( IndexType localIdx = 0; localIdx < diagonalsCount && compute; localIdx++ )
       {
          const IndexType columnIdx = rowIdx + diagonalsShifts_view[ localIdx ];
          if( columnIdx >= 0 && columnIdx < columns )
-            function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx ) ] );
+            function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx ) ], compute );
       }
    };
    Algorithms::ParallelFor< DeviceType >::exec( first, last, f );
-- 
GitLab