Commit c749c8a9 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Tomáš Oberhuber
Browse files

Added operator<< and setValue for SparseMatrixRowView

parent de51fb4d
Loading
Loading
Loading
Loading
+17 −2
Original line number Diff line number Diff line
@@ -10,6 +10,10 @@

#pragma once

#include <ostream>

#include <TNL/Cuda/CudaCallable.h>

namespace TNL {
namespace Matrices {

@@ -52,6 +56,10 @@ class SparseMatrixRowView
      __cuda_callable__
      RealType& getValue( const IndexType localIdx );

      __cuda_callable__
      void setValue( const IndexType localIdx,
                     const RealType& value );

      __cuda_callable__
      void setElement( const IndexType localIdx,
                       const IndexType column,
@@ -64,6 +72,13 @@ class SparseMatrixRowView

      ColumnsIndexesViewType columnIndexes;
};

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >& row );

} // namespace Matrices
} // namespace TNL

+30 −2
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@
#pragma once

#include <TNL/Matrices/SparseMatrixRowView.h>
#include <TNL/Assert.h>

namespace TNL {
namespace Matrices {
@@ -89,6 +90,22 @@ getValue( const IndexType localIdx ) -> RealType&
   return values[ segmentView.getGlobalIndex( localIdx ) ];
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
__cuda_callable__ void
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
setValue( const IndexType localIdx,
          const RealType& value )
{
   TNL_ASSERT_LT( localIdx, this->getSize(), "Local index exceeds matrix row capacity." );
   if( ! isBinary() ) {
      const IndexType globalIdx = segmentView.getGlobalIndex( localIdx );
      values[ globalIdx ] = value;
   }
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
@@ -106,6 +123,17 @@ setElement( const IndexType localIdx,
      values[ globalIdx ] = value;
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
std::ostream& operator<<( std::ostream& str, const SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >& row )
{
   using NonConstIndex = std::remove_const_t< typename SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::IndexType >;
   for( NonConstIndex i = 0; i < row.getSize(); i++ )
      str << " [ " << row.getColumnIndex( i ) << " ] = " << row.getValue( i ) << ", ";
   return str;
}

} // namespace Matrices
} // namespace TNL