Commit fd9ccbd5 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Tomáš Oberhuber
Browse files

Added operator== and operator!= for SparseMatrix

parent b761bc3f
Loading
Loading
Loading
Loading
+6 −0
Original line number Diff line number Diff line
@@ -243,6 +243,12 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
      template< typename RHSMatrix >
      SparseMatrix& operator=( const RHSMatrix& matrix );

      template< typename Matrix >
      bool operator==( const Matrix& m ) const;

      template< typename Matrix >
      bool operator!=( const Matrix& m ) const;

      void save( File& file ) const;

      void load( File& file );
+30 −0
Original line number Diff line number Diff line
@@ -880,6 +880,36 @@ operator=( const RHSMatrix& matrix )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index,
          typename MatrixType,
          template< typename, typename, typename > class Segments,
          typename RealAllocator,
          typename IndexAllocator >
   template< typename Matrix >
bool
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
operator==( const Matrix& m ) const
{
   return view == m;
}

template< typename Real,
          typename Device,
          typename Index,
          typename MatrixType,
          template< typename, typename, typename > class Segments,
          typename RealAllocator,
          typename IndexAllocator >
   template< typename Matrix >
bool
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
operator!=( const Matrix& m ) const
{
   return view != m;
}

template< typename Real,
          typename Device,
          typename Index,
+8 −0
Original line number Diff line number Diff line
@@ -64,6 +64,14 @@ class SparseMatrixRowView
      void setElement( const IndexType localIdx,
                       const IndexType column,
                       const RealType& value );

      template< typename _SegmentView,
                typename _ValuesView,
                typename _ColumnsIndexesView,
                bool _isBinary >
      __cuda_callable__
      bool operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView, _isBinary >& other ) const;

   protected:

      SegmentViewType segmentView;
+32 −0
Original line number Diff line number Diff line
@@ -123,6 +123,38 @@ setElement( const IndexType localIdx,
      values[ globalIdx ] = value;
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
          bool isBinary_ >
   template< typename _SegmentView,
             typename _ValuesView,
             typename _ColumnsIndexesView,
             bool _isBinary >
__cuda_callable__
bool
SparseMatrixRowView< SegmentView, ValuesView, ColumnsIndexesView, isBinary_ >::
operator==( const SparseMatrixRowView< _SegmentView, _ValuesView, _ColumnsIndexesView, _isBinary >& other ) const
{
   IndexType i = 0;
   while( i < getSize() && i < other.getSize() ) {
      if( getColumnIndex( i ) != other.getColumnIndex( i ) )
         return false;
      if( getValue( i ) != other.getValue( i ) )
         return false;
      ++i;
   }
   for( IndexType j = i; j < getSize(); j++ )
      // TODO: use ... != getPaddingIndex()
      if( getColumnIndex( j ) >= 0 )
         return false;
   for( IndexType j = i; j < other.getSize(); j++ )
      // TODO: use ... != getPaddingIndex()
      if( other.getColumnIndex( j ) >= 0 )
         return false;
   return true;
}

template< typename SegmentView,
          typename ValuesView,
          typename ColumnsIndexesView,
+6 −0
Original line number Diff line number Diff line
@@ -153,6 +153,12 @@ class SparseMatrixView : public MatrixView< Real, Device, Index >

      SparseMatrixView& operator=( const SparseMatrixView& matrix );

      template< typename Matrix >
      bool operator==( const Matrix& m ) const;

      template< typename Matrix >
      bool operator!=( const Matrix& m ) const;

      void save( File& file ) const;

      void save( const String& fileName ) const;
Loading