Skip to content
Snippets Groups Projects
Commit a16d4530 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed access to constant SparseRow in CSR and Ellpack matrices

parent 092a025d
No related branches found
No related tags found
No related merge requests found
...@@ -41,6 +41,7 @@ class CSR : public Sparse< Real, Device, Index > ...@@ -41,6 +41,7 @@ class CSR : public Sparse< Real, Device, Index >
typedef CSR< Real, Devices::Cuda, Index > CudaType; typedef CSR< Real, Devices::Cuda, Index > CudaType;
typedef Sparse< Real, Device, Index > BaseType; typedef Sparse< Real, Device, Index > BaseType;
typedef typename BaseType::MatrixRow MatrixRow; typedef typename BaseType::MatrixRow MatrixRow;
typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
enum SPMVCudaKernel { scalar, vector, hybrid }; enum SPMVCudaKernel { scalar, vector, hybrid };
...@@ -125,7 +126,7 @@ class CSR : public Sparse< Real, Device, Index > ...@@ -125,7 +126,7 @@ class CSR : public Sparse< Real, Device, Index >
MatrixRow getRow( const IndexType rowIndex ); MatrixRow getRow( const IndexType rowIndex );
__cuda_callable__ __cuda_callable__
const MatrixRow getRow( const IndexType rowIndex ) const; ConstMatrixRow getRow( const IndexType rowIndex ) const;
template< typename Vector > template< typename Vector >
__cuda_callable__ __cuda_callable__
......
...@@ -406,16 +406,16 @@ template< typename Real, ...@@ -406,16 +406,16 @@ template< typename Real,
typename Device, typename Device,
typename Index > typename Index >
__cuda_callable__ __cuda_callable__
const typename CSR< Real, Device, Index >::MatrixRow typename CSR< Real, Device, Index >::ConstMatrixRow
CSR< Real, Device, Index >:: CSR< Real, Device, Index >::
getRow( const IndexType rowIndex ) const getRow( const IndexType rowIndex ) const
{ {
const IndexType rowOffset = this->rowPointers[ rowIndex ]; const IndexType rowOffset = this->rowPointers[ rowIndex ];
const IndexType rowLength = this->rowPointers[ rowIndex + 1 ] - rowOffset; const IndexType rowLength = this->rowPointers[ rowIndex + 1 ] - rowOffset;
return MatrixRow( &this->columnIndexes[ rowOffset ], return ConstMatrixRow( &this->columnIndexes[ rowOffset ],
&this->values[ rowOffset ], &this->values[ rowOffset ],
rowLength, rowLength,
1 ); 1 );
} }
template< typename Real, template< typename Real,
......
...@@ -35,6 +35,7 @@ class Ellpack : public Sparse< Real, Device, Index > ...@@ -35,6 +35,7 @@ class Ellpack : public Sparse< Real, Device, Index >
typedef Ellpack< Real, Devices::Cuda, Index > CudaType; typedef Ellpack< Real, Devices::Cuda, Index > CudaType;
typedef Sparse< Real, Device, Index > BaseType; typedef Sparse< Real, Device, Index > BaseType;
typedef typename BaseType::MatrixRow MatrixRow; typedef typename BaseType::MatrixRow MatrixRow;
typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
Ellpack(); Ellpack();
...@@ -128,7 +129,7 @@ class Ellpack : public Sparse< Real, Device, Index > ...@@ -128,7 +129,7 @@ class Ellpack : public Sparse< Real, Device, Index >
MatrixRow getRow( const IndexType rowIndex ); MatrixRow getRow( const IndexType rowIndex );
__cuda_callable__ __cuda_callable__
const MatrixRow getRow( const IndexType rowIndex ) const; ConstMatrixRow getRow( const IndexType rowIndex ) const;
template< typename Vector > template< typename Vector >
__cuda_callable__ __cuda_callable__
......
...@@ -455,16 +455,16 @@ template< typename Real, ...@@ -455,16 +455,16 @@ template< typename Real,
typename Device, typename Device,
typename Index > typename Index >
__cuda_callable__ __cuda_callable__
const typename Ellpack< Real, Device, Index >::MatrixRow typename Ellpack< Real, Device, Index >::ConstMatrixRow
Ellpack< Real, Device, Index >:: Ellpack< Real, Device, Index >::
getRow( const IndexType rowIndex ) const getRow( const IndexType rowIndex ) const
{ {
//printf( "this->rowLengths = %d this = %p \n", this->rowLengths, this ); //printf( "this->rowLengths = %d this = %p \n", this->rowLengths, this );
IndexType rowBegin = DeviceDependentCode::getRowBegin( *this, rowIndex ); IndexType rowBegin = DeviceDependentCode::getRowBegin( *this, rowIndex );
return MatrixRow( &this->columnIndexes[ rowBegin ], return ConstMatrixRow( &this->columnIndexes[ rowBegin ],
&this->values[ rowBegin ], &this->values[ rowBegin ],
this->rowLengths, this->rowLengths,
DeviceDependentCode::getElementStep( *this ) ); DeviceDependentCode::getElementStep( *this ) );
} }
template< typename Real, template< typename Real,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment