From 71a1f300e71ef90e6d81f07e553c11eb1d227b94 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Tue, 14 Jan 2020 22:25:05 +0100
Subject: [PATCH] Fixing multidiagonal matrix with CUDA.

---
 src/TNL/Matrices/Matrix.h                    | 10 ++++++----
 src/TNL/Matrices/Multidiagonal.hpp           |  2 +-
 src/TNL/Matrices/MultidiagonalMatrixView.h   |  8 ++++----
 src/TNL/Matrices/MultidiagonalMatrixView.hpp |  2 +-
 4 files changed, 12 insertions(+), 10 deletions(-)

diff --git a/src/TNL/Matrices/Matrix.h b/src/TNL/Matrices/Matrix.h
index ebe7ccc21f..0b34a5a57f 100644
--- a/src/TNL/Matrices/Matrix.h
+++ b/src/TNL/Matrices/Matrix.h
@@ -76,14 +76,16 @@ public:
    __cuda_callable__
    IndexType getColumns() const;
 
-   virtual void setElement( const IndexType row,
+   //virtual TODO: uncomment
+   void setElement( const IndexType row,
                             const IndexType column,
-                            const RealType& value ) = 0;
+                            const RealType& value );// = 0;
 
-   virtual void addElement( const IndexType row,
+   //virtual TODO: uncomment
+   void addElement( const IndexType row,
                             const IndexType column,
                             const RealType& value,
-                            const RealType& thisElementMultiplicator = 1.0 ) = 0;
+                            const RealType& thisElementMultiplicator = 1.0 );// = 0;
 
    virtual Real getElement( const IndexType row,
                             const IndexType column ) const = 0;
diff --git a/src/TNL/Matrices/Multidiagonal.hpp b/src/TNL/Matrices/Multidiagonal.hpp
index b885115012..7bc83f2d49 100644
--- a/src/TNL/Matrices/Multidiagonal.hpp
+++ b/src/TNL/Matrices/Multidiagonal.hpp
@@ -707,7 +707,7 @@ operator=( const Multidiagonal< Real_, Device_, Index_, RowMajorOrder_, RealAllo
 
             ////
             // Copy matrix elements from the buffer to the matrix
-            auto f2 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType& columnIndex, RealType& value  ) mutable {
+            auto f2 = [=] __cuda_callable__ ( const IndexType rowIdx, const IndexType localIdx, const IndexType columnIndex, RealType& value  ) mutable {
                const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx;
                   value = thisValuesBuffer_view[ bufferIdx ];
             };
diff --git a/src/TNL/Matrices/MultidiagonalMatrixView.h b/src/TNL/Matrices/MultidiagonalMatrixView.h
index 3d33ac0aea..1e5a9bd28e 100644
--- a/src/TNL/Matrices/MultidiagonalMatrixView.h
+++ b/src/TNL/Matrices/MultidiagonalMatrixView.h
@@ -30,10 +30,10 @@ class MultidiagonalMatrixView : public MatrixView< Real, Device, Index >
       using DeviceType = Device;
       using IndexType = Index;
       using BaseType = MatrixView< Real, Device, Index >;
-      using DiagonalsShiftsType = Containers::Vector< IndexType, DeviceType, IndexType >;
-      using DiagonalsShiftsView = typename DiagonalsShiftsType::ViewType;
-      using HostDiagonalsShiftsType = Containers::Vector< IndexType, Devices::Host, IndexType >;
-      using HostDiagonalsShiftsView = typename DiagonalsShiftsType::ViewType;
+      //using DiagonalsShiftsType = Containers::Vector< IndexType, DeviceType, IndexType >;
+      using DiagonalsShiftsView = Containers::VectorView< IndexType, DeviceType, IndexType >;
+      //using HostDiagonalsShiftsType = Containers::Vector< IndexType, Devices::Host, IndexType >;
+      using HostDiagonalsShiftsView = Containers::VectorView< IndexType, Devices::Host, IndexType >;
       using IndexerType = details::MultidiagonalMatrixIndexer< IndexType, RowMajorOrder >;
       using ValuesViewType = typename BaseType::ValuesView;
       using ViewType = MultidiagonalMatrixView< Real, Device, Index, RowMajorOrder >;
diff --git a/src/TNL/Matrices/MultidiagonalMatrixView.hpp b/src/TNL/Matrices/MultidiagonalMatrixView.hpp
index 1ba8dc34d0..2839c997aa 100644
--- a/src/TNL/Matrices/MultidiagonalMatrixView.hpp
+++ b/src/TNL/Matrices/MultidiagonalMatrixView.hpp
@@ -398,7 +398,7 @@ forRows( IndexType first, IndexType last, Function& function ) const
       {
          const IndexType columnIdx = rowIdx + diagonalsShifts_view[ localIdx ];
          if( columnIdx >= 0 && columnIdx < columns )
-            function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx, 0 ) ] );
+            function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx ) ] );
       }
    };
    Algorithms::ParallelFor< DeviceType >::exec( first, last, f );
-- 
GitLab