Commit 95c7ff96 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed DistributedSpMV for unstructured partitionings

parent b2bd2047
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -304,7 +304,7 @@ vectorProduct( const InVector& inVector,
   if( getCommunicationGroup() == CommunicatorType::NullGroup )
      return;

   const_cast< DistributedMatrix* >( this )->spmv.vectorProduct( outVector, localMatrix, inVector, getCommunicationGroup() );
   const_cast< DistributedMatrix* >( this )->spmv.vectorProduct( outVector, localMatrix, localRowRange, inVector, getCommunicationGroup() );
}

template< typename Matrix,
+40 −11
Original line number Diff line number Diff line
@@ -12,7 +12,6 @@

#pragma once

#include <TNL/Containers/Partitioner.h>
#include <TNL/Containers/DistributedVectorView.h>

// buffers
@@ -44,7 +43,7 @@ public:
   using IndexType = typename Matrix::IndexType;
   using CommunicatorType = Communicator;
   using CommunicationGroup = typename CommunicatorType::CommunicationGroup;
   using Partitioner = Containers::Partitioner< typename Matrix::IndexType, Communicator >;
   using LocalRangeType = Containers::Subrange< typename Matrix::IndexType >;

   // - communication pattern: vector components whose indices are in the range
   //   [start_ij, end_ij) are copied from the j-th process to the i-th process
@@ -56,13 +55,32 @@ public:
   // - assembly of the i-th row involves traversal of the local matrix stored
   //   in the i-th process
   // - assembly of the full matrix needs all-to-all communication
   void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group )
   void updateCommunicationPattern( const MatrixType& localMatrix, const LocalRangeType& localRowRange, CommunicationGroup group )
   {
      const int rank = CommunicatorType::GetRank( group );
      const int nproc = CommunicatorType::GetSize( group );
      commPatternStarts.setDimensions( nproc, nproc );
      commPatternEnds.setDimensions( nproc, nproc );

      // exchange global offsets (i.e. beginnings of the local ranges) so that each rank can determine the owner of each index
      Containers::Array< IndexType, DeviceType, int > globalOffsets( nproc );
      {
         Containers::Array< IndexType, Devices::Host, int > sendbuf( nproc );
         sendbuf.setValue( localRowRange.getBegin() );
         CommunicatorType::Alltoall( sendbuf.getData(), 1,
                                     globalOffsets.getData(), 1,
                                     group );
      }
      const auto globalOffsetsView = globalOffsets.getConstView();
      auto getOwner = [=] __cuda_callable__ ( IndexType global_idx ) -> int
      {
         const int nproc = globalOffsetsView.getSize();
         for( int i = 0; i < nproc - 1; i++ )
            if( globalOffsetsView[ i ] <= global_idx && global_idx < globalOffsetsView[ i + 1 ] )
               return i;
         return nproc - 1;
      };

      // pass the localMatrix to the device
      const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );

@@ -90,7 +108,7 @@ public:
            if( j == localMatrix->getPaddingIndex() )
               continue;
            if( j < columns ) {
               const int owner = Partitioner::getOwner( j, columns, nproc );
               const int owner = getOwner( j );
               // atomic assignment
               span_starts[ owner ].fetch_min( j );
               span_ends[ owner ].fetch_max( j + 1 );
@@ -144,21 +162,32 @@ public:
             typename OutVector >
   void vectorProduct( OutVector& outVector,
                       const MatrixType& localMatrix,
                       const LocalRangeType& localRowRange,
                       const InVector& inVector,
                       CommunicationGroup group )
   {
      const int rank = CommunicatorType::GetRank( group );
      const int nproc = CommunicatorType::GetSize( group );

      // handle trivial case
      if( nproc == 1 ) {
         const auto inVectorView = inVector.getConstLocalView();
         auto outVectorView = outVector.getLocalView();
         localMatrix.vectorProduct( inVectorView, outVectorView );
         return;
      }

      // update communication pattern
      if( commPatternStarts.getRows() != nproc || commPatternEnds.getRows() != nproc )
         updateCommunicationPattern( localMatrix, group );
         updateCommunicationPattern( localMatrix, localRowRange, group );

      // prepare buffers
      globalBuffer.init( Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
      globalBuffer.init( localRowRange.getBegin(),
                         inVector.getConstLocalView(),
                         localMatrix.getColumns() - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ) - inVector.getConstLocalView().getSize() );
      const auto globalBufferView = globalBuffer.getConstView();
                         localMatrix.getColumns() - localRowRange.getBegin() - inVector.getConstLocalView().getSize() );

      TNL_ASSERT_EQ( outVector.getLocalView().getSize(), localMatrix.getRows(), "the output vector size does not match the number of matrix rows" );
      TNL_ASSERT_EQ( globalBuffer.getSize(), localMatrix.getColumns(), "the global buffer size does not match the number of matrix columns" );

      // buffer for asynchronous communication requests
      std::vector< typename CommunicatorType::Request > commRequests;
@@ -169,7 +198,7 @@ public:
             continue;
         if( commPatternStarts( i, rank ) < commPatternEnds( i, rank ) )
            commRequests.push_back( CommunicatorType::ISend(
                     inVector.getConstLocalView().getData() + commPatternStarts( i, rank ) - Partitioner::getOffset( localMatrix.getColumns(), rank, nproc ),
                     inVector.getConstLocalView().getData() + commPatternStarts( i, rank ) - localRowRange.getBegin(),
                     commPatternEnds( i, rank ) - commPatternStarts( i, rank ),
                     i, 0, group ) );
      }
@@ -205,8 +234,8 @@ public:
         CommunicatorType::WaitAll( commRequests.data(), commRequests.size() );

         // finish the multiplication by adding the non-local entries
         localMatrix.vectorProduct( globalBufferView, outVectorView, 1.0, 0.0, 0, localOnlySpan.first );
         localMatrix.vectorProduct( globalBufferView, outVectorView, 1.0, 0.0, localOnlySpan.second, localMatrix.getRows() );
         localMatrix.vectorProduct( globalBuffer, outVectorView, 1.0, 0.0, 0, localOnlySpan.first );
         localMatrix.vectorProduct( globalBuffer, outVectorView, 1.0, 0.0, localOnlySpan.second, localMatrix.getRows() );
      }
   }