diff --git a/src/TNL/Matrices/CSR.h b/src/TNL/Matrices/CSR.h index 1ce7d330bbf17aec36373169389889f4b61773ee..423b40feff03c60869486c76bd330dd129b13df3 100644 --- a/src/TNL/Matrices/CSR.h +++ b/src/TNL/Matrices/CSR.h @@ -41,7 +41,8 @@ private: public: - typedef Real RealType; + using RealType = Real; + //typedef Real RealType; typedef Device DeviceType; typedef Index IndexType; typedef typename Sparse< RealType, DeviceType, IndexType >:: CompressedRowLengthsVector CompressedRowLengthsVector; @@ -51,7 +52,10 @@ public: typedef CSR< Real, Devices::Cuda, Index > CudaType; typedef Sparse< Real, Device, Index > BaseType; typedef typename BaseType::MatrixRow MatrixRow; - typedef SparseRow< const RealType, const IndexType > ConstMatrixRow; + + using ConstMatrixRow = typename BaseType::ConstMatrixRow; + //using typename BaseType::ConstMatrixRow; + //typedef SparseRow< const RealType, const IndexType > ConstMatrixRow; enum SPMVCudaKernel { scalar, vector, hybrid }; diff --git a/src/TNL/Matrices/CSR_impl.h b/src/TNL/Matrices/CSR_impl.h index de95c0d78d00d3ddb6c467c290e13267d127f16f..a77e68575b13ce5933dd7fbfd10a376402af48e5 100644 --- a/src/TNL/Matrices/CSR_impl.h +++ b/src/TNL/Matrices/CSR_impl.h @@ -131,13 +131,38 @@ Index CSR< Real, Device, Index >::getRowLengthFast( const IndexType row ) const return this->rowPointers[ row + 1 ] - this->rowPointers[ row ]; } +// TODO: presunout do SparseRow +template< typename MatrixRow > +__global__ void getNonZeroRowLengthCudaKernel( const MatrixRow row, typename MatrixRow::IndexType* result ) +{ + int threadId = blockIdx.x * blockDim.x + threadIdx.x; + if( threadId == 0 ) + { + result = row->getNonZeroElementsCount(); + } +} + template< typename Real, typename Device, typename Index > Index CSR< Real, Device, Index >::getNonZeroRowLength( const IndexType row ) const -{ - ConstMatrixRow matrixRow = this->getRow( row ); - return matrixRow.getNonZeroElementsCount( TNL::String( Device::getDeviceType() ) ); +{ + if( std::is_same< DeviceType, Devices::Host >::value ) + { + ConstMatrixRow matrixRow = this->getRow( row ); + return matrixRow.getNonZeroElementsCount(); + } + if( std::is_same< DeviceType, Devices::Cuda >::value ) + { + ConstMatrixRow matrixRow = this->getRow( row ); + IndexType resultHost; + IndexType* resultCuda = Devices::Cuda::passToDevice( resultHost ); + getNonZeroRowLengthCudaKernel<<< 1, 1 >>>( row, &resultCuda ); + resultHost = Devices::Cuda::passFromDevice( resultCuda ); + Devices::Cuda::freeFromDevice( resultCuda ); + return resultHost; + } + // getRow() was throwing segmentation faults. // FOR THIS TO WORK, I had to change getRow() from [ rowIndex ] to .getElement( rowIndex ). @@ -159,6 +184,16 @@ Index CSR< Real, Device, Index >::getNonZeroRowLength( const IndexType row ) con // return elementCount; } +template< typename Real, + typename Device, + typename Index > +__cuda_callable__ +Index CSR< Real, Device, Index >::getNonZeroRowLengthFast( const IndexType row ) const +{ + ConstMatrixRow matrixRow = this->getRow( row ); + return matrixRow.getNonZeroElementsCount(); +} + template< typename Real, typename Device, typename Index > diff --git a/src/TNL/Matrices/Sparse.h b/src/TNL/Matrices/Sparse.h index 2ee49219ee2d8fa8be6662c417dd68c3c3a6c690..069ade36cb989e18e1b6e1cc9821af5df50de8c1 100644 --- a/src/TNL/Matrices/Sparse.h +++ b/src/TNL/Matrices/Sparse.h @@ -30,6 +30,7 @@ class Sparse : public Matrix< Real, Device, Index > typedef Containers::Vector< IndexType, DeviceType, IndexType > ColumnIndexesVector; typedef Matrix< Real, Device, Index > BaseType; typedef SparseRow< RealType, IndexType > MatrixRow; + typedef SparseRow< const RealType, const IndexType > ConstMatrixRow; Sparse(); diff --git a/src/TNL/Matrices/SparseRow.h b/src/TNL/Matrices/SparseRow.h index d70d780bdfecfb4eb346bbded1d7b06e5c2b94c3..fac855eae71a26cdb4dbf62f927f12d0f24b5af1 100644 --- a/src/TNL/Matrices/SparseRow.h +++ b/src/TNL/Matrices/SparseRow.h @@ -21,6 +21,9 @@ namespace Matrices { template< typename Real, typename Index > class SparseRow { + using RealType = Real; + using IndexType = Index; + public: __cuda_callable__ @@ -53,7 +56,7 @@ class SparseRow Index getLength() const; __cuda_callable__ - Index getNonZeroElementsCount( TNL::String deviceType ) const; + Index getNonZeroElementsCount() const; void print( std::ostream& str ) const; diff --git a/src/TNL/Matrices/SparseRow_impl.h b/src/TNL/Matrices/SparseRow_impl.h index 14888669dd9ff049e1852a3527177435396ac9e0..d83aad239271c44baffb1bc5181aa65224725bf4 100644 --- a/src/TNL/Matrices/SparseRow_impl.h +++ b/src/TNL/Matrices/SparseRow_impl.h @@ -116,7 +116,7 @@ template< typename Real, typename Index > __cuda_callable__ Index SparseRow< Real, Index >:: -getNonZeroElementsCount( TNL::String deviceType ) const +getNonZeroElementsCount() const { using NonConstIndex = typename std::remove_const< Index >::type;