Commit cdb5807d authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed interface of DistributedMatrix

parent 07fd608b
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
@@ -81,6 +81,9 @@ public:
   __cuda_callable__
   const Matrix& getLocalMatrix() const;

   __cuda_callable__
   Matrix& getLocalMatrix();


   /*
    * Some common Matrix methods follow below.
@@ -102,7 +105,8 @@ public:
   __cuda_callable__
   IndexType getColumns() const;

   void setRowCapacities( const CompressedRowLengthsVector& rowLengths );
   template< typename RowCapacitiesVector >
   void setRowCapacities( const RowCapacitiesVector& rowCapacities );

   template< typename Vector >
   void getCompressedRowLengths( Vector& rowLengths ) const;
@@ -144,6 +148,13 @@ public:
   vectorProduct( const InVector& inVector,
                  OutVector& outVector ) const;

   // FIXME: does not work for distributed matrices, here only due to common interface
   template< typename Vector1, typename Vector2 >
   bool performSORIteration( const Vector1& b,
                             const IndexType row,
                             Vector2& x,
                             const RealType& omega = 1.0 ) const;

protected:
   LocalRangeType localRowRange;
   IndexType rows = 0;  // global rows count
+29 −5
Original line number Diff line number Diff line
@@ -70,6 +70,16 @@ getLocalMatrix() const
   return localMatrix;
}

template< typename Matrix,
          typename Communicator >
__cuda_callable__
Matrix&
DistributedMatrix< Matrix, Communicator >::
getLocalMatrix()
{
   return localMatrix;
}


/*
 * Some common Matrix methods follow below.
@@ -149,16 +159,17 @@ getColumns() const

template< typename Matrix,
          typename Communicator >
   template< typename RowCapacitiesVector >
void
DistributedMatrix< Matrix, Communicator >::
setRowCapacities( const CompressedRowLengthsVector& rowLengths )
setRowCapacities( const RowCapacitiesVector& rowCapacities )
{
   TNL_ASSERT_EQ( rowLengths.getSize(), getRows(), "row lengths vector has wrong size" );
   TNL_ASSERT_EQ( rowLengths.getLocalRange(), getLocalRowRange(), "row lengths vector has wrong distribution" );
   TNL_ASSERT_EQ( rowLengths.getCommunicationGroup(), getCommunicationGroup(), "row lengths vector has wrong communication group" );
   TNL_ASSERT_EQ( rowCapacities.getSize(), getRows(), "row lengths vector has wrong size" );
   TNL_ASSERT_EQ( rowCapacities.getLocalRange(), getLocalRowRange(), "row lengths vector has wrong distribution" );
   TNL_ASSERT_EQ( rowCapacities.getCommunicationGroup(), getCommunicationGroup(), "row lengths vector has wrong communication group" );

   if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
      localMatrix.setRowCapacities( rowLengths.getConstLocalView() );
      localMatrix.setRowCapacities( rowCapacities.getConstLocalView() );

      spmv.reset();
   }
@@ -296,5 +307,18 @@ vectorProduct( const InVector& inVector,
   const_cast< DistributedMatrix* >( this )->spmv.vectorProduct( outVector, localMatrix, inVector, getCommunicationGroup() );
}

template< typename Matrix,
          typename Communicator >
   template< typename Vector1, typename Vector2 >
bool
DistributedMatrix< Matrix, Communicator >::
performSORIteration( const Vector1& b,
                     const IndexType row,
                     Vector2& x,
                     const RealType& omega ) const
{
   return getLocalMatrix().performSORIteration( b, row, x, omega );
}

} // namespace Matrices
} // namespace TNL