Commit 859e6374 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized distributed SpMV

parent 8831bce1
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@ SET( headers DistributedArray.h
             DistributedVectorView_impl.h
             Partitioner.h
             Subrange.h
             ThreePartVector.h
    )

INSTALL( FILES ${headers} DESTINATION ${TNL_TARGET_INCLUDE_DIRECTORY}/DistributedContainers )
+64 −37
Original line number Diff line number Diff line
@@ -18,9 +18,11 @@
// buffers
#include <vector>
#include <utility>  // std::pair
#include <limits>   // std::numeric_limits
#include <TNL/Matrices/Dense.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Containers/VectorView.h>
#include <TNL/DistributedContainers/ThreePartVector.h>

// operations
#include <type_traits>  // std::add_const
@@ -43,26 +45,31 @@ public:
   using CommunicationGroup = typename CommunicatorType::CommunicationGroup;
   using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;

   // - communication pattern matrix is an nproc x nproc binary matrix C, where
   //   C_ij = 1 iff the i-th process needs data from the j-th process
   // - communication pattern: vector components whose indices are in the range
   //   [start_ij, end_ij) are copied from the j-th process to the i-th process
   //   (an empty range with start_ij == end_ij indicates that there is no
   //   communication between the i-th and j-th processes)
   // - communication pattern matrices - we need to assemble two nproc x nproc
   //   matrices commPatternStarts and commPatternEnds holding the values
   //   start_ij and end_ij respectively
   // - assembly of the i-th row involves traversal of the local matrix stored
   //   in the i-th process
   // - assembly the full matrix needs all-to-all communication
   // - assembly of the full matrix needs all-to-all communication
   void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group )
   {
      const int rank = CommunicatorType::GetRank( group );
      const int nproc = CommunicatorType::GetSize( group );
      commPattern.setDimensions( nproc, nproc );
      commPatternStarts.setDimensions( nproc, nproc );
      commPatternEnds.setDimensions( nproc, nproc );

      // pass the localMatrix to the device
      const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );

      // buffer for the local row of the commPattern matrix
//      using AtomicBool = Atomic< bool, DeviceType >;
      // FIXME: CUDA does not support atomic operations for bool
      using AtomicBool = Atomic< int, DeviceType >;
      Containers::Array< AtomicBool, DeviceType > buffer( nproc );
      buffer.setValue( false );
      using AtomicIndex = Atomic< IndexType, DeviceType >;
      Containers::Array< AtomicIndex, DeviceType > span_starts( nproc ), span_ends( nproc );
      span_starts.setValue( std::numeric_limits<IndexType>::max() );
      span_ends.setValue( 0 );

      // optimization for banded matrices
      using AtomicIndex = Atomic< IndexType, DeviceType >;
@@ -71,7 +78,7 @@ public:
      local_span.setElement( 1, localMatrix.getRows() );  // span end

      auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
                                            AtomicBool* buffer, AtomicIndex* local_span )
                                            AtomicIndex* span_starts, AtomicIndex* span_ends, AtomicIndex* local_span )
      {
         const IndexType columns = localMatrix->getColumns();
         const auto row = localMatrix->getRow( i );
@@ -82,8 +89,9 @@ public:
            if( j < columns ) {
               const int owner = Partitioner::getOwner( j, columns, nproc );
               // atomic assignment
               buffer[ owner ].store( true );
               // update comm_left/Right
               span_starts[ owner ].fetch_min( j );
               span_ends[ owner ].fetch_max( j + 1 );
               // update comm_left/right
               if( owner < rank )
                  comm_left = true;
               if( owner > rank )
@@ -100,7 +108,8 @@ public:
      ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
                                       kernel,
                                       &localMatrixPointer.template getData< DeviceType >(),
                                       buffer.getData(),
                                       span_starts.getData(),
                                       span_ends.getData(),
                                       local_span.getData()
                                    );

@@ -108,16 +117,19 @@ public:
      localOnlySpan.first = local_span.getElement( 0 );
      localOnlySpan.second = local_span.getElement( 1 );

      // copy the buffer into all rows of the preCommPattern matrix
      Matrices::Dense< bool, Devices::Host, int > preCommPattern;
      preCommPattern.setLike( commPattern );
      // copy the buffer into all rows of the commPattern* matrices
      for( int j = 0; j < nproc; j++ )
      for( int i = 0; i < nproc; i++ )
         preCommPattern.setElementFast( j, i, buffer.getElement( i ) );
      for( int i = 0; i < nproc; i++ ) {
         commPatternStarts.setElementFast( j, i, span_starts.getElement( i ) );
         commPatternEnds.setElementFast( j, i, span_ends.getElement( i ) );
      }

      // assemble the commPattern matrix
      CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
                                  &commPattern(0, 0), nproc,
      // assemble the commPattern* matrices
      CommunicatorType::Alltoall( &commPatternStarts(0, 0), nproc,
                                  &commPatternStarts(0, 0), nproc,
                                  group );
      CommunicatorType::Alltoall( &commPatternEnds(0, 0), nproc,
                                  &commPatternEnds(0, 0), nproc,
                                  group );
   }

@@ -132,28 +144,37 @@ public:
      const int nproc = CommunicatorType::GetSize( group );

      // update communication pattern
      if( commPattern.getRows() != nproc )
      if( commPatternStarts.getRows() != nproc || commPatternEnds.getRows() != nproc )
         updateCommunicationPattern( localMatrix, group );

      // prepare buffers
      globalBuffer.setSize( localMatrix.getColumns() );
      commRequests.clear();
      globalBuffer.init( Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
                         inVector.getLocalVectorView(),
                         localMatrix.getColumns() - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ) - inVector.getLocalVectorView().getSize() );
      const auto globalBufferView = globalBuffer.getConstView();

      // send our data to all processes that need it
      for( int i = 0; i < commPattern.getRows(); i++ )
         if( commPattern( i, rank ) )
      for( int i = 0; i < commPatternStarts.getRows(); i++ ) {
         if( i == rank )
             continue;
         if( commPatternStarts( i, rank ) < commPatternEnds( i, rank ) )
            commRequests.push_back( CommunicatorType::ISend(
                     inVector.getLocalVectorView().getData(),
                     inVector.getLocalVectorView().getSize(),
                     inVector.getLocalVectorView().getData() + commPatternStarts( i, rank ) - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
                     commPatternEnds( i, rank ) - commPatternStarts( i, rank ),
                     i, 0, group ) );
      }

      // receive data that we need
      for( int j = 0; j < commPattern.getRows(); j++ )
         if( commPattern( rank, j ) )
      for( int j = 0; j < commPatternStarts.getRows(); j++ ) {
         if( j == rank )
             continue;
         if( commPatternStarts( rank, j ) < commPatternEnds( rank, j ) )
            commRequests.push_back( CommunicatorType::IRecv(
                     &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
                     Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
                     &globalBuffer[ commPatternStarts( rank, j ) ],
                     commPatternEnds( rank, j ) - commPatternStarts( rank, j ),
                     j, 0, group ) );
      }

      // general variant
      if( localOnlySpan.first >= localOnlySpan.second ) {
@@ -161,8 +182,14 @@ public:
         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

         // perform matrix-vector multiplication
         auto outView = outVector.getLocalVectorView();
         localMatrix.vectorProduct( globalBuffer, outView );
         auto outVectorView = outVector.getLocalVectorView();
         const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
         auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
         {
            outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
         };
         ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel,
                                          &localMatrixPointer.template getData< DeviceType >() );
      }
      // optimization for banded matrices
      else {
@@ -183,7 +210,6 @@ public:
         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );

         // finish the multiplication by adding the non-local entries
         Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
         auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
         {
            outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
@@ -197,7 +223,8 @@ public:

   void reset()
   {
      commPattern.reset();
      commPatternStarts.reset();
      commPatternEnds.reset();
      localOnlySpan.first = localOnlySpan.second = 0;
      globalBuffer.reset();
      commRequests.clear();
@@ -205,13 +232,13 @@ public:

protected:
   // communication pattern
   Matrices::Dense< bool, Devices::Host, int > commPattern;
   Matrices::Dense< IndexType, Devices::Host, int > commPatternStarts, commPatternEnds;

   // span of rows with only block-diagonal entries
   std::pair< IndexType, IndexType > localOnlySpan;

   // global buffer for non-local elements of the vector
   Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
   ThreePartVector< RealType, DeviceType, IndexType > globalBuffer;

   // buffer for asynchronous communication requests
   std::vector< typename CommunicatorType::Request > commRequests;
+157 −0
Original line number Diff line number Diff line
/***************************************************************************
                          ThreePartVector.h  -  description
                             -------------------
    begin                : Dec 19, 2018
    copyright            : (C) 2018 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

// Implemented by: Jakub Klinkovský

#pragma once

#include <TNL/Containers/Vector.h>
#include <TNL/Containers/VectorView.h>

namespace TNL {
namespace DistributedContainers {

template< typename Real,
          typename Device = Devices::Host,
          typename Index = int >
class ThreePartVectorView
{
public:
   using RealType = Real;
   using DeviceType = Device;
   using IndexType = Index;
   using VectorView = Containers::VectorView< Real, Device, Index >;

   ThreePartVectorView() = default;
   ThreePartVectorView( const ThreePartVectorView& ) = default;
   ThreePartVectorView( ThreePartVectorView&& ) = default;

   ThreePartVectorView( VectorView view_left, VectorView view_mid, VectorView view_right )
   {
      bind( view_left, view_mid, view_right );
   }

   void bind( VectorView view_left, VectorView view_mid, VectorView view_right )
   {
      left.bind( view_left );
      middle.bind( view_mid );
      right.bind( view_right );
   }

   void reset()
   {
      left.reset();
      middle.reset();
      right.reset();
   }

//   __cuda_callable__
//   Real& operator[]( Index i )
//   {
//      if( i < left.getSize() )
//         return left[ i ];
//      else if( i < left.getSize() + middle.getSize() )
//         return middle[ i - left.getSize() ];
//      else
//         return right[ i - left.getSize() - middle.getSize() ];
//   }

   __cuda_callable__
   const Real& operator[]( Index i ) const
   {
      if( i < left.getSize() )
         return left[ i ];
      else if( i < left.getSize() + middle.getSize() )
         return middle[ i - left.getSize() ];
      else
         return right[ i - left.getSize() - middle.getSize() ];
   }

   friend std::ostream& operator<<( std::ostream& str, const ThreePartVectorView& v )
   {
      str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]";
      return str;
   }

protected:
   VectorView left, middle, right;
};

template< typename Real,
          typename Device = Devices::Host,
          typename Index = int >
class ThreePartVector
{
   using ConstReal = typename std::add_const< Real >::type;
public:
   using RealType = Real;
   using DeviceType = Device;
   using IndexType = Index;
   using Vector = Containers::Vector< Real, Device, Index >;
   using VectorView = Containers::VectorView< Real, Device, Index >;
   using ConstVectorView = Containers::VectorView< ConstReal, Device, Index >;

   ThreePartVector() = default;
   ThreePartVector( ThreePartVector& ) = default;

   void init( Index size_left, ConstVectorView view_mid, Index size_right )
   {
      left.setSize( size_left );
      middle.bind( view_mid );
      right.setSize( size_right );
   }

   void reset()
   {
      left.reset();
      middle.reset();
      right.reset();
   }

   ThreePartVectorView< ConstReal, Device, Index > getConstView()
   {
      return {left, middle, right};
   }

//   __cuda_callable__
//   Real& operator[]( Index i )
//   {
//      if( i < left.getSize() )
//         return left[ i ];
//      else if( i < left.getSize() + middle.getSize() )
//         return middle[ i - left.getSize() ];
//      else
//         return right[ i - left.getSize() - middle.getSize() ];
//   }

   __cuda_callable__
   const Real& operator[]( Index i ) const
   {
      if( i < left.getSize() )
         return left[ i ];
      else if( i < left.getSize() + middle.getSize() )
         return middle[ i - left.getSize() ];
      else
         return right[ i - left.getSize() - middle.getSize() ];
   }

   friend std::ostream& operator<<( std::ostream& str, const ThreePartVector& v )
   {
      str << "[\n\tleft: " << v.left << ",\n\tmiddle: " << v.middle << ",\n\tright: " << v.right << "\n]";
      return str;
   }

protected:
   Vector left, right;
   ConstVectorView middle;
};

} // namespace DistributedContainers
} // namespace TNL
+6 −2
Original line number Diff line number Diff line
@@ -214,7 +214,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput )
   DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() );
   this->matrix.vectorProduct( inVector, outVector );

   EXPECT_EQ( outVector, this->rowLengths );
   EXPECT_EQ( outVector, this->rowLengths )
      << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView()
      << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView();
}

TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
@@ -229,7 +231,9 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
   DistributedVector outVector( this->matrix.getLocalRowRange(), this->globalSize, this->matrix.getCommunicationGroup() );
   this->matrix.vectorProduct( inVector, outVector );

   EXPECT_EQ( outVector, this->rowLengths );
   EXPECT_EQ( outVector, this->rowLengths )
      << "outVector.getLocalVectorView() = " << outVector.getLocalVectorView()
      << ",\nthis->rowLengths.getLocalVectorView() = " << this->rowLengths.getLocalVectorView();
}

#endif  // HAVE_GTEST