Commit db588b4f authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Added more comparison operators for dense matrices.

parent c67f1184
Loading
Loading
Loading
Loading
+65 −2
Original line number Diff line number Diff line
@@ -974,7 +974,7 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \return \e true if the RHS matrix is equal, \e false otherwise.
       */
      template< typename Real_, typename Device_, typename Index_, typename RealAllocator_ >
      bool operator==( const DenseMatrix< Real_, Device_, Index_, Organization >& matrix ) const;
      bool operator==( const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator_ >& matrix ) const;

      /**
       * \brief Comparison operator with another dense matrix.
@@ -983,7 +983,43 @@ class DenseMatrix : public Matrix< Real, Device, Index, RealAllocator >
       * \return \e false if the RHS matrix is equal, \e true otherwise.
       */
      template< typename Real_, typename Device_, typename Index_, typename RealAllocator_ >
      bool operator!=( const DenseMatrix< Real_, Device_, Index_, Organization >& matrix ) const;
      bool operator!=( const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator_ >& matrix ) const;

      /**
       * \brief Comparison operator with another dense matrix view.
       *
       * \param matrix is the right-hand side matrix view.
       * \return \e true if the RHS matrix view is equal, \e false otherwise.
       */
      template< typename Real_, typename Device_, typename Index_ >
      bool operator==( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const;

      /**
       * \brief Comparison operator with another dense matrix view.
       *
       * \param matrix is the right-hand side matrix view.
       * \return \e false if the RHS matrix view is equal, \e true otherwise.
       */
      template< typename Real_, typename Device_, typename Index_ >
      bool operator!=( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const;

      /**
       * \brief Comparison operator with another arbitrary matrix type.
       *
       * \param matrix is the right-hand side matrix.
       * \return \e true if the RHS matrix is equal, \e false otherwise.
       */
      template< typename Matrix >
      bool operator==( const Matrix& m ) const;

      /**
       * \brief Comparison operator with another arbitrary matrix type.
       *
       * \param matrix is the right-hand side matrix.
       * \return \e true if the RHS matrix is equal, \e false otherwise.
       */
      template< typename Matrix >
      bool operator!=( const Matrix& m ) const;

      /**
       * \brief Method for saving the matrix to the file with given filename.
@@ -1045,6 +1081,33 @@ template< typename Real,
          typename RealAllocator >
std::ostream& operator<< ( std::ostream& str, const DenseMatrix< Real, Device, Index, Organization, RealAllocator >& matrix );

/**
 * \brief Comparison operator with another dense matrix view.
 *
 * \param leftMatrix is the left-hand side matrix view.
 * \param rightMatrix is the right-hand side matrix.
 * \return \e true if the both matrices are is equal, \e false otherwise.
 */
template< typename Real, typename Device, typename Index,
          typename Real_, typename Device_, typename Index_,
          ElementsOrganization Organization, typename RealAllocator >
bool operator==( const DenseMatrixView< Real, Device, Index, Organization >& leftMatrix,
                 const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator >& rightMatrix );

/**
 * \brief Comparison operator with another dense matrix view.
 *
 * \param leftMatrix is the left-hand side matrix view.
 * \param rightMatrix is the right-hand side matrix.
 * \return \e false if the both matrices are is equal, \e true otherwise.
 */
template< typename Real, typename Device, typename Index,
          typename Real_, typename Device_, typename Index_,
          ElementsOrganization Organization, typename RealAllocator >
bool operator!=( const DenseMatrixView< Real, Device, Index, Organization >& leftMatrix,
                 const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator >& rightMatrix );


} // namespace Matrices
} // namespace TNL

+74 −2
Original line number Diff line number Diff line
@@ -1284,7 +1284,7 @@ template< typename Real,
   template< typename Real_, typename Device_, typename Index_, typename RealAllocator_ >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator==( const DenseMatrix< Real_, Device_, Index_, Organization >& matrix ) const
operator==( const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator_ >& matrix ) const
{
   return( this->getRows() == matrix.getRows() &&
           this->getColumns() == matrix.getColumns() &&
@@ -1299,11 +1299,65 @@ template< typename Real,
   template< typename Real_, typename Device_, typename Index_, typename RealAllocator_ >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator!=( const DenseMatrix< Real_, Device_, Index_, Organization >& matrix ) const
operator!=( const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator_ >& matrix ) const
{
   return ! ( *this == matrix );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization,
          typename RealAllocator >
   template< typename Real_, typename Device_, typename Index_ >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator==( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const
{
   return( this->getRows() == matrix.getRows() &&
           this->getColumns() == matrix.getColumns() &&
           this->getValues() == matrix.getValues() );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization,
          typename RealAllocator >
   template< typename Real_, typename Device_, typename Index_ >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator!=( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const
{
   return ! ( *this == matrix );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization,
          typename RealAllocator >
   template< typename Matrix >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator==( const Matrix& m ) const
{
   return ( this->view == m );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization,
          typename RealAllocator >
   template< typename Matrix >
bool
DenseMatrix< Real, Device, Index, Organization, RealAllocator >::
operator!=( const Matrix& m ) const
{
   return ( this->view != m );
}

template< typename Real,
          typename Device,
          typename Index,
@@ -1380,5 +1434,23 @@ std::ostream& operator<< ( std::ostream& str, const DenseMatrix< Real, Device, I
   return str;
}

template< typename Real, typename Device, typename Index,
          typename Real_, typename Device_, typename Index_,
          ElementsOrganization Organization, typename RealAllocator >
bool operator==( const DenseMatrixView< Real, Device, Index, Organization >& leftMatrix,
                 const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator >& rightMatrix )
{
   return rightMatrix == leftMatrix;
}

template< typename Real, typename Device, typename Index,
          typename Real_, typename Device_, typename Index_,
          ElementsOrganization Organization, typename RealAllocator >
bool operator!=( const DenseMatrixView< Real, Device, Index, Organization >& leftMatrix,
                 const DenseMatrix< Real_, Device_, Index_, Organization, RealAllocator >& rightMatrix )
{
   return rightMatrix != leftMatrix;
}

} // namespace Matrices
} // namespace TNL
+36 −0
Original line number Diff line number Diff line
@@ -870,6 +870,42 @@ class DenseMatrixView : public MatrixView< Real, Device, Index >
       */
      DenseMatrixView& operator=( const DenseMatrixView& matrix );

      /**
       * \brief Comparison operator with another dense matrix view.
       *
       * \param matrix is the right-hand side matrix view.
       * \return \e true if the RHS matrix view is equal, \e false otherwise.
       */
      template< typename Real_, typename Device_, typename Index_ >
      bool operator==( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const;

      /**
       * \brief Comparison operator with another dense matrix view.
       *
       * \param matrix is the right-hand side matrix.
       * \return \e false if the RHS matrix view is equal, \e true otherwise.
       */
      template< typename Real_, typename Device_, typename Index_ >
      bool operator!=( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const;

      /**
       * \brief Comparison operator with another arbitrary matrix type.
       *
       * \param matrix is the right-hand side matrix.
       * \return \e true if the RHS matrix is equal, \e false otherwise.
       */
      template< typename Matrix >
      bool operator==( const Matrix& m ) const;

      /**
       * \brief Comparison operator with another arbitrary matrix type.
       *
       * \param matrix is the right-hand side matrix.
       * \return \e true if the RHS matrix is equal, \e false otherwise.
       */
      template< typename Matrix >
      bool operator!=( const Matrix& m ) const;

      /**
       * \brief Method for saving the matrix view to the file with given filename.
       *
+57 −0
Original line number Diff line number Diff line
@@ -808,6 +808,63 @@ operator=( const DenseMatrixView& matrix )
   return *this;
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Real_, typename Device_, typename Index_ >
bool
DenseMatrixView< Real, Device, Index, Organization >::
operator==( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const
{
   return( this->getRows() == matrix.getRows() &&
           this->getColumns() == matrix.getColumns() &&
           this->getValues() == matrix.getValues() );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Real_, typename Device_, typename Index_ >
bool
DenseMatrixView< Real, Device, Index, Organization >::
operator!=( const DenseMatrixView< Real_, Device_, Index_, Organization >& matrix ) const
{
   return ! ( *this == matrix );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Matrix >
bool
DenseMatrixView< Real, Device, Index, Organization >::
operator==( const Matrix& m ) const
{
   const auto& view1 = *this;
   const auto view2 = m.getConstView();
   auto fetch = [=] __cuda_callable__ ( const IndexType i ) -> bool
   {
      return view1.getRow( i ) == view2.getRow( i );
   };
   return Algorithms::Reduction< DeviceType >::reduce( ( IndexType ) 0, this->getRows(), fetch, std::logical_and<>{}, true );
}

template< typename Real,
          typename Device,
          typename Index,
          ElementsOrganization Organization >
   template< typename Matrix >
bool
DenseMatrixView< Real, Device, Index, Organization >::
operator!=( const Matrix& m ) const
{
   return ! ( *this == m );
}


template< typename Real,
          typename Device,
          typename Index,