From 9bbe9ac8c3d5ef21df1abdeea65b323198e7a155 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Thu, 20 Sep 2018 16:20:21 +0200
Subject: [PATCH] Refactoring distributed SpMV

---
 .../DistributedContainers/DistributedMatrix.h |  84 +++----
 .../DistributedMatrix_impl.h                  | 174 ++------------
 .../DistributedContainers/DistributedSpMV.h   | 221 ++++++++++++++++++
 3 files changed, 274 insertions(+), 205 deletions(-)
 create mode 100644 src/TNL/DistributedContainers/DistributedSpMV.h

diff --git a/src/TNL/DistributedContainers/DistributedMatrix.h b/src/TNL/DistributedContainers/DistributedMatrix.h
index cdf2c4dfe2..52d98b2642 100644
--- a/src/TNL/DistributedContainers/DistributedMatrix.h
+++ b/src/TNL/DistributedContainers/DistributedMatrix.h
@@ -12,23 +12,33 @@
 
 #pragma once
 
-#include <type_traits>  // std::add_const
+#include <type_traits>
 
 #include <TNL/Matrices/SparseRow.h>
 #include <TNL/Communicators/MpiCommunicator.h>
 #include <TNL/DistributedContainers/Subrange.h>
-#include <TNL/DistributedContainers/Partitioner.h>
 #include <TNL/DistributedContainers/DistributedVector.h>
-
-// buffers for vectorProduct
-#include <vector>
-#include <utility>  // std::pair
-#include <TNL/Matrices/Dense.h>
-#include <TNL/Containers/Vector.h>
+#include <TNL/DistributedContainers/DistributedVectorView.h>
+#include <TNL/DistributedContainers/DistributedSpMV.h>
 
 namespace TNL {
 namespace DistributedContainers {
 
+template< typename T, typename R = void >
+struct enable_if_type
+{
+   using type = R;
+};
+
+template< typename T, typename Enable = void >
+struct has_communicator : std::false_type {};
+
+template< typename T >
+struct has_communicator< T, typename enable_if_type< typename T::CommunicatorType >::type >
+: std::true_type
+{};
+
+
 // TODO: 2D distribution for dense matrices (maybe it should be in different template,
 //       because e.g. setRowFast doesn't make sense for dense matrices)
 template< typename Matrix,
@@ -37,12 +47,6 @@ class DistributedMatrix
 : public Object
 {
    using CommunicationGroup = typename Communicator::CommunicationGroup;
-
-   template< typename Real >
-   using DistVector = DistributedVector< Real, typename Matrix::DeviceType, typename Matrix::IndexType, Communicator >;
-
-   using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;
-
 public:
    using MatrixType = Matrix;
    using RealType = typename Matrix::RealType;
@@ -67,10 +71,13 @@ public:
 
    void setDistribution( LocalRangeType localRowRange, IndexType rows, IndexType columns, CommunicationGroup group = Communicator::AllGroup );
 
+   __cuda_callable__
    const LocalRangeType& getLocalRowRange() const;
 
+   __cuda_callable__
    CommunicationGroup getCommunicationGroup() const;
 
+   __cuda_callable__
    const Matrix& getLocalMatrix() const;
 
 
@@ -141,25 +148,22 @@ public:
    ConstMatrixRow getRow( IndexType row ) const;
 
    // multiplication with a global vector
-   template< typename Vector,
-             typename RealOut >
-   void vectorProduct( const Vector& inVector,
-                       DistVector< RealOut >& outVector ) const;
-
-   // Optimization for matrix-vector multiplication:
-   // - communication pattern matrix is an nproc x nproc binary matrix C, where
-   //   C_ij = 1 iff the i-th process needs data from the j-th process
-   // - assembly of the i-th row involves traversal of the local matrix stored
-   //   in the i-th process
-   // - assembly the full matrix needs all-to-all communication
+   template< typename InVector,
+             typename OutVector >
+   typename std::enable_if< ! has_communicator< InVector >::value >::type
+   vectorProduct( const InVector& inVector,
+                  OutVector& outVector ) const;
+
+   // Optimization for distributed matrix-vector multiplication
    void updateVectorProductCommunicationPattern();
 
    // multiplication with a distributed vector
    // (not const because it modifies internal bufers)
-   template< typename RealIn,
-             typename RealOut >
-   void vectorProduct( const DistVector< RealIn >& inVector,
-                       DistVector< RealOut >& outVector );
+   template< typename InVector,
+             typename OutVector >
+   typename std::enable_if< has_communicator< InVector >::value >::type
+   vectorProduct( const InVector& inVector,
+                  OutVector& outVector ) const;
 
 protected:
    LocalRangeType localRowRange;
@@ -167,27 +171,7 @@ protected:
    CommunicationGroup group = Communicator::NullGroup;
    Matrix localMatrix;
 
-   void resetBuffers()
-   {
-      commPattern.reset();
-      globalBuffer.reset();
-      commRequests.clear();
-   }
-
-   // communication pattern for matrix-vector product
-   // TODO: probably should be stored elsewhere
-   Matrices::Dense< bool, Devices::Host, int > commPattern;
-
-   // span of rows with only block-diagonal entries
-   std::pair< IndexType, IndexType > localOnlySpan;
-
-   // global buffer for operations such as distributed matrix-vector multiplication
-   // TODO: probably should be stored elsewhere
-   Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
-
-   // buffer for asynchronous communication requests
-   // TODO: probably should be stored elsewhere
-   std::vector< typename CommunicatorType::Request > commRequests;
+   DistributedSpMV< Matrix, Communicator > spmv;
 
 private:
    // TODO: disabled until they are implemented
diff --git a/src/TNL/DistributedContainers/DistributedMatrix_impl.h b/src/TNL/DistributedContainers/DistributedMatrix_impl.h
index 9f0c5e4ef5..1fad8486c0 100644
--- a/src/TNL/DistributedContainers/DistributedMatrix_impl.h
+++ b/src/TNL/DistributedContainers/DistributedMatrix_impl.h
@@ -14,11 +14,6 @@
 
 #include "DistributedMatrix.h"
 
-#include <TNL/Atomic.h>
-#include <TNL/ParallelFor.h>
-#include <TNL/Pointers/DevicePointer.h>
-#include <TNL/Containers/VectorView.h>
-
 namespace TNL {
 namespace DistributedContainers {
 
@@ -42,11 +37,12 @@ setDistribution( LocalRangeType localRowRange, IndexType rows, IndexType columns
    if( group != Communicator::NullGroup )
       localMatrix.setDimensions( localRowRange.getSize(), columns );
 
-   resetBuffers();
+   spmv.reset();
 }
 
 template< typename Matrix,
           typename Communicator >
+__cuda_callable__
 const Subrange< typename Matrix::IndexType >&
 DistributedMatrix< Matrix, Communicator >::
 getLocalRowRange() const
@@ -56,6 +52,7 @@ getLocalRowRange() const
 
 template< typename Matrix,
           typename Communicator >
+__cuda_callable__
 typename Communicator::CommunicationGroup
 DistributedMatrix< Matrix, Communicator >::
 getCommunicationGroup() const
@@ -65,6 +62,7 @@ getCommunicationGroup() const
 
 template< typename Matrix,
           typename Communicator >
+__cuda_callable__
 const Matrix&
 DistributedMatrix< Matrix, Communicator >::
 getLocalMatrix() const
@@ -134,7 +132,7 @@ setLike( const MatrixT& matrix )
    group = matrix.getCommunicationGroup();
    localMatrix.setLike( matrix.getLocalMatrix() );
 
-   resetBuffers();
+   spmv.reset();
 }
 
 template< typename Matrix,
@@ -148,7 +146,7 @@ reset()
    group = Communicator::NullGroup;
    localMatrix.reset();
 
-   resetBuffers();
+   spmv.reset();
 }
 
 template< typename Matrix,
@@ -184,7 +182,7 @@ setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
    if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
       localMatrix.setCompressedRowLengths( rowLengths.getLocalVectorView() );
 
-      resetBuffers();
+      spmv.reset();
    }
 }
 
@@ -309,12 +307,12 @@ getRow( IndexType row ) const
 
 template< typename Matrix,
           typename Communicator >
-   template< typename Vector,
-             typename RealOut >
-void
+   template< typename InVector,
+             typename OutVector >
+typename std::enable_if< ! has_communicator< InVector >::value >::type
 DistributedMatrix< Matrix, Communicator >::
-vectorProduct( const Vector& inVector,
-               DistVector< RealOut >& outVector ) const
+vectorProduct( const InVector& inVector,
+               OutVector& outVector ) const
 {
    TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
    TNL_ASSERT_EQ( outVector.getSize(), getRows(), "output vector has wrong size" );
@@ -333,86 +331,17 @@ updateVectorProductCommunicationPattern()
 {
    if( getCommunicationGroup() == CommunicatorType::NullGroup )
       return;
-
-   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
-   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );
-   commPattern.setDimensions( nproc, nproc );
-
-   // pass the localMatrix to the device
-   Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );
-
-   // buffer for the local row of the commPattern matrix
-//   using AtomicBool = Atomic< bool, DeviceType >;
-   // FIXME: CUDA does not support atomic operations for bool
-   using AtomicBool = Atomic< int, DeviceType >;
-   Containers::Array< AtomicBool, DeviceType > buffer( nproc );
-   buffer.setValue( false );
-
-   // optimization for banded matrices
-   using AtomicIndex = Atomic< IndexType, DeviceType >;
-   Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
-   local_span.setElement( 0, 0 );  // span start
-   local_span.setElement( 1, localMatrix.getRows() );  // span end
-
-   auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
-                                         AtomicBool* buffer, AtomicIndex* local_span )
-   {
-      const IndexType columns = localMatrix->getColumns();
-      const auto row = localMatrix->getRow( i );
-      bool comm_left = false;
-      bool comm_right = false;
-      for( IndexType c = 0; c < row.getLength(); c++ ) {
-         const IndexType j = row.getElementColumn( c );
-         if( j < columns ) {
-            const int owner = Partitioner::getOwner( j, columns, nproc );
-            // atomic assignment
-            buffer[ owner ].store( true );
-            // update comm_left/Right
-            if( owner < rank )
-               comm_left = true;
-            if( owner > rank )
-               comm_right = true;
-         }
-      }
-      // update local span
-      if( comm_left )
-         local_span[0].fetch_max( i + 1 );
-      if( comm_right )
-         local_span[1].fetch_min( i );
-   };
-
-   ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
-                                    kernel,
-                                    &localMatrixPointer.template getData< DeviceType >(),
-                                    buffer.getData(),
-                                    local_span.getData()
-                                 );
-
-   // set the local-only span (optimization for banded matrices)
-   localOnlySpan.first = local_span.getElement( 0 );
-   localOnlySpan.second = local_span.getElement( 1 );
-
-   // copy the buffer into all rows of the preCommPattern matrix
-   Matrices::Dense< bool, Devices::Host, int > preCommPattern;
-   preCommPattern.setLike( commPattern );
-   for( int j = 0; j < nproc; j++ )
-   for( int i = 0; i < nproc; i++ )
-      preCommPattern.setElementFast( j, i, buffer.getElement( i ) );
-
-   // assemble the commPattern matrix
-   CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
-                               &commPattern(0, 0), nproc,
-                               getCommunicationGroup() );
+   spmv.updateCommunicationPattern( getLocalMatrix(), getCommunicationGroup() );
 }
 
 template< typename Matrix,
           typename Communicator >
-   template< typename RealIn,
-             typename RealOut >
-void
+   template< typename InVector,
+             typename OutVector >
+typename std::enable_if< has_communicator< InVector >::value >::type
 DistributedMatrix< Matrix, Communicator >::
-vectorProduct( const DistVector< RealIn >& inVector,
-               DistVector< RealOut >& outVector )
+vectorProduct( const InVector& inVector,
+               OutVector& outVector ) const
 {
    TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
    TNL_ASSERT_EQ( inVector.getLocalRange(), getLocalRowRange(), "input vector has wrong distribution" );
@@ -424,72 +353,7 @@ vectorProduct( const DistVector< RealIn >& inVector,
    if( getCommunicationGroup() == CommunicatorType::NullGroup )
       return;
 
-   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
-   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );
-
-   // update communication pattern
-   if( commPattern.getRows() != nproc )
-      updateVectorProductCommunicationPattern();
-
-   // prepare buffers
-   globalBuffer.setSize( localMatrix.getColumns() );
-   commRequests.clear();
-
-   // send our data to all processes that need it
-   for( int i = 0; i < commPattern.getRows(); i++ )
-      if( commPattern( i, rank ) )
-         commRequests.push_back( CommunicatorType::ISend(
-                  inVector.getLocalVectorView().getData(),
-                  inVector.getLocalVectorView().getSize(),
-                  i, getCommunicationGroup() ) );
-
-   // receive data that we need
-   for( int j = 0; j < commPattern.getRows(); j++ )
-      if( commPattern( rank, j ) )
-         commRequests.push_back( CommunicatorType::IRecv(
-                  &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
-                  Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
-                  j, getCommunicationGroup() ) );
-
-   // general variant
-   if( localOnlySpan.first >= localOnlySpan.second ) {
-      // wait for all communications to finish
-      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
-
-      // perform matrix-vector multiplication
-      vectorProduct( globalBuffer, outVector );
-   }
-   // optimization for banded matrices
-   else {
-      Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );
-      auto outVectorView = outVector.getLocalVectorView();
-      // TODO
-//      const auto inVectorView = DistributedVectorView( inVector );
-      Pointers::DevicePointer< const DistVector< RealIn > > inVectorPointer( inVector );
-
-      // matrix-vector multiplication using local-only rows
-      auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix, const DistVector< RealIn >* inVector ) mutable
-      {
-         outVectorView[ i ] = localMatrix->rowVectorProduct( i, *inVector );
-      };
-      ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
-                                       &localMatrixPointer.template getData< DeviceType >(),
-                                       &inVectorPointer.template getData< DeviceType >() );
-
-      // wait for all communications to finish
-      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
-
-      // finish the multiplication by adding the non-local entries
-      Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
-      auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
-      {
-         outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
-      };
-      ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
-                                       &localMatrixPointer.template getData< DeviceType >() );
-      ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
-                                       &localMatrixPointer.template getData< DeviceType >() );
-   }
+   const_cast< DistributedMatrix* >( this )->spmv.vectorProduct( outVector, localMatrix, inVector, getCommunicationGroup() );
 }
 
 } // namespace DistributedContainers
diff --git a/src/TNL/DistributedContainers/DistributedSpMV.h b/src/TNL/DistributedContainers/DistributedSpMV.h
new file mode 100644
index 0000000000..9bc47d7391
--- /dev/null
+++ b/src/TNL/DistributedContainers/DistributedSpMV.h
@@ -0,0 +1,221 @@
+/***************************************************************************
+                          DistributedSpMV.h  -  description
+                             -------------------
+    begin                : Sep 20, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub KlinkovskĂ˝
+
+#pragma once
+
+#include <TNL/DistributedContainers/Partitioner.h>
+#include <TNL/DistributedContainers/DistributedVectorView.h>
+
+// buffers
+#include <vector>
+#include <utility>  // std::pair
+#include <TNL/Matrices/Dense.h>
+#include <TNL/Containers/Vector.h>
+#include <TNL/Containers/VectorView.h>
+
+// operations
+#include <type_traits>  // std::add_const
+#include <TNL/Atomic.h>
+#include <TNL/ParallelFor.h>
+#include <TNL/Pointers/DevicePointer.h>
+
+namespace TNL {
+namespace DistributedContainers {
+
+template< typename Matrix, typename Communicator >
+class DistributedSpMV
+{
+public:
+   using MatrixType = Matrix;
+   using RealType = typename Matrix::RealType;
+   using DeviceType = typename Matrix::DeviceType;
+   using IndexType = typename Matrix::IndexType;
+   using CommunicatorType = Communicator;
+   using CommunicationGroup = typename CommunicatorType::CommunicationGroup;
+   using Partitioner = DistributedContainers::Partitioner< typename Matrix::IndexType, Communicator >;
+
+   // - communication pattern matrix is an nproc x nproc binary matrix C, where
+   //   C_ij = 1 iff the i-th process needs data from the j-th process
+   // - assembly of the i-th row involves traversal of the local matrix stored
+   //   in the i-th process
+   // - assembly the full matrix needs all-to-all communication
+   void updateCommunicationPattern( const MatrixType& localMatrix, CommunicationGroup group )
+   {
+      const int rank = CommunicatorType::GetRank( group );
+      const int nproc = CommunicatorType::GetSize( group );
+      commPattern.setDimensions( nproc, nproc );
+
+      // pass the localMatrix to the device
+      const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
+
+      // buffer for the local row of the commPattern matrix
+//      using AtomicBool = Atomic< bool, DeviceType >;
+      // FIXME: CUDA does not support atomic operations for bool
+      using AtomicBool = Atomic< int, DeviceType >;
+      Containers::Array< AtomicBool, DeviceType > buffer( nproc );
+      buffer.setValue( false );
+
+      // optimization for banded matrices
+      using AtomicIndex = Atomic< IndexType, DeviceType >;
+      Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
+      local_span.setElement( 0, 0 );  // span start
+      local_span.setElement( 1, localMatrix.getRows() );  // span end
+
+      auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
+                                            AtomicBool* buffer, AtomicIndex* local_span )
+      {
+         const IndexType columns = localMatrix->getColumns();
+         const auto row = localMatrix->getRow( i );
+         bool comm_left = false;
+         bool comm_right = false;
+         for( IndexType c = 0; c < row.getLength(); c++ ) {
+            const IndexType j = row.getElementColumn( c );
+            if( j < columns ) {
+               const int owner = Partitioner::getOwner( j, columns, nproc );
+               // atomic assignment
+               buffer[ owner ].store( true );
+               // update comm_left/Right
+               if( owner < rank )
+                  comm_left = true;
+               if( owner > rank )
+                  comm_right = true;
+            }
+         }
+         // update local span
+         if( comm_left )
+            local_span[0].fetch_max( i + 1 );
+         if( comm_right )
+            local_span[1].fetch_min( i );
+      };
+
+      ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
+                                       kernel,
+                                       &localMatrixPointer.template getData< DeviceType >(),
+                                       buffer.getData(),
+                                       local_span.getData()
+                                    );
+
+      // set the local-only span (optimization for banded matrices)
+      localOnlySpan.first = local_span.getElement( 0 );
+      localOnlySpan.second = local_span.getElement( 1 );
+
+      // copy the buffer into all rows of the preCommPattern matrix
+      Matrices::Dense< bool, Devices::Host, int > preCommPattern;
+      preCommPattern.setLike( commPattern );
+      for( int j = 0; j < nproc; j++ )
+      for( int i = 0; i < nproc; i++ )
+         preCommPattern.setElementFast( j, i, buffer.getElement( i ) );
+
+      // assemble the commPattern matrix
+      CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
+                                  &commPattern(0, 0), nproc,
+                                  group );
+   }
+
+   template< typename InVector,
+             typename OutVector >
+   void vectorProduct( OutVector& outVector,
+                       const MatrixType& localMatrix,
+                       const InVector& inVector,
+                       CommunicationGroup group )
+   {
+      const int rank = CommunicatorType::GetRank( group );
+      const int nproc = CommunicatorType::GetSize( group );
+
+      // update communication pattern
+      if( commPattern.getRows() != nproc )
+         updateCommunicationPattern( localMatrix, group );
+
+      // prepare buffers
+      globalBuffer.setSize( localMatrix.getColumns() );
+      commRequests.clear();
+
+      // send our data to all processes that need it
+      for( int i = 0; i < commPattern.getRows(); i++ )
+         if( commPattern( i, rank ) )
+            commRequests.push_back( CommunicatorType::ISend(
+                     inVector.getLocalVectorView().getData(),
+                     inVector.getLocalVectorView().getSize(),
+                     i, group ) );
+
+      // receive data that we need
+      for( int j = 0; j < commPattern.getRows(); j++ )
+         if( commPattern( rank, j ) )
+            commRequests.push_back( CommunicatorType::IRecv(
+                     &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
+                     Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
+                     j, group ) );
+
+      // general variant
+      if( localOnlySpan.first >= localOnlySpan.second ) {
+         // wait for all communications to finish
+         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
+
+         // perform matrix-vector multiplication
+         auto outView = outVector.getLocalVectorView();
+         localMatrix.vectorProduct( globalBuffer, outView );
+      }
+      // optimization for banded matrices
+      else {
+         auto outVectorView = outVector.getLocalVectorView();
+         const Pointers::DevicePointer< const MatrixType > localMatrixPointer( localMatrix );
+         using InView = DistributedVectorView< const typename InVector::RealType, typename InVector::DeviceType, typename InVector::IndexType, typename InVector::CommunicatorType >;
+         const InView inView( inVector );
+
+         // matrix-vector multiplication using local-only rows
+         auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
+         {
+            outVectorView[ i ] = localMatrix->rowVectorProduct( i, inView );
+         };
+         ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
+                                          &localMatrixPointer.template getData< DeviceType >() );
+
+         // wait for all communications to finish
+         CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
+
+         // finish the multiplication by adding the non-local entries
+         Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
+         auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
+         {
+            outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
+         };
+         ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
+                                          &localMatrixPointer.template getData< DeviceType >() );
+         ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
+                                          &localMatrixPointer.template getData< DeviceType >() );
+      }
+   }
+
+   void reset()
+   {
+      commPattern.reset();
+      localOnlySpan.first = localOnlySpan.second = 0;
+      globalBuffer.reset();
+      commRequests.clear();
+   }
+
+protected:
+   // communication pattern
+   Matrices::Dense< bool, Devices::Host, int > commPattern;
+
+   // span of rows with only block-diagonal entries
+   std::pair< IndexType, IndexType > localOnlySpan;
+
+   // global buffer for non-local elements of the vector
+   Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
+
+   // buffer for asynchronous communication requests
+   std::vector< typename CommunicatorType::Request > commRequests;
+};
+
+} // namespace DistributedContainers
+} // namespace TNL
-- 
GitLab