Commit 0e118aae authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Fixing SparseMatrixRowViewValueGetter for binary matrices.

parent 362ecdd4
Loading
Loading
Loading
Loading
+8 −2
Original line number Diff line number Diff line
@@ -25,12 +25,18 @@ namespace Matrices {
 * \tparam Index is a type of matrix elements column indexes.
 */
template< typename Real,
          typename Index,
          bool isBinary_ = false >
          typename Index >
class SparseMatrixElement
{
   public:

      /**
       * \brief Test of binary matrix type.
       *
       * \return \e true if the matrix is stored as binary and \e false otherwise.
       */
      static constexpr bool isBinary() { return std::is_same< std::remove_const_t< Real >, bool >::value; };

      /**
       * \brief Type of matrix elements values.
       */
+14 −18
Original line number Diff line number Diff line
@@ -25,7 +25,6 @@ namespace Matrices {
 * \tparam SegmentView is a segment view of segments representing the matrix format.
 * \tparam ValuesView is a vector view storing the matrix elements values.
 * \tparam ColumnsIndexesView is a vector view storing the column indexes of the matrix element.
 * \tparam isBinary tells if the the parent matrix is a binary matrix.
 *
 * See \ref SparseMatrix and \ref SparseMatrixView.
 *
@@ -41,12 +40,17 @@ namespace Matrices {
 */
template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
class SparseMatrixRowView
{
   public:

      /**
       * \brief Tells whether the parent matrix is a binary matrix.
       * @return `true` if the matrix is binary.
       */
      static constexpr bool isBinary() { return std::is_same< std::remove_const_t< RealType >, bool >::value; };

      /**
       * \brief The type of matrix elements.
       */
@@ -85,12 +89,12 @@ class SparseMatrixRowView
      /**
       * \brief Type of sparse matrix row view.
       */
      using RowView = SparseMatrixRowView< SegmentView, ValuesViewType, ColumnsIndexesViewType, isBinary_ >;
      using RowView = SparseMatrixRowView< SegmentView, ValuesViewType, ColumnsIndexesViewType >;

      /**
       * \brief Type of constant sparse matrix row view.
       */
      using ConstView = SparseMatrixRowView< SegmentView, ConstValuesViewType, ConstColumnsIndexesViewType, isBinary_ >;
      using ConstView = SparseMatrixRowView< SegmentView, ConstValuesViewType, ConstColumnsIndexesViewType >;

      /**
       * \brief The type of related matrix element.
@@ -102,13 +106,7 @@ class SparseMatrixRowView
       */
      using IteratorType = MatrixRowViewIterator< RowView >;

      using ValueGetterType = details::SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >;

      /**
       * \brief Tells whether the parent matrix is a binary matrix.
       * @return `true` if the matrix is binary.
       */
      static constexpr bool isBinary() { return isBinary_; };
      using ValueGetterType = details::SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView >;

      /**
       * \brief Constructor with \e segmentView, \e values and \e columnIndexes.
@@ -220,10 +218,9 @@ class SparseMatrixRowView
       */
      template< typename _SegmentView,
                typename _ValuesView,
                typename _ColumnsIndexesView,
                bool _isBinary >
                typename _ColumnsIndexesView >
      __cuda_callable__
      bool operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView, _isBinary >& other ) const;
      bool operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView >& other ) const;

      /**
       * \brief Returns iterator pointing at the beginning of the matrix row.
@@ -278,9 +275,8 @@ class SparseMatrixRowView
 */
template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >& row );
          typename ColumnsIndexesView >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >& row );

} // namespace Matrices
} // namespace TNL
+37 −56
Original line number Diff line number Diff line
@@ -18,10 +18,9 @@ namespace Matrices {

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
SparseMatrixRowView( const SegmentViewType& segmentView,
                     const ValuesViewType& values,
                     const ColumnsIndexesViewType& columnIndexes )
@@ -31,10 +30,9 @@ SparseMatrixRowView( const SegmentViewType& segmentView,

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getSize() const -> IndexType
{
   return segmentView.getSize();
@@ -42,11 +40,10 @@ getSize() const -> IndexType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__
auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getRowIndex() const -> const IndexType&
{
   return segmentView.getSegmentIndex();
@@ -54,10 +51,9 @@ getRowIndex() const -> const IndexType&

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getColumnIndex( const IndexType localIdx ) const -> const IndexType&
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." );
@@ -66,10 +62,9 @@ getColumnIndex( const IndexType localIdx ) const -> const IndexType&

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getColumnIndex( const IndexType localIdx ) -> IndexType&
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." );
@@ -78,14 +73,12 @@ getColumnIndex( const IndexType localIdx ) -> IndexType&

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getValue( const IndexType localIdx ) const -> typename ValueGetterType::ConstResultType
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." );
   //TNL_ASSERT_FALSE( isBinary(), "Cannot call this method for binary matrix row." );
   return ValueGetterType::getValue( segmentView.getGlobalIndex( localIdx ),
                                     values,
                                     columnIndexes,
@@ -94,14 +87,12 @@ getValue( const IndexType localIdx ) const -> typename ValueGetterType::ConstRes

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
getValue( const IndexType localIdx ) -> typename ValueGetterType::ResultType
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." );
   //TNL_ASSERT_FALSE( isBinary(), "Cannot call this method for binary matrix row." );
   return ValueGetterType::getValue( segmentView.getGlobalIndex( localIdx ),
                                     values,
                                     columnIndexes,
@@ -110,10 +101,9 @@ getValue( const IndexType localIdx ) -> typename ValueGetterType::ResultType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ void
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
setValue( const IndexType localIdx,
          const RealType& value )
{
@@ -126,10 +116,9 @@ setValue( const IndexType localIdx,

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ void
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
setColumnIndex( const IndexType localIdx,
                const IndexType& columnIndex )
{
@@ -140,10 +129,9 @@ setColumnIndex( const IndexType localIdx,

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ void
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
setElement( const IndexType localIdx,
            const IndexType column,
            const RealType& value )
@@ -157,22 +145,20 @@ setElement( const IndexType localIdx,

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
   template< typename _SegmentView,
             typename _ValuesView,
             typename _ColumnsIndexesView,
             bool _isBinary >
             typename _ColumnsIndexesView >
__cuda_callable__
bool
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView, _isBinary >& other ) const
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView >& other ) const
{
   IndexType i = 0;
   while( i < getSize() && i < other.getSize() ) {
      if( getColumnIndex( i ) != other.getColumnIndex( i ) )
         return false;
      if( ! _isBinary && getValue( i ) != other.getValue( i ) )
      if( ! isBinary() && getValue( i ) != other.getValue( i ) )
         return false;
      ++i;
   }
@@ -189,10 +175,9 @@ operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexe

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
begin() -> IteratorType
{
   return IteratorType( *this, 0 );
@@ -200,10 +185,9 @@ begin() -> IteratorType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
end() -> IteratorType
{
   return IteratorType( *this, this->getSize() );
@@ -211,10 +195,9 @@ end() -> IteratorType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
cbegin() const -> const IteratorType
{
   return IteratorType( *this, 0 );
@@ -222,10 +205,9 @@ cbegin() const -> const IteratorType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename ColumnsIndexesView >
__cuda_callable__ auto
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::
cend() const -> const IteratorType
{
   return IteratorType( *this, this->getSize() );
@@ -233,13 +215,12 @@ cend() const -> const IteratorType

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >& row )
          typename ColumnsIndexesView >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >& row )
{
   using NonConstIndex = std::remove_const_t< typename SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::IndexType >;
   using NonConstIndex = std::remove_const_t< typename SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView >::IndexType >;
   for( NonConstIndex i = 0; i < row.getSize(); i++ )
      if( isBinary_ )
      if( row.isBinary() )
         // TODO: check getPaddingIndex(), print only the column indices of non-zeros but not the values
         str << " [ " << row.getColumnIndex( i ) << " ] = " << (row.getColumnIndex( i ) >= 0) << ", ";
      else
+2 −2
Original line number Diff line number Diff line
@@ -136,12 +136,12 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >
      /**
       * \brief Type for accessing matrix rows.
       */
      using RowView = SparseMatrixRowView< typename SegmentsViewType::SegmentViewType, ValuesViewType, ColumnsIndexesViewType, isBinary() >;
      using RowView = SparseMatrixRowView< typename SegmentsViewType::SegmentViewType, ValuesViewType, ColumnsIndexesViewType >;

      /**
       * \brief Type for accessing constant matrix rows.
       */
      using ConstRowView = SparseMatrixRowView< typename SegmentsViewType::SegmentViewType, ConstValuesViewType, ConstColumnsIndexesViewType, isBinary() >;;
      using ConstRowView = SparseMatrixRowView< typename SegmentsViewType::SegmentViewType, ConstValuesViewType, ConstColumnsIndexesViewType >;;

      /**
       * \brief Helper type for getting self type or its modifications.
+8 −5
Original line number Diff line number Diff line
@@ -20,13 +20,15 @@ namespace TNL {
template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
          typename Real = std::remove_const_t<typename ValuesView::RealType >,
          bool isBinary_ = std::is_same< std::remove_const_t<typename ValuesView::RealType >, bool >::value >
struct SparseMatrixRowViewValueGetter {};

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView >
struct SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView, true >
          typename ColumnsIndexesView,
          typename Real >
struct SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView, Real, true >
{
   using RealType = typename ValuesView::RealType;

@@ -47,8 +49,9 @@ struct SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesVi

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView >
struct SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView, false >
          typename ColumnsIndexesView,
          typename Real >
struct SparseMatrixRowViewValueGetter< SegmentView, ValuesView, ColumnsIndexesView, Real, false >
{
   using RealType = typename ValuesView::RealType;