Commit e55b52d3 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Implemented vectorProduct for distributed matrices

parent 20dc15de
Loading
Loading
Loading
Loading
+40 −5
Original line number Diff line number Diff line
@@ -14,12 +14,17 @@

#include <type_traits>  // std::add_const

#include <TNL/Containers/Vector.h>
#include <TNL/Matrices/SparseRow.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/DistributedContainers/IndexMap.h>
#include <TNL/DistributedContainers/DistributedVector.h>

// buffers for vectorProduct
#include <vector>
#include <utility>  // std::pair
#include <TNL/Matrices/Dense.h>
#include <TNL/Containers/Vector.h>

namespace TNL {
namespace DistributedContainers {

@@ -139,20 +144,50 @@ public:
   void vectorProduct( const Vector& inVector,
                       DistVector< RealOut >& outVector ) const;

   // optimization for matrix-vector multiplication
   void updateVectorProductPrefetchPattern();
   // Optimization for matrix-vector multiplication:
   // - communication pattern matrix is an nproc x nproc binary matrix C, where
   //   C_ij = 1 iff the i-th process needs data from the j-th process
   // - assembly of the i-th row involves traversal of the local matrix stored
   //   in the i-th process
   // - assembly the full matrix needs all-to-all communication
   template< typename Partitioner >
   void updateVectorProductCommunicationPattern();

   // multiplication with a distributed vector
   template< typename RealIn,
   // (not const because it modifies internal bufers)
   template< typename Partitioner,
             typename RealIn,
             typename RealOut >
   void vectorProduct( const DistVector< RealIn >& inVector,
                       DistVector< RealOut >& outVector ) const;
                       DistVector< RealOut >& outVector );

protected:
   IndexMap rowIndexMap;
   CommunicationGroup group = Communicator::NullGroup;
   Matrix localMatrix;

   void resetBuffers()
   {
      commPattern.reset();
      globalBuffer.reset();
      commRequests.clear();
   }

   // communication pattern for matrix-vector product
   // TODO: probably should be stored elsewhere
   Matrices::Dense< bool, Devices::Host, int > commPattern;

   // span of rows with only block-diagonal entries
   std::pair< IndexType, IndexType > localOnlySpan;

   // global buffer for operations such as distributed matrix-vector multiplication
   // TODO: probably should be stored elsewhere
   Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;

   // buffer for asynchronous communication requests
   // TODO: probably should be stored elsewhere
   std::vector< typename CommunicatorType::Request > commRequests;

private:
   // TODO: disabled until they are implemented
   using Object::save;
+186 −1
Original line number Diff line number Diff line
@@ -14,6 +14,11 @@

#include "DistributedMatrix.h"

#include <TNL/Atomic.h>
#include <TNL/ParallelFor.h>
#include <TNL/Pointers/DevicePointer.h>
#include <TNL/Containers/VectorView.h>

namespace TNL {
namespace DistributedContainers {

@@ -37,6 +42,8 @@ setDistribution( IndexMap rowIndexMap, IndexType columns, CommunicationGroup gro
   this->group = group;
   if( group != Communicator::NullGroup )
      localMatrix.setDimensions( rowIndexMap.getLocalSize(), columns );

   resetBuffers();
}

template< typename Matrix,
@@ -135,6 +142,8 @@ setLike( const MatrixT& matrix )
   rowIndexMap = matrix.getRowIndexMap();
   group = matrix.getCommunicationGroup();
   localMatrix.setLike( matrix.getLocalMatrix() );

   resetBuffers();
}

template< typename Matrix,
@@ -147,6 +156,8 @@ reset()
   rowIndexMap.reset();
   group = Communicator::NullGroup;
   localMatrix.reset();

   resetBuffers();
}

template< typename Matrix,
@@ -182,8 +193,11 @@ setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
   TNL_ASSERT_EQ( rowLengths.getIndexMap(), getRowIndexMap(), "row lengths vector has wrong distribution" );
   TNL_ASSERT_EQ( rowLengths.getCommunicationGroup(), getCommunicationGroup(), "row lengths vector has wrong communication group" );

   if( getCommunicationGroup() != CommunicatorType::NullGroup )
   if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
      localMatrix.setCompressedRowLengths( rowLengths.getLocalVectorView() );

      resetBuffers();
   }
}

template< typename Matrix,
@@ -334,5 +348,176 @@ vectorProduct( const Vector& inVector,
   localMatrix.vectorProduct( inVector, outView );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
   template< typename Partitioner >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
updateVectorProductCommunicationPattern()
{
   if( getCommunicationGroup() == CommunicatorType::NullGroup )
      return;

   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );
   commPattern.setDimensions( nproc, nproc );

   // pass the localMatrix to the device
   Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );

   // buffer for the local row of the commPattern matrix
//   using AtomicBool = Atomic< bool, DeviceType >;
   // FIXME: CUDA does not support atomic operations for bool
   using AtomicBool = Atomic< int, DeviceType >;
   Containers::Array< AtomicBool, DeviceType > buffer( nproc );
   buffer.setValue( false );

   // optimization for banded matrices
   using AtomicIndex = Atomic< IndexType, DeviceType >;
   Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
   local_span.setElement( 0, 0 );  // span start
   local_span.setElement( 1, localMatrix.getRows() );  // span end

   auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
                                         AtomicBool* buffer, AtomicIndex* local_span )
   {
      const IndexType columns = localMatrix->getColumns();
      const auto row = localMatrix->getRow( i );
      bool comm_left = false;
      bool comm_right = false;
      for( IndexType c = 0; c < row.getLength(); c++ ) {
         const IndexType j = row.getElementColumn( c );
         if( j < columns ) {
            const int owner = Partitioner::getOwner( j, columns, nproc );
            // atomic assignment
            buffer[ owner ].store( true );
            // update comm_left/Right
            if( owner < rank )
               comm_left = true;
            if( owner > rank )
               comm_right = true;
         }
      }
      // update local span
      if( comm_left )
         local_span[0].fetch_max( i + 1 );
      if( comm_right )
         local_span[1].fetch_min( i );
   };

   ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
                                    kernel,
                                    &localMatrixPointer.template getData< DeviceType >(),
                                    buffer.getData(),
                                    local_span.getData()
                                 );

   // set the local-only span (optimization for banded matrices)
   localOnlySpan.first = local_span.getElement( 0 );
   localOnlySpan.second = local_span.getElement( 1 );

   // copy the buffer into all rows of the preCommPattern matrix
   Matrices::Dense< bool, Devices::Host, int > preCommPattern;
   preCommPattern.setLike( commPattern );
   for( int j = 0; j < nproc; j++ )
   for( int i = 0; i < nproc; i++ )
      preCommPattern.setElementFast( j, i, buffer.getElement( i ) );

   // assemble the commPattern matrix
   CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
                               &commPattern(0, 0), nproc,
                               getCommunicationGroup() );
}

template< typename Matrix,
          typename Communicator,
          typename IndexMap >
   template< typename Partitioner,
             typename RealIn,
             typename RealOut >
void
DistributedMatrix< Matrix, Communicator, IndexMap >::
vectorProduct( const DistVector< RealIn >& inVector,
               DistVector< RealOut >& outVector )
{
   TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
   TNL_ASSERT_EQ( inVector.getIndexMap(), getRowIndexMap(), "input vector has wrong distribution" );
   TNL_ASSERT_EQ( inVector.getCommunicationGroup(), getCommunicationGroup(), "input vector has wrong communication group" );
   TNL_ASSERT_EQ( outVector.getSize(), getRows(), "output vector has wrong size" );
   TNL_ASSERT_EQ( outVector.getIndexMap(), getRowIndexMap(), "output vector has wrong distribution" );
   TNL_ASSERT_EQ( outVector.getCommunicationGroup(), getCommunicationGroup(), "output vector has wrong communication group" );

   if( getCommunicationGroup() == CommunicatorType::NullGroup )
      return;

   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );

   // update communication pattern
   if( commPattern.getRows() != nproc )
      updateVectorProductCommunicationPattern< Partitioner >();

   // prepare buffers
   globalBuffer.setSize( localMatrix.getColumns() );
   commRequests.clear();

   // send our data to all processes that need it
   for( int i = 0; i < commPattern.getRows(); i++ )
      if( commPattern( i, rank ) )
         commRequests.push_back( CommunicatorType::ISend(
                  inVector.getLocalVectorView().getData(),
                  inVector.getLocalVectorView().getSize(),
                  i, getCommunicationGroup() ) );

   // receive data that we need
   for( int j = 0; j < commPattern.getRows(); j++ )
      if( commPattern( rank, j ) )
         commRequests.push_back( CommunicatorType::IRecv(
                  &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
                  Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
                  j, getCommunicationGroup() ) );

   // general variant
   if( localOnlySpan.first >= localOnlySpan.second ) {
      // wait for all communications to finish
      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

      // perform matrix-vector multiplication
      vectorProduct( globalBuffer, outVector );
   }
   // optimization for banded matrices
   else {
      Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );
      auto outVectorView = outVector.getLocalVectorView();
      // TODO
//      const auto inVectorView = DistributedVectorView( inVector );
      Pointers::DevicePointer< const DistVector< RealIn > > inVectorPointer( inVector );

      // matrix-vector multiplication using local-only rows
      auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix, const DistVector< RealIn >* inVector ) mutable
      {
         outVectorView[ i ] = localMatrix->rowVectorProduct( i, *inVector );
      };
      ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
                                       &localMatrixPointer.template getData< DeviceType >(),
                                       &inVectorPointer.template getData< DeviceType >() );

      // wait for all communications to finish
      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

      // finish the multiplication by adding the non-local entries
      Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
      auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
      {
         outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
      };
      ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
                                       &localMatrixPointer.template getData< DeviceType >() );
      ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
                                       &localMatrixPointer.template getData< DeviceType >() );
   }
}

} // namespace DistributedContainers
} // namespace TNL
+23 −0
Original line number Diff line number Diff line
@@ -42,6 +42,29 @@ public:
      else
         return IndexMap( 0, 0, globalSize );
   }

   // Gets the owner of given global index.
   __cuda_callable__
   static int getOwner( Index i, Index globalSize, int partitions )
   {
      return i * partitions / globalSize;
   }

   // Gets the offset of data for given rank.
   __cuda_callable__
   static Index getOffset( Index globalSize, int rank, int partitions )
   {
      return rank * globalSize / partitions;
   }

   // Gets the size of data assigned to given rank.
   __cuda_callable__
   static Index getSizeForRank( Index globalSize, int rank, int partitions )
   {
      const Index begin = min( globalSize, rank * globalSize / partitions );
      const Index end = min( globalSize, (rank + 1) * globalSize / partitions );
      return end - begin;
   }
};

} // namespace DistributedContainers
+4 −2
Original line number Diff line number Diff line
@@ -70,6 +70,7 @@ protected:
   using IndexType = typename DistributedMatrix::IndexType;
   using IndexMap = typename DistributedMatrix::IndexMapType;
   using DistributedMatrixType = DistributedMatrix;
   using Partitioner = DistributedContainers::Partitioner< IndexMap, CommunicatorType >;

   using RowLengthsVector = typename DistributedMatrixType::CompressedRowLengthsVector;
   using GlobalVector = Containers::Vector< RealType, DeviceType, IndexType >;
@@ -88,7 +89,7 @@ protected:

   void SetUp() override
   {
      const IndexMap map = DistributedContainers::Partitioner< IndexMap, CommunicatorType >::splitRange( globalSize, group );
      const IndexMap map = Partitioner::splitRange( globalSize, group );
      matrix.setDistribution( map, globalSize, group );
      rowLengths.setDistribution( map, group );

@@ -220,6 +221,7 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput )
TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
{
   using DistributedVector = typename TestFixture::DistributedVector;
   using Partitioner = typename TestFixture::Partitioner;

   this->matrix.setCompressedRowLengths( this->rowLengths );
   setMatrix( this->matrix, this->rowLengths );
@@ -227,7 +229,7 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
   DistributedVector inVector( this->matrix.getRowIndexMap(), this->matrix.getCommunicationGroup() );
   inVector.setValue( 1 );
   DistributedVector outVector( this->matrix.getRowIndexMap(), this->matrix.getCommunicationGroup() );
   this->matrix.vectorProduct( inVector, outVector );
   this->matrix.template vectorProduct< Partitioner >( inVector, outVector );

   EXPECT_EQ( outVector, this->rowLengths );
}