diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index 429615647391a381e0e58bb3565975f60350eb59..c197c7010cb352ed10f773937665719455570b48 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -124,7 +124,7 @@ class Ellpack }; } // namespace Segements - } // namespace Conatiners + } // namespace Containers } // namespace TNL #include <TNL/Containers/Segments/Ellpack.hpp> diff --git a/src/TNL/Matrices/Dense.h b/src/TNL/Matrices/Dense.h index cff1d57b4c838e23cdb5ebd50cb950062789e7b5..c72b7edfab14917923996147287b6afac440f7b3 100644 --- a/src/TNL/Matrices/Dense.h +++ b/src/TNL/Matrices/Dense.h @@ -12,8 +12,8 @@ #include <TNL/Allocators/Default.h> #include <TNL/Devices/Host.h> +#include <TNL/Matrices/DenseMatrixRowView.h> #include <TNL/Matrices/Matrix.h> -#include <TNL/Matrices/DenseRow.h> #include <TNL/Containers/Segments/Ellpack.h> namespace TNL { @@ -42,11 +42,16 @@ public: using RealType = Real; using DeviceType = Device; using IndexType = Index; - using CompressedRowLengthsVector = typename Matrix< Real, Device, Index >::CompressedRowLengthsVector; - using ConstCompressedRowLengthsVectorView = typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView; using BaseType = Matrix< Real, Device, Index >; - using MatrixRow = DenseRow< Real, Index >; + using ValuesType = typename BaseType::ValuesVector; + using ValuesViewType = typename ValuesType::ViewType; using SegmentsType = Containers::Segments::Ellpack< DeviceType, IndexType, typename Allocators::Default< Device >::template Allocator< IndexType >, RowMajorOrder >; + using SegmentViewType = typename SegmentsType::SegmentViewType; + using RowView = DenseMatrixRowView< SegmentViewType, ValuesViewType >; + + // TODO: remove this + using CompressedRowLengthsVector = typename Matrix< Real, Device, Index >::CompressedRowLengthsVector; + using ConstCompressedRowLengthsVectorView = typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView; template< typename _Real = Real, typename _Device = Device, @@ -81,6 +86,13 @@ public: void reset(); + __cuda_callable__ + const RowView getRow( const IndexType& rowIdx ) const; + + __cuda_callable__ + RowView getRow( const IndexType& rowIdx ); + + void setValue( const RealType& v ); __cuda_callable__ @@ -103,11 +115,11 @@ public: Real getElement( const IndexType row, const IndexType column ) const; - __cuda_callable__ + /*__cuda_callable__ MatrixRow getRow( const IndexType rowIndex ); __cuda_callable__ - const MatrixRow getRow( const IndexType rowIndex ) const; + const MatrixRow getRow( const IndexType rowIndex ) const;*/ template< typename Vector > __cuda_callable__ diff --git a/src/TNL/Matrices/Dense.hpp b/src/TNL/Matrices/Dense.hpp index bed7a37b72dc64b1fccb2e211432c03802c672f6..bd0614ad08fcf6c62cb78be5612ace5f9b0e448c 100644 --- a/src/TNL/Matrices/Dense.hpp +++ b/src/TNL/Matrices/Dense.hpp @@ -158,6 +158,32 @@ void Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::setValue( const this->values = value; } +template< typename Real, + typename Device, + typename Index, + bool RowMajorOrder, + typename RealAllocator > +__cuda_callable__ auto +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +getRow( const IndexType& rowIdx ) const -> const RowView +{ + TNL_ASSERT_LT( rowIdx, this->getRows(), "Row index is larger than number of matrix rows." ); + return RowView( this->segments.getSegmentView( rowIdx ), this->values.getView() ); +} + +template< typename Real, + typename Device, + typename Index, + bool RowMajorOrder, + typename RealAllocator > +__cuda_callable__ auto +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +getRow( const IndexType& rowIdx ) -> RowView +{ + TNL_ASSERT_LT( rowIdx, this->getRows(), "Row index is larger than number of matrix rows." ); + return RowView( this->segments.getSegmentView( rowIdx ), this->values.getView() ); +} + template< typename Real, typename Device, typename Index, @@ -236,46 +262,6 @@ Real Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getElement( con return this->values.getElement( this->getElementIndex( row, column ) ); } -template< typename Real, - typename Device, - typename Index, - bool RowMajorOrder, - typename RealAllocator > -__cuda_callable__ -typename Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::MatrixRow -Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: -getRow( const IndexType rowIndex ) -{ - if( std::is_same< Device, Devices::Host >::value ) - return MatrixRow( &this->values.getData()[ this->getElementIndex( rowIndex, 0 ) ], - this->columns, - 1 ); - if( std::is_same< Device, Devices::Cuda >::value ) - return MatrixRow( &this->values.getData()[ this->getElementIndex( rowIndex, 0 ) ], - this->columns, - this->rows ); -} - -template< typename Real, - typename Device, - typename Index, - bool RowMajorOrder, - typename RealAllocator > -__cuda_callable__ -const typename Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::MatrixRow -Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: -getRow( const IndexType rowIndex ) const -{ - if( std::is_same< Device, Devices::Host >::value ) - return MatrixRow( &this->values.getData()[ this->getElementIndex( rowIndex, 0 ) ], - this->columns, - 1 ); - if( std::is_same< Device, Devices::Cuda >::value ) - return MatrixRow( &this->values.getData()[ this->getElementIndex( rowIndex, 0 ) ], - this->columns, - this->rows ); -} - template< typename Real, typename Device, typename Index, @@ -898,13 +884,6 @@ Index Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getElementInde const IndexType column ) const { return this->segments.getGlobalIndex( row, column ); - /*TNL_ASSERT( ( std::is_same< Device, Devices::Host >::value || - std::is_same< Device, Devices::Cuda >::value ), ) - if( std::is_same< Device, Devices::Host >::value ) - return row * this->columns + column; - if( std::is_same< Device, Devices::Cuda >::value ) - return column * this->rows + row; - return -1;*/ } template<> diff --git a/src/TNL/Matrices/DenseMatrixRowView.h b/src/TNL/Matrices/DenseMatrixRowView.h new file mode 100644 index 0000000000000000000000000000000000000000..84c6b141cd7f7cdf25be8e550e573680b4cce902 --- /dev/null +++ b/src/TNL/Matrices/DenseMatrixRowView.h @@ -0,0 +1,52 @@ +/*************************************************************************** + DenseMatrixRowView.h - description + ------------------- + begin : Jan 3, 2020 + copyright : (C) 2020 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +namespace TNL { + namespace Matrices { + +template< typename SegmentView, + typename ValuesView > +class DenseMatrixRowView +{ + public: + + using RealType = typename ValuesView::RealType; + using SegmentViewType = SegmentView; + using IndexType = typename SegmentViewType::IndexType; + using ValuesViewType = ValuesView; + + __cuda_callable__ + DenseMatrixRowView( const SegmentViewType& segmentView, + const ValuesViewType& values ); + + __cuda_callable__ + IndexType getSize() const; + + __cuda_callable__ + const RealType& getValue( const IndexType column ) const; + + __cuda_callable__ + RealType& getValue( const IndexType column ); + + __cuda_callable__ + void setElement( const IndexType column, + const RealType& value ); + protected: + + SegmentViewType segmentView; + + ValuesViewType values; +}; + } // namespace Matrices +} // namespace TNL + +#include <TNL/Matrices/DenseMatrixRowView.hpp> diff --git a/src/TNL/Matrices/DenseMatrixRowView.hpp b/src/TNL/Matrices/DenseMatrixRowView.hpp new file mode 100644 index 0000000000000000000000000000000000000000..1962a4d9a8eabe80f28b2e21d1f0506792949225 --- /dev/null +++ b/src/TNL/Matrices/DenseMatrixRowView.hpp @@ -0,0 +1,71 @@ +/*************************************************************************** + DenseMatrixRowView.hpp - description + ------------------- + begin : Jan 3, 2020 + copyright : (C) 2020 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +#include <TNL/Matrices/DenseMatrixRowView.h> + +namespace TNL { + namespace Matrices { + +template< typename SegmentView, + typename ValuesView > +__cuda_callable__ +DenseMatrixRowView< SegmentView, ValuesView >:: +DenseMatrixRowView( const SegmentViewType& segmentView, + const ValuesViewType& values ) + : segmentView( segmentView ), values( values ) +{ +} + +template< typename SegmentView, + typename ValuesView > +__cuda_callable__ auto +DenseMatrixRowView< SegmentView, ValuesView >:: +getSize() const -> IndexType +{ + return segmentView.getSize(); +} + +template< typename SegmentView, + typename ValuesView > +__cuda_callable__ auto +DenseMatrixRowView< SegmentView, ValuesView >:: +getValue( const IndexType column ) const -> const RealType& +{ + TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." ); + return values[ segmentView.getGlobalIndex( column ) ]; +} + +template< typename SegmentView, + typename ValuesView > +__cuda_callable__ auto +DenseMatrixRowView< SegmentView, ValuesView >:: +getValue( const IndexType column ) -> RealType& +{ + TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." ); + return values[ segmentView.getGlobalIndex( column ) ]; +} + +template< typename SegmentView, + typename ValuesView > +__cuda_callable__ void +DenseMatrixRowView< SegmentView, ValuesView >:: +setElement( const IndexType column, + const RealType& value ) +{ + TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." ); + const IndexType globalIdx = segmentView.getGlobalIndex( column ); + values[ globalIdx ] = value; +} + + + } // namespace Matrices +} // namespace TNL diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h index 8f96af169cce7e7f54b528d58a1be6a2328102e1..c50f71612346a5ad7386f3eaa95e42fb2973a762 100644 --- a/src/TNL/Matrices/SparseMatrix.h +++ b/src/TNL/Matrices/SparseMatrix.h @@ -14,8 +14,8 @@ #include <TNL/Matrices/MatrixType.h> #include <TNL/Allocators/Default.h> #include <TNL/Containers/Segments/CSR.h> -#include <TNL/Matrices/SparseMatrixView.h> #include <TNL/Matrices/SparseMatrixRowView.h> +#include <TNL/Matrices/SparseMatrixView.h> namespace TNL { namespace Matrices {