Skip to content
Snippets Groups Projects
DistributedSpMV.h 11.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • /***************************************************************************
                              DistributedSpMV.h  -  description
                                 -------------------
        begin                : Sep 20, 2018
        copyright            : (C) 2018 by Tomas Oberhuber et al.
        email                : tomas.oberhuber@fjfi.cvut.cz
     ***************************************************************************/
    
    /* See Copyright Notice in tnl/Copyright */
    
    // Implemented by: Jakub Klinkovský
    
    #pragma once
    
    
    #include <TNL/Containers/Partitioner.h>
    #include <TNL/Containers/DistributedVectorView.h>
    
    
    // buffers
    #include <vector>
    #include <utility>  // std::pair
    
    #include <limits>   // std::numeric_limits
    
    #include <TNL/Matrices/Dense.h>
    #include <TNL/Containers/Vector.h>
    #include <TNL/Containers/VectorView.h>
    
    #include <TNL/Matrices/ThreePartVector.h>
    
    
    // operations
    #include <type_traits>  // std::add_const
    #include <TNL/Atomic.h>
    
    #include <TNL/Algorithms/ParallelFor.h>
    
    #include <TNL/Pointers/DevicePointer.h>
    
    namespace TNL {
    
    
    template< typename Matrix, typename Communicator >
    class DistributedSpMV
    {
    public:
       using MatrixType = Matrix;
       using RealType = typename Matrix::RealType;
       using DeviceType = typename Matrix::DeviceType;
       using IndexType = typename Matrix::IndexType;
       using CommunicatorType = Communicator;
       using CommunicationGroup = typename CommunicatorType::CommunicationGroup;
    
       using Partitioner = Containers::Partitioner< typename Matrix::IndexType, Communicator >;
    
       // - communication pattern: vector components whose indices are in the range
       //   [start_ij, end_ij) are copied from the j-th process to the i-th process
       //   (an empty range with start_ij == end_ij indicates that there is no
       //   communication between the i-th and j-th processes)
       // - communication pattern matrices - we need to assemble two nproc x nproc
       //   matrices commPatternStarts and commPatternEnds holding the values
       //   start_ij and end_ij respectively
    
       // - assembly of the i-th row involves traversal of the local matrix stored
       //   in the i-th process
    
       // - assembly of the full matrix needs all-to-all communication
    
       void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group )
       {
          const int rank = CommunicatorType::GetRank( group );
          const int nproc = CommunicatorType::GetSize( group );
    
          commPatternStarts.setDimensions( nproc, nproc );
          commPatternEnds.setDimensions( nproc, nproc );
    
    
          // pass the localMatrix to the device
          const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
    
          // buffer for the local row of the commPattern matrix
    
          using AtomicIndex = Atomic< IndexType, DeviceType >;
          Containers::Array< AtomicIndex, DeviceType > span_starts( nproc ), span_ends( nproc );
          span_starts.setValue( std::numeric_limits<IndexType>::max() );
          span_ends.setValue( 0 );
    
    
          // optimization for banded matrices
          using AtomicIndex = Atomic< IndexType, DeviceType >;
          Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
          local_span.setElement( 0, 0 );  // span start
          local_span.setElement( 1, localMatrix.getRows() );  // span end
    
          auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
    
                                                AtomicIndex* span_starts, AtomicIndex* span_ends, AtomicIndex* local_span )
    
          {
             const IndexType columns = localMatrix->getColumns();
             const auto row = localMatrix->getRow( i );
             bool comm_left = false;
             bool comm_right = false;
             for( IndexType c = 0; c < row.getLength(); c++ ) {
                const IndexType j = row.getElementColumn( c );
                if( j < columns ) {
                   const int owner = Partitioner::getOwner( j, columns, nproc );
                   // atomic assignment
    
                   span_starts[ owner ].fetch_min( j );
                   span_ends[ owner ].fetch_max( j + 1 );
                   // update comm_left/right
    
                   if( owner < rank )
                      comm_left = true;
                   if( owner > rank )
                      comm_right = true;
                }
             }
             // update local span
             if( comm_left )
                local_span[0].fetch_max( i + 1 );
             if( comm_right )
                local_span[1].fetch_min( i );
          };
    
    
          Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
                                                       kernel,
                                                       &localMatrixPointer.template getData< DeviceType >(),
                                                       span_starts.getData(),
                                                       span_ends.getData(),
                                                       local_span.getData()
                                                    );
    
    
          // set the local-only span (optimization for banded matrices)
          localOnlySpan.first = local_span.getElement( 0 );
          localOnlySpan.second = local_span.getElement( 1 );
    
    
          // copy the buffer into all rows of the preCommPattern* matrices
          // (in-place copy does not work with some OpenMPI configurations)
          Matrices::Dense< IndexType, Devices::Host, int > preCommPatternStarts, preCommPatternEnds;
          preCommPatternStarts.setLike( commPatternStarts );
          preCommPatternEnds.setLike( commPatternEnds );
    
          for( int j = 0; j < nproc; j++ )
    
          for( int i = 0; i < nproc; i++ ) {
    
             preCommPatternStarts.setElement( j, i, span_starts.getElement( i ) );
             preCommPatternEnds.setElement( j, i, span_ends.getElement( i ) );
    
          // assemble the commPattern* matrices
    
          CommunicatorType::Alltoall( &preCommPatternStarts(0, 0), nproc,
    
                                      &commPatternStarts(0, 0), nproc,
                                      group );
    
          CommunicatorType::Alltoall( &preCommPatternEnds(0, 0), nproc,
    
                                      &commPatternEnds(0, 0), nproc,
    
                                      group );
       }
    
       template< typename InVector,
                 typename OutVector >
       void vectorProduct( OutVector& outVector,
                           const MatrixType& localMatrix,
                           const InVector& inVector,
                           CommunicationGroup group )
       {
          const int rank = CommunicatorType::GetRank( group );
          const int nproc = CommunicatorType::GetSize( group );
    
          // update communication pattern
    
          if( commPatternStarts.getRows() != nproc || commPatternEnds.getRows() != nproc )
    
             updateCommunicationPattern( localMatrix, group );
    
          // prepare buffers
    
          globalBuffer.init( Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
    
                             inVector.getConstLocalView(),
                             localMatrix.getColumns() - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ) - inVector.getConstLocalView().getSize() );
    
          const auto globalBufferView = globalBuffer.getConstView();
    
    Jakub Klinkovský's avatar
    Jakub Klinkovský committed
          // buffer for asynchronous communication requests
          std::vector< typename CommunicatorType::Request > commRequests;
    
    
          // send our data to all processes that need it
    
          for( int i = 0; i < commPatternStarts.getRows(); i++ ) {
             if( i == rank )
                 continue;
             if( commPatternStarts( i, rank ) < commPatternEnds( i, rank ) )
    
                commRequests.push_back( CommunicatorType::ISend(
    
                         inVector.getConstLocalView().getData() + commPatternStarts( i, rank ) - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
    
                         commPatternEnds( i, rank ) - commPatternStarts( i, rank ),
    
    
          // receive data that we need
    
          for( int j = 0; j < commPatternStarts.getRows(); j++ ) {
             if( j == rank )
                 continue;
             if( commPatternStarts( rank, j ) < commPatternEnds( rank, j ) )
    
                commRequests.push_back( CommunicatorType::IRecv(
    
                         globalBuffer.getPointer( commPatternStarts( rank, j ) ),
    
                         commPatternEnds( rank, j ) - commPatternStarts( rank, j ),
    
    
          // general variant
          if( localOnlySpan.first >= localOnlySpan.second ) {
             // wait for all communications to finish
             CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
    
             // perform matrix-vector multiplication
    
             auto outVectorView = outVector.getLocalView();
    
             const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
             auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
             {
                outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
             };
    
             Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(), kernel,
                                                          &localMatrixPointer.template getData< DeviceType >() );
    
          }
          // optimization for banded matrices
          else {
    
             auto outVectorView = outVector.getLocalView();
    
             const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
    
             const auto inView = inVector.getConstView();
    
    
             // matrix-vector multiplication using local-only rows
             auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
             {
                outVectorView[ i ] = localMatrix->rowVectorProduct( i, inView );
             };
    
             Algorithms::ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
                                                          &localMatrixPointer.template getData< DeviceType >() );
    
    
             // wait for all communications to finish
             CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
    
             // finish the multiplication by adding the non-local entries
             auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
             {
                outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
             };
    
             Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
                                                          &localMatrixPointer.template getData< DeviceType >() );
             Algorithms::ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
                                                          &localMatrixPointer.template getData< DeviceType >() );
    
          commPatternStarts.reset();
          commPatternEnds.reset();
    
          localOnlySpan.first = localOnlySpan.second = 0;
          globalBuffer.reset();
       }
    
    protected:
       // communication pattern
    
    Jakub Klinkovský's avatar
    Jakub Klinkovský committed
       Matrices::Dense< IndexType, Devices::Host, int > commPatternStarts, commPatternEnds;
    
    
       // span of rows with only block-diagonal entries
       std::pair< IndexType, IndexType > localOnlySpan;
    
       // global buffer for non-local elements of the vector
    
       __DistributedSpMV_impl::ThreePartVector< RealType, DeviceType, IndexType > globalBuffer;
    
    } // namespace TNL