Skip to content
Snippets Groups Projects
DistributedSpMV.h 9.06 KiB
Newer Older
  • Learn to ignore specific revisions
  • /***************************************************************************
                              DistributedSpMV.h  -  description
                                 -------------------
        begin                : Sep 20, 2018
        copyright            : (C) 2018 by Tomas Oberhuber et al.
        email                : tomas.oberhuber@fjfi.cvut.cz
     ***************************************************************************/
    
    /* See Copyright Notice in tnl/Copyright */
    
    // Implemented by: Jakub Klinkovský
    
    #pragma once
    
    #include <TNL/DistributedContainers/Partitioner.h>
    #include <TNL/DistributedContainers/DistributedVectorView.h>
    
    // buffers
    #include <vector>
    #include <utility>  // std::pair
    #include <TNL/Matrices/Dense.h>
    #include <TNL/Containers/Vector.h>
    #include <TNL/Containers/VectorView.h>
    
    // operations
    #include <type_traits>  // std::add_const
    #include <TNL/Atomic.h>
    #include <TNL/ParallelFor.h>
    #include <TNL/Pointers/DevicePointer.h>
    
    namespace TNL {
    namespace DistributedContainers {
    
    template< typename Matrix, typename Communicator >
    class DistributedSpMV
    {
    public:
       using MatrixType = Matrix;
       using RealType = typename Matrix::RealType;
       using DeviceType = typename Matrix::DeviceType;
       using IndexType = typename Matrix::IndexType;
       using CommunicatorType = Communicator;
       using CommunicationGroup = typename CommunicatorType::CommunicationGroup;
       using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;
    
       // - communication pattern matrix is an nproc x nproc binary matrix C, where
       //   C_ij = 1 iff the i-th process needs data from the j-th process
       // - assembly of the i-th row involves traversal of the local matrix stored
       //   in the i-th process
       // - assembly the full matrix needs all-to-all communication
       void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group )
       {
          const int rank = CommunicatorType::GetRank( group );
          const int nproc = CommunicatorType::GetSize( group );
          commPattern.setDimensions( nproc, nproc );
    
          // pass the localMatrix to the device
          const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
    
          // buffer for the local row of the commPattern matrix
    //      using AtomicBool = Atomic< bool, DeviceType >;
          // FIXME: CUDA does not support atomic operations for bool
          using AtomicBool = Atomic< int, DeviceType >;
          Containers::Array< AtomicBool, DeviceType > buffer( nproc );
          buffer.setValue( false );
    
          // optimization for banded matrices
          using AtomicIndex = Atomic< IndexType, DeviceType >;
          Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
          local_span.setElement( 0, 0 );  // span start
          local_span.setElement( 1, localMatrix.getRows() );  // span end
    
          auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
                                                AtomicBool* buffer, AtomicIndex* local_span )
          {
             const IndexType columns = localMatrix->getColumns();
             const auto row = localMatrix->getRow( i );
             bool comm_left = false;
             bool comm_right = false;
             for( IndexType c = 0; c < row.getLength(); c++ ) {
                const IndexType j = row.getElementColumn( c );
                if( j < columns ) {
                   const int owner = Partitioner::getOwner( j, columns, nproc );
                   // atomic assignment
                   buffer[ owner ].store( true );
                   // update comm_left/Right
                   if( owner < rank )
                      comm_left = true;
                   if( owner > rank )
                      comm_right = true;
                }
             }
             // update local span
             if( comm_left )
                local_span[0].fetch_max( i + 1 );
             if( comm_right )
                local_span[1].fetch_min( i );
          };
    
          ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
                                           kernel,
                                           &localMatrixPointer.template getData< DeviceType >(),
                                           buffer.getData(),
                                           local_span.getData()
                                        );
    
          // set the local-only span (optimization for banded matrices)
          localOnlySpan.first = local_span.getElement( 0 );
          localOnlySpan.second = local_span.getElement( 1 );
    
          // copy the buffer into all rows of the preCommPattern matrix
          Matrices::Dense< bool, Devices::Host, int > preCommPattern;
          preCommPattern.setLike( commPattern );
          for( int j = 0; j < nproc; j++ )
          for( int i = 0; i < nproc; i++ )
             preCommPattern.setElementFast( j, i, buffer.getElement( i ) );
    
          // assemble the commPattern matrix
          CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
                                      &commPattern(0, 0), nproc,
                                      group );
       }
    
       template< typename InVector,
                 typename OutVector >
       void vectorProduct( OutVector& outVector,
                           const MatrixType& localMatrix,
                           const InVector& inVector,
                           CommunicationGroup group )
       {
          const int rank = CommunicatorType::GetRank( group );
          const int nproc = CommunicatorType::GetSize( group );
    
          // update communication pattern
          if( commPattern.getRows() != nproc )
             updateCommunicationPattern( localMatrix, group );
    
          // prepare buffers
          globalBuffer.setSize( localMatrix.getColumns() );
          commRequests.clear();
    
          // send our data to all processes that need it
          for( int i = 0; i < commPattern.getRows(); i++ )
             if( commPattern( i, rank ) )
                commRequests.push_back( CommunicatorType::ISend(
                         inVector.getLocalVectorView().getData(),
                         inVector.getLocalVectorView().getSize(),
    
    
          // receive data that we need
          for( int j = 0; j < commPattern.getRows(); j++ )
             if( commPattern( rank, j ) )
                commRequests.push_back( CommunicatorType::IRecv(
                         &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
                         Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
    
    
          // general variant
          if( localOnlySpan.first >= localOnlySpan.second ) {
             // wait for all communications to finish
             CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
    
             // perform matrix-vector multiplication
             auto outView = outVector.getLocalVectorView();
             localMatrix.vectorProduct( globalBuffer, outView );
          }
          // optimization for banded matrices
          else {
             auto outVectorView = outVector.getLocalVectorView();
             const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
             using InView = DistributedVectorView< const typename InVector::RealType, typename InVector::DeviceType, typename InVector::IndexType, typename InVector::CommunicatorType >;
             const InView inView( inVector );
    
             // matrix-vector multiplication using local-only rows
             auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
             {
                outVectorView[ i ] = localMatrix->rowVectorProduct( i, inView );
             };
             ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
                                              &localMatrixPointer.template getData< DeviceType >() );
    
             // wait for all communications to finish
             CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
    
             // finish the multiplication by adding the non-local entries
             Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
             auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
             {
                outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
             };
             ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
                                              &localMatrixPointer.template getData< DeviceType >() );
             ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
                                              &localMatrixPointer.template getData< DeviceType >() );
          }
       }
    
       void reset()
       {
          commPattern.reset();
          localOnlySpan.first = localOnlySpan.second = 0;
          globalBuffer.reset();
          commRequests.clear();
       }
    
    protected:
       // communication pattern
       Matrices::Dense< bool, Devices::Host, int > commPattern;
    
       // span of rows with only block-diagonal entries
       std::pair< IndexType, IndexType > localOnlySpan;
    
       // global buffer for non-local elements of the vector
       Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
    
       // buffer for asynchronous communication requests
       std::vector< typename CommunicatorType::Request > commRequests;
    };
    
    } // namespace DistributedContainers
    } // namespace TNL