Commit 55339b85 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Updated DistributedMatrix, simplified usage of Partitioner

parent 4af3efc6
Loading
Loading
Loading
Loading
+16 −15
Original line number Diff line number Diff line
@@ -16,7 +16,8 @@

#include <TNL/Matrices/SparseRow.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/DistributedContainers/IndexMap.h>
#include <TNL/DistributedContainers/Subrange.h>
#include <TNL/DistributedContainers/Partitioner.h>
#include <TNL/DistributedContainers/DistributedVector.h>

// buffers for vectorProduct
@@ -31,15 +32,16 @@ namespace DistributedContainers {
// TODO: 2D distribution for dense matrices (maybe it should be in different template,
//       because e.g. setRowFast doesn't make sense for dense matrices)
template< typename Matrix,
          typename Communicator = Communicators::MpiCommunicator,
          typename IndexMap = Subrange< typename Matrix::IndexType > >
          typename Communicator = Communicators::MpiCommunicator >
class DistributedMatrix
: public Object
{
   using CommunicationGroup = typename Communicator::CommunicationGroup;

   template< typename Real >
   using DistVector = DistributedVector< Real, typename Matrix::DeviceType, Communicator, typename Matrix::IndexType, IndexMap >;
   using DistVector = DistributedVector< Real, typename Matrix::DeviceType, typename Matrix::IndexType, Communicator >;

   using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;

public:
   using MatrixType = Matrix;
@@ -47,12 +49,12 @@ public:
   using DeviceType = typename Matrix::DeviceType;
   using IndexType = typename Matrix::IndexType;
   using CommunicatorType = Communicator;
   using IndexMapType = IndexMap;
   using LocalRangeType = Subrange< typename Matrix::IndexType >;

   using HostType = DistributedMatrix< typename Matrix::HostType, Communicator, IndexMap >;
   using CudaType = DistributedMatrix< typename Matrix::CudaType, Communicator, IndexMap >;
   using HostType = DistributedMatrix< typename Matrix::HostType, Communicator >;
   using CudaType = DistributedMatrix< typename Matrix::CudaType, Communicator >;

   using CompressedRowLengthsVector = DistributedVector< IndexType, DeviceType, CommunicatorType, IndexType, IndexMapType >;
   using CompressedRowLengthsVector = DistributedVector< IndexType, DeviceType, IndexType, CommunicatorType >;

   using MatrixRow = Matrices::SparseRow< RealType, IndexType >;
   using ConstMatrixRow = Matrices::SparseRow< typename std::add_const< RealType >::type, typename std::add_const< IndexType >::type >;
@@ -61,11 +63,11 @@ public:

   DistributedMatrix( DistributedMatrix& ) = default;

   DistributedMatrix( IndexMap rowIndexMap, IndexType columns, CommunicationGroup group = Communicator::AllGroup );
   DistributedMatrix( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group = Communicator::AllGroup );

   void setDistribution( IndexMap rowIndexMap, IndexType columns, CommunicationGroup group = Communicator::AllGroup );
   void setDistribution( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group = Communicator::AllGroup );

   const IndexMap& getRowIndexMap() const;
   const LocalRangeType& getLocalRowRange() const;

   CommunicationGroup getCommunicationGroup() const;

@@ -150,19 +152,18 @@ public:
   // - assembly of the i-th row involves traversal of the local matrix stored
   //   in the i-th process
   // - assembly the full matrix needs all-to-all communication
   template< typename Partitioner >
   void updateVectorProductCommunicationPattern();

   // multiplication with a distributed vector
   // (not const because it modifies internal bufers)
   template< typename Partitioner,
             typename RealIn,
   template< typename RealIn,
             typename RealOut >
   void vectorProduct( const DistVector< RealIn >& inVector,
                       DistVector< RealOut >& outVector );

protected:
   IndexMap rowIndexMap;
   LocalRangeType localRowRange;
   IndexType rows = 0;  // global rows count
   CommunicationGroup group = Communicator::NullGroup;
   Matrix localMatrix;

+89 −116
Original line number Diff line number Diff line
@@ -23,54 +23,50 @@ namespace TNL {
namespace DistributedContainers {

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix( IndexMap rowIndexMap, IndexType columns, CommunicationGroup group )
          typename Communicator >
DistributedMatrix< Matrix, Communicator >::
DistributedMatrix( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group )
{
   setDistribution( rowIndexMap, columns, group );
   setDistribution( localRowRange, rows, columns, group );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
setDistribution( IndexMap rowIndexMap, IndexType columns, CommunicationGroup group )
DistributedMatrix< Matrix, Communicator >::
setDistribution( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group )
{
   this->rowIndexMap = rowIndexMap;
   this->localRowRange = localRowRange;
   this->rows = rows;
   this->group = group;
   if( group != Communicator::NullGroup )
      localMatrix.setDimensions( rowIndexMap.getLocalSize(), columns );
      localMatrix.setDimensions( localRowRange.getSize(), columns );

   resetBuffers();
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
const IndexMap&
DistributedMatrix< Matrix, Communicator, IndexMap >::
getRowIndexMap() const
          typename Communicator >
const Subrange< typename Matrix::IndexType >&
DistributedMatrix< Matrix, Communicator >::
getLocalRowRange() const
{
   return rowIndexMap;
   return localRowRange;
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
typename Communicator::CommunicationGroup
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getCommunicationGroup() const
{
   return group;
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
const Matrix&
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getLocalMatrix() const
{
   return localMatrix;
@@ -78,24 +74,21 @@ getLocalMatrix() const


template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
String
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getType()
{
   return String( "DistributedContainers::DistributedMatrix< " ) +
          Matrix::getType() + ", " +
          // TODO: communicators don't have a getType method
          "<Communicator>, " +
          IndexMap::getType() + " >";
          "<Communicator>" + " >";
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
String
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getTypeVirtual() const
{
   return getType();
@@ -107,10 +100,9 @@ getTypeVirtual() const
 */

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
DistributedMatrix< Matrix, Communicator, IndexMap >&
DistributedMatrix< Matrix, Communicator, IndexMap >::
          typename Communicator >
DistributedMatrix< Matrix, Communicator >&
DistributedMatrix< Matrix, Communicator >::
operator=( const DistributedMatrix& matrix )
{
   setLike( matrix );
@@ -119,11 +111,10 @@ operator=( const DistributedMatrix& matrix )
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
   template< typename MatrixT >
DistributedMatrix< Matrix, Communicator, IndexMap >&
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >&
DistributedMatrix< Matrix, Communicator >::
operator=( const MatrixT& matrix )
{
   setLike( matrix );
@@ -132,14 +123,14 @@ operator=( const MatrixT& matrix )
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
   template< typename MatrixT >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
setLike( const MatrixT& matrix )
{
   rowIndexMap = matrix.getRowIndexMap();
   localRowRange = matrix.getLocalRowRange();
   rows = matrix.getRows();
   group = matrix.getCommunicationGroup();
   localMatrix.setLike( matrix.getLocalMatrix() );

@@ -147,13 +138,13 @@ setLike( const MatrixT& matrix )
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
reset()
{
   rowIndexMap.reset();
   localRowRange.reset();
   rows = 0;
   group = Communicator::NullGroup;
   localMatrix.reset();

@@ -161,36 +152,33 @@ reset()
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
typename Matrix::IndexType
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getRows() const
{
   return rowIndexMap.getGlobalSize();
   return rows;
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
typename Matrix::IndexType
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getColumns() const
{
   return localMatrix.getColumns();
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
{
   TNL_ASSERT_EQ( rowLengths.getSize(), getRows(), "row lengths vector has wrong size" );
   TNL_ASSERT_EQ( rowLengths.getIndexMap(), getRowIndexMap(), "row lengths vector has wrong distribution" );
   TNL_ASSERT_EQ( rowLengths.getLocalRange(), getLocalRowRange(), "row lengths vector has wrong distribution" );
   TNL_ASSERT_EQ( rowLengths.getCommunicationGroup(), getCommunicationGroup(), "row lengths vector has wrong communication group" );

   if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
@@ -201,147 +189,136 @@ setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getCompressedRowLengths( CompressedRowLengthsVector& rowLengths ) const
{
   if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
      rowLengths.setDistribution( getRowIndexMap(), getCommunicationGroup() );
      rowLengths.setDistribution( getLocalRowRange(), getRows(), getCommunicationGroup() );
      localMatrix.getCompressedRowLengths( rowLengths.getLocalVectorView() );
   }
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
typename Matrix::IndexType
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getRowLength( IndexType row ) const
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getRowLength( localRow );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
bool
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
setElement( IndexType row,
            IndexType column,
            RealType value )
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.setElement( localRow, column, value );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
bool
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
setElementFast( IndexType row,
                IndexType column,
                RealType value )
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.setElementFast( localRow, column, value );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
typename Matrix::RealType
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getElement( IndexType row,
            IndexType column ) const
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getElement( localRow, column );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
typename Matrix::RealType
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getElementFast( IndexType row,
                IndexType column ) const
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getElementFast( localRow, column );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
bool
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
setRowFast( IndexType row,
            const IndexType* columnIndexes,
            const RealType* values,
            IndexType elements )
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.setRowFast( localRow, columnIndexes, values, elements );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
getRowFast( IndexType row,
            IndexType* columns,
            RealType* values ) const
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getRowFast( localRow, columns, values );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
typename DistributedMatrix< Matrix, Communicator, IndexMap >::MatrixRow
DistributedMatrix< Matrix, Communicator, IndexMap >::
typename DistributedMatrix< Matrix, Communicator >::MatrixRow
DistributedMatrix< Matrix, Communicator >::
getRow( IndexType row )
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getRow( localRow );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
typename DistributedMatrix< Matrix, Communicator, IndexMap >::ConstMatrixRow
DistributedMatrix< Matrix, Communicator, IndexMap >::
typename DistributedMatrix< Matrix, Communicator >::ConstMatrixRow
DistributedMatrix< Matrix, Communicator >::
getRow( IndexType row ) const
{
   const IndexType localRow = rowIndexMap.getLocalIndex( row );
   const IndexType localRow = localRowRange.getLocalIndex( row );
   return localMatrix.getRow( localRow );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
          typename Communicator >
   template< typename Vector,
             typename RealOut >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
vectorProduct( const Vector& inVector,
               DistVector< RealOut >& outVector ) const
{
   TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
   TNL_ASSERT_EQ( outVector.getSize(), getRows(), "output vector has wrong size" );
   TNL_ASSERT_EQ( outVector.getIndexMap(), getRowIndexMap(), "output vector has wrong distribution" );
   TNL_ASSERT_EQ( outVector.getLocalRange(), getLocalRowRange(), "output vector has wrong distribution" );
   TNL_ASSERT_EQ( outVector.getCommunicationGroup(), getCommunicationGroup(), "output vector has wrong communication group" );

   auto outView = outVector.getLocalVectorView();
@@ -349,11 +326,9 @@ vectorProduct( const Vector& inVector,
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
   template< typename Partitioner >
          typename Communicator >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
updateVectorProductCommunicationPattern()
{
   if( getCommunicationGroup() == CommunicatorType::NullGroup )
@@ -431,21 +406,19 @@ updateVectorProductCommunicationPattern()
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
   template< typename Partitioner,
             typename RealIn,
          typename Communicator >
   template< typename RealIn,
             typename RealOut >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
DistributedMatrix< Matrix, Communicator >::
vectorProduct( const DistVector< RealIn >& inVector,
               DistVector< RealOut >& outVector )
{
   TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
   TNL_ASSERT_EQ( inVector.getIndexMap(), getRowIndexMap(), "input vector has wrong distribution" );
   TNL_ASSERT_EQ( inVector.getLocalRange(), getLocalRowRange(), "input vector has wrong distribution" );
   TNL_ASSERT_EQ( inVector.getCommunicationGroup(), getCommunicationGroup(), "input vector has wrong communication group" );
   TNL_ASSERT_EQ( outVector.getSize(), getRows(), "output vector has wrong size" );
   TNL_ASSERT_EQ( outVector.getIndexMap(), getRowIndexMap(), "output vector has wrong distribution" );
   TNL_ASSERT_EQ( outVector.getLocalRange(), getLocalRowRange(), "output vector has wrong distribution" );
   TNL_ASSERT_EQ( outVector.getCommunicationGroup(), getCommunicationGroup(), "output vector has wrong communication group" );

   if( getCommunicationGroup() == CommunicatorType::NullGroup )
@@ -456,7 +429,7 @@ vectorProduct( const DistVector< RealIn >& inVector,

   // update communication pattern
   if( commPattern.getRows() != nproc )
      updateVectorProductCommunicationPattern< Partitioner >();
      updateVectorProductCommunicationPattern();

   // prepare buffers
   globalBuffer.setSize( localMatrix.getColumns() );
+22 −24

File changed.

Preview size limit exceeded, changes collapsed.