Commit d4cbcc7c authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixing const view for matrices.

parent f243db3b
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ class Sparse : public TNL::Matrices::Matrix< Real, Device, Index >
   typedef Real RealType;
   typedef Device DeviceType;
   typedef Index IndexType;
   typedef typename TNL::Matrices::Matrix< RealType, DeviceType, IndexType >::ValuesVectorType ValuesVector;
   typedef typename TNL::Matrices::Matrix< RealType, DeviceType, IndexType >::ValuesType ValuesVector;
   typedef Containers::Vector< IndexType, DeviceType, IndexType > ColumnIndexesVector;
   typedef TNL::Matrices::Matrix< Real, Device, Index > BaseType;
   typedef SparseRow< RealType, IndexType > MatrixRow;
+22 −2
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
{
   protected:
      using BaseType = Matrix< Real, Device, Index, RealAllocator >;
      using ValuesVectorType = typename BaseType::ValuesVectorType;
      using ValuesVectorType = typename BaseType::ValuesType;
      using ValuesViewType = typename ValuesVectorType::ViewType;
      using SegmentsType = Algorithms::Segments::Ellpack< Device, Index, typename Allocators::Default< Device >::template Allocator< Index >, Organization, 1 >;
      using SegmentViewType = typename SegmentsType::SegmentViewType;
@@ -92,13 +92,33 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       *
       * See \ref DenseMatrixView.
       */
      using ConstViewType = DenseMatrixView< typename std::add_const< Real >::type, Device, Index, Organization >;
      using ConstViewType = typename DenseMatrixView< Real, Device, Index, Organization >::ConstViewType;

      /**
       * \brief Type for accessing matrix rows.
       */
      using RowView = DenseMatrixRowView< SegmentViewType, ValuesViewType >;

      /**
       * \brief Type of vector holding values of matrix elements.
       */
      using typename Matrix< Real, Device, Index, RealAllocator >::ValuesType;

      /**
       * \brief Type of constant vector holding values of matrix elements.
       */
      using typename Matrix< Real, Device, Index, RealAllocator >::ConstValuesType;

      /**
       * \brief Type of vector view holding values of matrix elements.
       */
      using typename Matrix< Real, Device, Index, RealAllocator >::ValuesView;

      /**
       * \brief Type of constant vector view holding values of matrix elements.
       */
      using typename Matrix< Real, Device, Index, RealAllocator >::ConstValuesView;

      /**
       * \brief Helper type for getting self type or its modifications.
       */
+4 −3
Original line number Diff line number Diff line
@@ -105,9 +105,11 @@ auto
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
getView() -> ViewType
{
   ValuesView values_view = this->getValues().getView();
   // note this is improtant here to avoid const qualifier to appear in - somehow :(
   return ViewType( this->getRows(),
                    this->getColumns(),
                    this->getValues().getView() );
                    values_view );
}

template< typename Real,
@@ -119,10 +121,9 @@ auto
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
getConstView() const -> ConstViewType
{
   DenseMatrix* this_ptr = const_cast< DenseMatrix* >( this );
   return ConstViewType( this->getRows(),
                         this->getColumns(),
                         this_ptr->getValues().getView() );
                         this->getValues().getConstView() );
}

template< typename Real,
+31 −4
Original line number Diff line number Diff line
@@ -43,7 +43,7 @@ class DenseMatrixView : public MatrixView< Real, Device, Index >
{
   protected:
      using BaseType = Matrix< Real, Device, Index >;
      using ValuesVectorType = typename BaseType::ValuesVectorType;
      using ValuesType = typename BaseType::ValuesType;
      using SegmentsType = Algorithms::Segments::Ellpack< Device, Index, typename Allocators::Default< Device >::template Allocator< Index >, Organization, 1 >;
      using SegmentsViewType = typename SegmentsType::ViewType;
      using SegmentViewType = typename SegmentsType::SegmentViewType;
@@ -77,7 +77,14 @@ class DenseMatrixView : public MatrixView< Real, Device, Index >
       *
       * Use this for embedding of the matrix elements values.
       */
      using ValuesViewType = typename ValuesVectorType::ViewType;
      using ValuesViewType = typename ValuesType::ViewType;

      /**
       * \brief Matrix elements container view type.
       *
       * Use this for embedding of the matrix elements values.
       */
      using ConstValuesViewType = typename ValuesType::ConstViewType;

      /**
       * \brief Matrix view type.
@@ -91,7 +98,7 @@ class DenseMatrixView : public MatrixView< Real, Device, Index >
       *
       * See \ref DenseMatrixView.
       */
      using ConstViewType = DenseMatrixView< typename std::add_const< Real >::type, Device, Index, Organization >;
      using ConstViewType = DenseMatrixView< std::add_const_t< Real >, Device, Index, Organization >;

      /**
       * \brief Type for accessing matrix row.
@@ -125,13 +132,33 @@ class DenseMatrixView : public MatrixView< Real, Device, Index >
       * \include Matrices/DenseMatrix/DenseMatrixViewExample_constructor.cpp
       * \par Output
       * \include DenseMatrixViewExample_constructor.out

       */
      __cuda_callable__
      DenseMatrixView( const IndexType rows,
                       const IndexType columns,
                       const ValuesViewType& values );

      /**
       * \brief Constructor with matrix dimensions and values.
       *
       * Organization of matrix elements values in
       *
       * \param rows number of matrix rows.
       * \param columns number of matrix columns.
       * \param values is vector view with matrix elements values.
       *
       * \par Example
       * \include Matrices/DenseMatrix/DenseMatrixViewExample_constructor.cpp
       * \par Output
       * \include DenseMatrixViewExample_constructor.out
       */
       template< typename Real_ >
      __cuda_callable__
      DenseMatrixView( const IndexType rows,
                       const IndexType columns,
                       const Containers::VectorView< Real_, Device, Index >& values );


      /**
       * \brief Copy constructor.
       *
+16 −0
Original line number Diff line number Diff line
@@ -44,6 +44,22 @@ DenseMatrixView( const IndexType rows,
   segments = a.getView();
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Value_ >
__cuda_callable__
DenseMatrixView< Real, Device, Index, Organization >::
DenseMatrixView( const IndexType rows,
                 const IndexType columns,
                 const Containers::VectorView< Value_, Device, Index >& values )
 : MatrixView< Real, Device, Index >( rows, columns, values )
{
   SegmentsType a( rows, columns );
   segments = a.getView();
}

template< typename Real,
          typename Device,
          typename Index,
Loading