From 71a1f300e71ef90e6d81f07e553c11eb1d227b94 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Tue, 14 Jan 2020 22:25:05 +0100 Subject: [PATCH] Fixing multidiagonal matrix with CUDA. --- src/TNL/Matrices/Matrix.h | 10 ++++++---- src/TNL/Matrices/Multidiagonal.hpp | 2 +- src/TNL/Matrices/MultidiagonalMatrixView.h | 8 ++++---- src/TNL/Matrices/MultidiagonalMatrixView.hpp | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/TNL/Matrices/Matrix.h b/src/TNL/Matrices/Matrix.h index ebe7ccc21f..0b34a5a57f 100644 --- a/src/TNL/Matrices/Matrix.h +++ b/src/TNL/Matrices/Matrix.h @@ -76,14 +76,16 @@ public: __cuda_callable__ IndexType getColumns() const; - virtual void setElement( const IndexType row, + //virtual TODO: uncomment + void setElement( const IndexType row, const IndexType column, - const RealType& value ) = 0; + const RealType& value );// = 0; - virtual void addElement( const IndexType row, + //virtual TODO: uncomment + void addElement( const IndexType row, const IndexType column, const RealType& value, - const RealType& thisElementMultiplicator = 1.0 ) = 0; + const RealType& thisElementMultiplicator = 1.0 );// = 0; virtual Real getElement( const IndexType row, const IndexType column ) const = 0; diff --git a/src/TNL/Matrices/Multidiagonal.hpp b/src/TNL/Matrices/Multidiagonal.hpp index b885115012..7bc83f2d49 100644 --- a/src/TNL/Matrices/Multidiagonal.hpp +++ b/src/TNL/Matrices/Multidiagonal.hpp @@ -707,7 +707,7 @@ operator=( const Multidiagonal< Real_, Device_, Index_, RowMajorOrder_, RealAllo //// // Copy matrix elements from the buffer to the matrix - auto f2 = [=] __cuda_callable__ ( IndexType rowIdx, IndexType localIdx, IndexType& columnIndex, RealType& value ) mutable { + auto f2 = [=] __cuda_callable__ ( const IndexType rowIdx, const IndexType localIdx, const IndexType columnIndex, RealType& value ) mutable { const IndexType bufferIdx = ( rowIdx - baseRow ) * maxRowLength + localIdx; value = thisValuesBuffer_view[ bufferIdx ]; }; diff --git a/src/TNL/Matrices/MultidiagonalMatrixView.h b/src/TNL/Matrices/MultidiagonalMatrixView.h index 3d33ac0aea..1e5a9bd28e 100644 --- a/src/TNL/Matrices/MultidiagonalMatrixView.h +++ b/src/TNL/Matrices/MultidiagonalMatrixView.h @@ -30,10 +30,10 @@ class MultidiagonalMatrixView : public MatrixView< Real, Device, Index > using DeviceType = Device; using IndexType = Index; using BaseType = MatrixView< Real, Device, Index >; - using DiagonalsShiftsType = Containers::Vector< IndexType, DeviceType, IndexType >; - using DiagonalsShiftsView = typename DiagonalsShiftsType::ViewType; - using HostDiagonalsShiftsType = Containers::Vector< IndexType, Devices::Host, IndexType >; - using HostDiagonalsShiftsView = typename DiagonalsShiftsType::ViewType; + //using DiagonalsShiftsType = Containers::Vector< IndexType, DeviceType, IndexType >; + using DiagonalsShiftsView = Containers::VectorView< IndexType, DeviceType, IndexType >; + //using HostDiagonalsShiftsType = Containers::Vector< IndexType, Devices::Host, IndexType >; + using HostDiagonalsShiftsView = Containers::VectorView< IndexType, Devices::Host, IndexType >; using IndexerType = details::MultidiagonalMatrixIndexer< IndexType, RowMajorOrder >; using ValuesViewType = typename BaseType::ValuesView; using ViewType = MultidiagonalMatrixView< Real, Device, Index, RowMajorOrder >; diff --git a/src/TNL/Matrices/MultidiagonalMatrixView.hpp b/src/TNL/Matrices/MultidiagonalMatrixView.hpp index 1ba8dc34d0..2839c997aa 100644 --- a/src/TNL/Matrices/MultidiagonalMatrixView.hpp +++ b/src/TNL/Matrices/MultidiagonalMatrixView.hpp @@ -398,7 +398,7 @@ forRows( IndexType first, IndexType last, Function& function ) const { const IndexType columnIdx = rowIdx + diagonalsShifts_view[ localIdx ]; if( columnIdx >= 0 && columnIdx < columns ) - function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx, 0 ) ] ); + function( rowIdx, localIdx, columnIdx, values_view[ indexer.getGlobalIndex( rowIdx, localIdx ) ] ); } }; Algorithms::ParallelFor< DeviceType >::exec( first, last, f ); -- GitLab