Commit 228a552b authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added iterator to Dense, Lambda, Multidiagonal and Tridiagonal matrix row.

parent 7c00255a
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -472,7 +472,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \include DenseMatrixExample_forRows.out
       */
      template< typename Function >
      void forElements( IndexType begin, IndexType end, Function& function ) const;
      void forElements( IndexType begin, IndexType end, Function&& function ) const;

      /**
       * \brief Method for iteration over all matrix rows for non-constant instances.
@@ -494,7 +494,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \include DenseMatrixExample_forRows.out
       */
      template< typename Function >
      void forElements( IndexType begin, IndexType end, Function& function );
      void forElements( IndexType begin, IndexType end, Function&& function );

      /**
       * \brief This method calls \e forElements for all matrix rows (for constant instances).
@@ -510,7 +510,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \include DenseMatrixExample_forAllRows.out
       */
      template< typename Function >
      void forAllElements( Function& function ) const;
      void forAllElements( Function&& function ) const;

      /**
       * \brief This method calls \e forElements for all matrix rows.
@@ -526,7 +526,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \include DenseMatrixExample_forAllRows.out
       */
      template< typename Function >
      void forAllElements( Function& function );
      void forAllElements( Function&& function );

      /**
       * \brief Method for parallel iteration over matrix rows from interval [ \e begin, \e end).
@@ -643,7 +643,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \param function is an instance of the lambda function to be called in each row.
       */
      template< typename Function >
      void sequentialForRows( IndexType begin, IndexType end, Function& function ) const;
      void sequentialForRows( IndexType begin, IndexType end, Function&& function ) const;

      /**
       * \brief Method for sequential iteration over all matrix rows for non-constant instances.
@@ -660,7 +660,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \param function is an instance of the lambda function to be called in each row.
       */
      template< typename Function >
      void sequentialForRows( IndexType begin, IndexType end, Function& function );
      void sequentialForRows( IndexType begin, IndexType end, Function&& function );

      /**
       * \brief This method calls \e sequentialForRows for all matrix rows (for constant instances).
@@ -671,7 +671,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \param function  is an instance of the lambda function to be called in each row.
       */
      template< typename Function >
      void sequentialForAllRows( Function& function ) const;
      void sequentialForAllRows( Function&& function ) const;

      /**
       * \brief This method calls \e sequentialForRows for all matrix rows.
@@ -682,7 +682,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \param function  is an instance of the lambda function to be called in each row.
       */
      template< typename Function >
      void sequentialForAllRows( Function& function );
      void sequentialForAllRows( Function&& function );

      /**
       * \brief Method for performing general reduction on matrix rows.
+8 −8
Original line number Diff line number Diff line
@@ -407,7 +407,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
forElements( IndexType begin, IndexType end, Function& function ) const
forElements( IndexType begin, IndexType end, Function&& function ) const
{
   this->view.forElements( begin, end, function );
}
@@ -420,7 +420,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
forElements( IndexType first, IndexType last, Function& function )
forElements( IndexType first, IndexType last, Function&& function )
{
   this->view.forElements( first, last, function );
}
@@ -433,7 +433,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
forAllElements( Function& function ) const
forAllElements( Function&& function ) const
{
   this->forElements( 0, this->getRows(), function );
}
@@ -446,7 +446,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
forAllElements( Function& function )
forAllElements( Function&& function )
{
   this->forElements( 0, this->getRows(), function );
}
@@ -511,7 +511,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
sequentialForRows( IndexType begin, IndexType end, Function& function ) const
sequentialForRows( IndexType begin, IndexType end, Function&& function ) const
{
   this->view.sequentialForRows( begin, end, function );
}
@@ -524,7 +524,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
sequentialForRows( IndexType first, IndexType last, Function& function )
sequentialForRows( IndexType first, IndexType last, Function&& function )
{
   this->view.sequentialForRows( first, last, function );
}
@@ -537,7 +537,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
sequentialForAllRows( Function& function ) const
sequentialForAllRows( Function&& function ) const
{
   this->sequentialForRows( 0, this->getRows(), function );
}
@@ -550,7 +550,7 @@ template< typename Real,
   template< typename Function >
void
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
sequentialForAllRows( Function& function )
sequentialForAllRows( Function&& function )
{
   this->sequentialForRows( 0, this->getRows(), function );
}
+63 −0
Original line number Diff line number Diff line
/***************************************************************************
                          DenseMatrixElement.h -  description
                             -------------------
    begin                : Mar 22, 2021
    copyright            : (C) 2021 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <ostream>

#include <TNL/Cuda/CudaCallable.h>

namespace TNL {
namespace Matrices {


template< typename Real,
          typename Index >
class DenseMatrixElement
{
   public:

      using RealType = Real;

      using IndexType = Index;

      __cuda_callable__
      DenseMatrixElement( RealType& value,
                          const IndexType& rowIdx,
                          const IndexType& columnIdx,
                          const IndexType& localIdx )  // localIdx is here only for compatibility with SparseMatrixElement
      : value_( value ), rowIdx( rowIdx ), columnIdx( columnIdx ) {};

      __cuda_callable__
      RealType& value() { return value_; };

      __cuda_callable__
      const RealType& value() const { return value_; };

      __cuda_callable__
      const IndexType& rowIndex() const { return rowIdx; };

      __cuda_callable__
      const IndexType& columnIndex() const { return columnIdx; };

      __cuda_callable__
      const IndexType& localIndex() const { return columnIdx; };

   protected:

      RealType& value_;

      const IndexType& rowIdx;

      const IndexType& columnIdx;
};

   } // namespace Matrices
} // namespace TNL
+76 −4
Original line number Diff line number Diff line
@@ -10,6 +10,10 @@

#pragma once

#include <TNL/Cuda/CudaCallable.h>
#include <TNL/Matrices/SparseMatrixRowViewIterator.h>
#include <TNL/Matrices/DenseMatrixElement.h>

namespace TNL {
   namespace Matrices {

@@ -57,6 +61,31 @@ class DenseMatrixRowView
       */
      using ValuesViewType = ValuesView;

      /**
       * \brief Type of constant container view used for storing the matrix elements values.
       */
      using ConstValuesViewType = typename ValuesViewType::ConstViewType;

      /**
       * \brief Type of dense matrix row view.
       */
      using RowView = DenseMatrixRowView< SegmentView, ValuesViewType >;

      /**
       * \brief Type of constant sparse matrix row view.
       */
      using ConstView = DenseMatrixRowView< SegmentView, ConstValuesViewType >;

      /**
       * \brief The type of related matrix element.
       */
      using MatrixElementType = DenseMatrixElement< RealType, IndexType >;

      /**
       * \brief Type of iterator for the matrix row.
       */
      using IteratorType = SparseMatrixRowViewIterator< RowView >;

      /**
       * \brief Constructor with \e segmentView and \e values
       *
@@ -91,7 +120,7 @@ class DenseMatrixRowView
       * \return constant reference to the matrix element.
       */
      __cuda_callable__
      const RealType& getElement( const IndexType column ) const;
      const RealType& getValue( const IndexType column ) const;

      /**
       * \brief Returns non-constants reference to an element with given column index.
@@ -101,7 +130,17 @@ class DenseMatrixRowView
       * \return non-constant reference to the matrix element.
       */
      __cuda_callable__
      RealType& getElement( const IndexType column );
      RealType& getValue( const IndexType column );

      /**
       * \brief This method is only for compatibility with sparse matrix row.
       *
       * \param localIdx is the rank of the matrix element in given row.
       *
       * \return the value of \ref localIdx as column index.
       */
      __cuda_callable__
      IndexType getColumnIndex( const IndexType localIdx ) const;

      /**
       * \brief Sets value of matrix element with given column index
@@ -110,7 +149,7 @@ class DenseMatrixRowView
       * \param value is a value the matrix element will be set to.
       */
      __cuda_callable__
      void setElement( const IndexType column,
      void setValue( const IndexType column,
                     const RealType& value );

      /**
@@ -126,6 +165,39 @@ class DenseMatrixRowView
      void setElement( const IndexType localIdx,
                       const IndexType column,
                       const RealType& value );

      /**
       * \brief Returns iterator pointing at the beginning of the matrix row.
       *
       * \return iterator pointing at the beginning.
       */
      __cuda_callable__
      IteratorType begin();

      /**
       * \brief Returns iterator pointing at the end of the matrix row.
       *
       * \return iterator pointing at the end.
       */
      __cuda_callable__
      IteratorType end();

      /**
       * \brief Returns constant iterator pointing at the beginning of the matrix row.
       *
       * \return iterator pointing at the beginning.
       */
      __cuda_callable__
      const IteratorType cbegin() const;

      /**
       * \brief Returns constant iterator pointing at the end of the matrix row.
       *
       * \return iterator pointing at the end.
       */
      __cuda_callable__
      const IteratorType cend() const;

   protected:

      SegmentViewType segmentView;
+51 −4
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
getElement( const IndexType column ) const -> const RealType&
getValue( const IndexType column ) const -> const RealType&
{
   TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." );
   return values[ segmentView.getGlobalIndex( column ) ];
@@ -57,17 +57,28 @@ template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
getElement( const IndexType column ) -> RealType&
getValue( const IndexType column ) -> RealType&
{
   TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." );
   return values[ segmentView.getGlobalIndex( column ) ];
}

template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
getColumnIndex( const IndexType localIdx ) const -> IndexType
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Column index exceeds matrix row size." );
   return localIdx;
}


template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ void
DenseMatrixRowView< SegmentView, ValuesView >::
setElement( const IndexType column,
setValue( const IndexType column,
          const RealType& value )
{
   TNL_ASSERT_LT( column, this->getSize(), "Column index exceeds matrix row size." );
@@ -88,5 +99,41 @@ setElement( const IndexType localIdx,
   values[ globalIdx ] = value;
}

template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
begin() -> IteratorType
{
   return IteratorType( *this, 0 );
}

template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
end() -> IteratorType
{
   return IteratorType( *this, this->getSize() );
}

template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
cbegin() const -> const IteratorType
{
   return IteratorType( *this, 0 );
}

template< typename SegmentView,
          typename ValuesView >
__cuda_callable__ auto
DenseMatrixRowView< SegmentView, ValuesView >::
cend() const -> const IteratorType
{
   return IteratorType( *this, this->getSize() );
}

   } // namespace Matrices
} // namespace TNL
Loading