Commit 7c00255a authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added iterator for sparse matrix row.

parent 06a34e5c
Loading
Loading
Loading
Loading
+69 −0
Original line number Diff line number Diff line
/***************************************************************************
                          SparseMatrixElement.h -  description
                             -------------------
    begin                : Mar 21, 2021
    copyright            : (C) 2021 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <ostream>

#include <TNL/Cuda/CudaCallable.h>

namespace TNL {
namespace Matrices {


template< typename Real,
          typename Index,
          bool isBinary_ = false >
class SparseMatrixElement
{
   public:

      using RealType = Real;

      using IndexType = Index;

      __cuda_callable__
      SparseMatrixElement( RealType& value,
                           const IndexType& rowIdx,
                           IndexType& columnIdx,
                           const IndexType& localIdx )
      : value_( value ), rowIdx( rowIdx ), columnIdx( columnIdx ), localIdx( localIdx ) {};

      __cuda_callable__
      RealType& value() { return value_; };

      __cuda_callable__
      const RealType& value() const { return value_; };

      __cuda_callable__
      const IndexType& rowIndex() const { return rowIdx; };

      __cuda_callable__
      IndexType& columnIndex() { return columnIdx; };

      __cuda_callable__
      const IndexType& columnIndex() const { return columnIdx; };

      __cuda_callable__
      const IndexType& localIndex() const { return localIdx; };

   protected:

      RealType& value_;

      const IndexType& rowIdx;

      IndexType& columnIdx;

      const IndexType& localIdx;
};

   } // namespace Matrices
} // namespace TNL
+17 −3
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <ostream>

#include <TNL/Cuda/CudaCallable.h>
#include <TNL/Matrices/SparseMatrixRowViewIterator.h>

namespace TNL {
namespace Matrices {
@@ -83,13 +84,14 @@ class SparseMatrixRowView
      /**
       * \brief Type of sparse matrix row view.
       */
      using RowViewType = SparseMatrixRowView< SegmentView, ValuesViewType, ColumnsIndexesViewType, isBinary_ >;

      using RowView = SparseMatrixRowView< SegmentView, ValuesViewType, ColumnsIndexesViewType, isBinary_ >;

      /**
       * \brief Type of constant sparse matrix row view.
       */
      using ConstRowViewType = SparseMatrixRowView< SegmentView, ConstValuesViewType, ConstColumnsIndexesViewType, isBinary_ >;
      using ConstView = SparseMatrixRowView< SegmentView, ConstValuesViewType, ConstColumnsIndexesViewType, isBinary_ >;

      using IteratorType = SparseMatrixRowViewIterator< RowView >;

      /**
       * \brief Tells whether the parent matrix is a binary matrix.
@@ -212,6 +214,18 @@ class SparseMatrixRowView
      __cuda_callable__
      bool operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView, _isBinary >& other ) const;

      __cuda_callable__
      IteratorType begin();

      __cuda_callable__
      IteratorType end();

      __cuda_callable__
      const IteratorType cbegin() const;

      __cuda_callable__
      const IteratorType cend() const;


   protected:

+45 −0
Original line number Diff line number Diff line
@@ -181,6 +181,50 @@ operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexe
   return true;
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
begin() -> IteratorType
{
   return IteratorType( *this, 0 );
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
end() -> IteratorType
{
   return IteratorType( *this, this->getSize() );
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
cbegin() const -> const IteratorType
{
   return IteratorType( *this, 0 );
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
cend() const -> const IteratorType
{
   return IteratorType( *this, this->getSize() );
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
@@ -197,5 +241,6 @@ std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentV
   return str;
}


} // namespace Matrices
} // namespace TNL
+98 −0
Original line number Diff line number Diff line
 /***************************************************************************
                          SparseMatrixRowView.h -  description
                             -------------------
    begin                : Dec 28, 2019
    copyright            : (C) 2019 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <ostream>

#include <TNL/Cuda/CudaCallable.h>
#include <TNL/Matrices/SparseMatrixElement.h>

namespace TNL {
namespace Matrices {

template< typename RowView >
class SparseMatrixRowViewIterator
{

   public:

      /**
       * \brief Type of SparseMatrixRowView
       */
      using RowViewType = RowView;

      /**
       * \brief The type of matrix elements.
       */
      using RealType = typename RowViewType::RealType;

      /**
       * \brief The type used for matrix elements indexing.
       */
      using IndexType = typename RowViewType::IndexType;

      /**
       * \brief The type of related matrix element.
       */
      using MatrixElementType = SparseMatrixElement< RealType, IndexType >;

      /**
       * \brief Tells whether the parent matrix is a binary matrix.
       * @return `true` if the matrix is binary.
       */
      static constexpr bool isBinary() { return RowViewType::isBinary(); };

      __cuda_callable__
      SparseMatrixRowViewIterator( RowViewType& rowView,
                                   const IndexType& localIdx );

      /**
       * \brief Comparison of two matrix row iterators.
       *
       * \param other is another matrix row iterator.
       * \return \e true if both iterators points at the same point of the same matrix, \e false otherwise.
       */
      __cuda_callable__
      bool operator==( const SparseMatrixRowViewIterator& other ) const;

      /**
       * \brief Comparison of two matrix row iterators.
       *
       * \param other is another matrix row iterator.
       * \return \e false if both iterators points at the same point of the same matrix, \e true otherwise.
       */
      __cuda_callable__
      bool operator!=( const SparseMatrixRowViewIterator& other ) const;

      __cuda_callable__
      SparseMatrixRowViewIterator& operator++();

      __cuda_callable__
      SparseMatrixRowViewIterator& operator--();

      __cuda_callable__
      MatrixElementType operator*();

      __cuda_callable__
      const MatrixElementType operator*() const;

   protected:

      RowViewType& rowView;

      IndexType localIdx = 0;
};


   } // namespace Matrices
} // namespace TNL

#include <TNL/Matrices/SparseMatrixRowViewIterator.hpp>
+95 −0
Original line number Diff line number Diff line
/***************************************************************************
                          SparseMatrixRowView.hpp -  description
                             -------------------
    begin                : Dec 28, 2019
    copyright            : (C) 2019 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <TNL/Matrices/SparseMatrixRowView.h>
#include <TNL/Assert.h>

namespace TNL {
namespace Matrices {

template< typename RowView >
__cuda_callable__
SparseMatrixRowViewIterator< RowView >::
SparseMatrixRowViewIterator( RowViewType& rowView,
                             const IndexType& localIdx )
: rowView( rowView ), localIdx( localIdx )
{
}

template< typename RowView >
__cuda_callable__ bool
SparseMatrixRowViewIterator< RowView >::
operator==( const SparseMatrixRowViewIterator& other ) const
{
   if( &this->rowView == &other.rowView &&
       localIdx == other.localIdx )
      return true;
   return false;
}

template< typename RowView >
__cuda_callable__ bool
SparseMatrixRowViewIterator< RowView >::
operator!=( const SparseMatrixRowViewIterator& other ) const
{
   return ! ( other == *this );
}

template< typename RowView >
__cuda_callable__
SparseMatrixRowViewIterator< RowView >&
SparseMatrixRowViewIterator< RowView >::
operator++()
{
   if( localIdx < rowView.getSize() )
      localIdx ++;
   return *this;
}

template< typename RowView >
__cuda_callable__
SparseMatrixRowViewIterator< RowView >&
SparseMatrixRowViewIterator< RowView >::
operator--()
{
   if( localIdx > 0 )
      localIdx --;
   return *this;
}

template< typename RowView >
__cuda_callable__ auto
SparseMatrixRowViewIterator< RowView >::
operator*() -> MatrixElementType
{
   return MatrixElementType(
      this->rowView.getValue( this->localIdx ),
      this->rowView.getRowIndex(),
      this->rowView.getColumnIndex( this->localIdx ),
      this->localIdx );
}

template< typename RowView >
__cuda_callable__ auto
SparseMatrixRowViewIterator< RowView >::
operator*() const -> const MatrixElementType
{
   return MatrixElementType(
      this->rowView.getValue( this->localIdx ),
      this->rowView.getRowIndex( this->localIdx ),
      this->rowView.getColumnIndex( this->localIdx ),
      this->localIdx );
}


   } // namespace Matrices
} // namespace TNL
Loading