diff --git a/src/TNL/Matrices/Dense.h b/src/TNL/Matrices/Dense.h index a2c6a7eda607431929ef1be36d945712248d18b2..cff1d57b4c838e23cdb5ebd50cb950062789e7b5 100644 --- a/src/TNL/Matrices/Dense.h +++ b/src/TNL/Matrices/Dense.h @@ -14,10 +14,10 @@ #include <TNL/Devices/Host.h> #include <TNL/Matrices/Matrix.h> #include <TNL/Matrices/DenseRow.h> -#include <TNL/Containers/Array.h> +#include <TNL/Containers/Segments/Ellpack.h> namespace TNL { -namespace Matrices { +namespace Matrices { template< typename Device > class DenseDeviceDependentCode; @@ -46,6 +46,7 @@ public: using ConstCompressedRowLengthsVectorView = typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView; using BaseType = Matrix< Real, Device, Index >; using MatrixRow = DenseRow< Real, Index >; + using SegmentsType = Containers::Segments::Ellpack< DeviceType, IndexType, typename Allocators::Default< Device >::template Allocator< IndexType >, RowMajorOrder >; template< typename _Real = Real, typename _Device = Device, @@ -164,6 +165,8 @@ protected: typedef DenseDeviceDependentCode< DeviceType > DeviceDependentCode; friend class DenseDeviceDependentCode< DeviceType >; + + SegmentsType segments; }; } // namespace Matrices diff --git a/src/TNL/Matrices/Dense.hpp b/src/TNL/Matrices/Dense.hpp index 1900523908290731ff71b60ebfc33678857f5e64..bed7a37b72dc64b1fccb2e211432c03802c672f6 100644 --- a/src/TNL/Matrices/Dense.hpp +++ b/src/TNL/Matrices/Dense.hpp @@ -31,7 +31,9 @@ template< typename Real, typename Index, bool RowMajorOrder, typename RealAllocator > -String Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getSerializationType() +String +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +getSerializationType() { return String( "Matrices::Dense< " ) + getType< RealType >() + ", " + @@ -44,7 +46,9 @@ template< typename Real, typename Index, bool RowMajorOrder, typename RealAllocator > -String Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getSerializationTypeVirtual() const +String +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +getSerializationTypeVirtual() const { return this->getSerializationType(); } @@ -54,12 +58,15 @@ template< typename Real, typename Index, bool RowMajorOrder, typename RealAllocator > -void Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::setDimensions( const IndexType rows, - const IndexType columns ) +void +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +setDimensions( const IndexType rows, + const IndexType columns ) { Matrix< Real, Device, Index >::setDimensions( rows, columns ); + this->segments.setSegmentsSizes( rows, columns ); this->values.setSize( rows * columns ); - this->values.setValue( 0.0 ); + this->values = 0.0; } template< typename Real, @@ -68,7 +75,9 @@ template< typename Real, bool RowMajorOrder, typename RealAllocator > template< typename Matrix_ > -void Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::setLike( const Matrix_& matrix ) +void +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +setLike( const Matrix_& matrix ) { Matrix< Real, Device, Index, RealAllocator >::setLike( matrix ); } @@ -78,8 +87,11 @@ template< typename Real, typename Index, bool RowMajorOrder, typename RealAllocator > -void Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths ) +void +Dense< Real, Device, Index, RowMajorOrder, RealAllocator >:: +setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths ) { + this->setDimensions( rowLengths.getSize(), max( rowLengths ) ); } template< typename Real, @@ -92,17 +104,6 @@ Index Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getRowLength( return this->getColumns(); } -/*template< typename Real, - typename Device, - typename Index, - bool RowMajorOrder, - typename RealAllocator > -__cuda_callable__ -Index Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getRowLengthFast( const IndexType row ) const -{ - return this->getColumns(); -}*/ - template< typename Real, typename Device, typename Index, @@ -896,13 +897,14 @@ __cuda_callable__ Index Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::getElementIndex( const IndexType row, const IndexType column ) const { - TNL_ASSERT( ( std::is_same< Device, Devices::Host >::value || + return this->segments.getGlobalIndex( row, column ); + /*TNL_ASSERT( ( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value ), ) if( std::is_same< Device, Devices::Host >::value ) return row * this->columns + column; if( std::is_same< Device, Devices::Cuda >::value ) return column * this->rows + row; - return -1; + return -1;*/ } template<>