Commit 740f7551 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixed getRow method for constant SparseMatrix and SparseMatrixView.

parent 3aadca3a
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -59,11 +59,14 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
      using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType;
      using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVectorType;
      using ValuesViewType = typename ValuesVectorType::ViewType;
      using ConstValuesViewType = typename ValuesViewType::ConstViewType;
      using ColumnsIndexesVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >;
      using ColumnsIndexesViewType = typename ColumnsIndexesVectorType::ViewType;
      using ConstColumnsIndexesViewType = typename ColumnsIndexesViewType::ConstViewType;
      using ViewType = SparseMatrixView< Real, Device, Index, MatrixType, SegmentsViewTemplate >;
      using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
      using RowView = SparseMatrixRowView< SegmentViewType, ValuesViewType, ColumnsIndexesViewType, isBinary() >;
      using ConstRowView = typename RowView::ConstViewType;

      // TODO: remove this - it is here only for compatibility with original matrix implementation
      typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector;
@@ -135,7 +138,7 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
      void reset();

      __cuda_callable__
      const RowView getRow( const IndexType& rowIdx ) const;
      const ConstRowView getRow( const IndexType& rowIdx ) const;

      __cuda_callable__
      RowView getRow( const IndexType& rowIdx );
+1 −1
Original line number Diff line number Diff line
@@ -364,7 +364,7 @@ template< typename Real,
          typename IndexAllocator >
__cuda_callable__ auto
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
getRow( const IndexType& rowIdx ) const -> const RowView
getRow( const IndexType& rowIdx ) const -> const ConstRowView
{
   return this->view.getRow( rowIdx );
}
+3 −0
Original line number Diff line number Diff line
@@ -26,6 +26,9 @@ class SparseMatrixRowView
      using IndexType = typename SegmentViewType::IndexType;
      using ValuesViewType = ValuesView;
      using ColumnsIndexesViewType = ColumnsIndexesView;
      using ConstValuesViewType = typename ValuesViewType::ConstViewType;
      using ConstColumnsIndexesViewType = typename ColumnsIndexesViewType::ConstViewType;
      using ConstViewType = SparseMatrixRowView< SegmentView, ConstValuesViewType, ConstColumnsIndexesViewType, isBinary_ >;

      static constexpr bool isBinary() { return isBinary_; };

+4 −1
Original line number Diff line number Diff line
@@ -41,10 +41,13 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >
      using RowsCapacitiesView = Containers::VectorView< IndexType, DeviceType, IndexType >;
      using ConstRowsCapacitiesView = typename RowsCapacitiesView::ConstViewType;
      using ValuesViewType = typename BaseType::ValuesView;
      using ConstValuesViewType = typename ValuesViewType::ConstViewType;
      using ColumnsIndexesViewType = Containers::VectorView< IndexType, DeviceType, IndexType >;
      using ConstColumnsIndexesViewType = typename ColumnsIndexesViewType::ConstViewType;
      using ViewType = SparseMatrixView< typename std::remove_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
      using ConstViewType = SparseMatrixView< typename std::add_const< Real >::type, Device, Index, MatrixType, SegmentsViewTemplate >;
      using RowView = SparseMatrixRowView< SegmentViewType, ValuesViewType, ColumnsIndexesViewType, isBinary() >;
      using ConstRowView = typename RowView::ConstViewType;

      // TODO: remove this - it is here only for compatibility with original matrix implementation
      typedef Containers::Vector< IndexType, DeviceType, IndexType > CompressedRowLengthsVector;
@@ -88,7 +91,7 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >
      void reset();

      __cuda_callable__
      const RowView getRow( const IndexType& rowIdx ) const;
      const ConstRowView getRow( const IndexType& rowIdx ) const;

      __cuda_callable__
      RowView getRow( const IndexType& rowIdx );
+2 −2
Original line number Diff line number Diff line
@@ -193,10 +193,10 @@ template< typename Real,
          template< typename, typename > class SegmentsView >
__cuda_callable__ auto
SparseMatrixView< Real, Device, Index, MatrixType, SegmentsView >::
getRow( const IndexType& rowIdx ) const -> const RowView
getRow( const IndexType& rowIdx ) const -> const ConstRowView
{
   TNL_ASSERT_LT( rowIdx, this->getRows(), "Row index is larger than number of matrix rows." );
   return RowView( this->segments.getSegmentView( rowIdx ), this->values.getView(), this->columnIndexes.getView() );
   return ConstRowView( this->segments.getSegmentView( rowIdx ), this->values.getConstView(), this->columnIndexes.getConstView() );
}

template< typename Real,