diff --git a/src/TNL/DistributedContainers/DistributedMatrix.h b/src/TNL/DistributedContainers/DistributedMatrix.h
index b936f0fc3b41b90fa1e214eb286648eb19c9ab15..7fde46154025a9d74a2338edb5e24a348dda7761 100644
--- a/src/TNL/DistributedContainers/DistributedMatrix.h
+++ b/src/TNL/DistributedContainers/DistributedMatrix.h
@@ -14,12 +14,17 @@
 
 #include <type_traits>  // std::add_const
 
-#include <TNL/Containers/Vector.h>
 #include <TNL/Matrices/SparseRow.h>
 #include <TNL/Communicators/MpiCommunicator.h>
 #include <TNL/DistributedContainers/IndexMap.h>
 #include <TNL/DistributedContainers/DistributedVector.h>
 
+// buffers for vectorProduct
+#include <vector>
+#include <utility>  // std::pair
+#include <TNL/Matrices/Dense.h>
+#include <TNL/Containers/Vector.h>
+
 namespace TNL {
 namespace DistributedContainers {
 
@@ -139,20 +144,50 @@ public:
    void vectorProduct( const Vector& inVector,
                        DistVector< RealOut >& outVector ) const;
 
-   // optimization for matrix-vector multiplication
-   void updateVectorProductPrefetchPattern();
+   // Optimization for matrix-vector multiplication:
+   // - communication pattern matrix is an nproc x nproc binary matrix C, where
+   //   C_ij = 1 iff the i-th process needs data from the j-th process
+   // - assembly of the i-th row involves traversal of the local matrix stored
+   //   in the i-th process
+   // - assembly the full matrix needs all-to-all communication
+   template< typename Partitioner >
+   void updateVectorProductCommunicationPattern();
 
    // multiplication with a distributed vector
-   template< typename RealIn,
+   // (not const because it modifies internal bufers)
+   template< typename Partitioner,
+             typename RealIn,
              typename RealOut >
    void vectorProduct( const DistVector< RealIn >& inVector,
-                       DistVector< RealOut >& outVector ) const;
+                       DistVector< RealOut >& outVector );
 
 protected:
    IndexMap rowIndexMap;
    CommunicationGroup group = Communicator::NullGroup;
    Matrix localMatrix;
 
+   void resetBuffers()
+   {
+      commPattern.reset();
+      globalBuffer.reset();
+      commRequests.clear();
+   }
+
+   // communication pattern for matrix-vector product
+   // TODO: probably should be stored elsewhere
+   Matrices::Dense< bool, Devices::Host, int > commPattern;
+
+   // span of rows with only block-diagonal entries
+   std::pair< IndexType, IndexType > localOnlySpan;
+
+   // global buffer for operations such as distributed matrix-vector multiplication
+   // TODO: probably should be stored elsewhere
+   Containers::Vector< RealType, DeviceType, IndexType > globalBuffer;
+
+   // buffer for asynchronous communication requests
+   // TODO: probably should be stored elsewhere
+   std::vector< typename CommunicatorType::Request > commRequests;
+
 private:
    // TODO: disabled until they are implemented
    using Object::save;
diff --git a/src/TNL/DistributedContainers/DistributedMatrix_impl.h b/src/TNL/DistributedContainers/DistributedMatrix_impl.h
index 805dbd22d58fe205dc396d07d91bf663b4c295c9..8aecf91a7040a661b964e810a573f4a2f72522b6 100644
--- a/src/TNL/DistributedContainers/DistributedMatrix_impl.h
+++ b/src/TNL/DistributedContainers/DistributedMatrix_impl.h
@@ -14,6 +14,11 @@
 
 #include "DistributedMatrix.h"
 
+#include <TNL/Atomic.h>
+#include <TNL/ParallelFor.h>
+#include <TNL/Pointers/DevicePointer.h>
+#include <TNL/Containers/VectorView.h>
+
 namespace TNL {
 namespace DistributedContainers {
 
@@ -37,6 +42,8 @@ setDistribution( IndexMap rowIndexMap, IndexType columns, CommunicationGroup gro
    this->group = group;
    if( group != Communicator::NullGroup )
       localMatrix.setDimensions( rowIndexMap.getLocalSize(), columns );
+
+   resetBuffers();
 }
 
 template< typename Matrix,
@@ -135,6 +142,8 @@ setLike( const MatrixT& matrix )
    rowIndexMap = matrix.getRowIndexMap();
    group = matrix.getCommunicationGroup();
    localMatrix.setLike( matrix.getLocalMatrix() );
+
+   resetBuffers();
 }
 
 template< typename Matrix,
@@ -147,6 +156,8 @@ reset()
    rowIndexMap.reset();
    group = Communicator::NullGroup;
    localMatrix.reset();
+
+   resetBuffers();
 }
 
 template< typename Matrix,
@@ -182,8 +193,11 @@ setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
    TNL_ASSERT_EQ( rowLengths.getIndexMap(), getRowIndexMap(), "row lengths vector has wrong distribution" );
    TNL_ASSERT_EQ( rowLengths.getCommunicationGroup(), getCommunicationGroup(), "row lengths vector has wrong communication group" );
 
-   if( getCommunicationGroup() != CommunicatorType::NullGroup )
+   if( getCommunicationGroup() != CommunicatorType::NullGroup ) {
       localMatrix.setCompressedRowLengths( rowLengths.getLocalVectorView() );
+
+      resetBuffers();
+   }
 }
 
 template< typename Matrix,
@@ -334,5 +348,176 @@ vectorProduct( const Vector& inVector,
    localMatrix.vectorProduct( inVector, outView );
 }
 
+template< typename Matrix,
+          typename Communicator,
+          typename IndexMap >
+   template< typename Partitioner >
+void
+DistributedMatrix< Matrix, Communicator, IndexMap >::
+updateVectorProductCommunicationPattern()
+{
+   if( getCommunicationGroup() == CommunicatorType::NullGroup )
+      return;
+
+   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
+   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );
+   commPattern.setDimensions( nproc, nproc );
+
+   // pass the localMatrix to the device
+   Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );
+
+   // buffer for the local row of the commPattern matrix
+//   using AtomicBool = Atomic< bool, DeviceType >;
+   // FIXME: CUDA does not support atomic operations for bool
+   using AtomicBool = Atomic< int, DeviceType >;
+   Containers::Array< AtomicBool, DeviceType > buffer( nproc );
+   buffer.setValue( false );
+
+   // optimization for banded matrices
+   using AtomicIndex = Atomic< IndexType, DeviceType >;
+   Containers::Array< AtomicIndex, DeviceType > local_span( 2 );
+   local_span.setElement( 0, 0 );  // span start
+   local_span.setElement( 1, localMatrix.getRows() );  // span end
+
+   auto kernel = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix,
+                                         AtomicBool* buffer, AtomicIndex* local_span )
+   {
+      const IndexType columns = localMatrix->getColumns();
+      const auto row = localMatrix->getRow( i );
+      bool comm_left = false;
+      bool comm_right = false;
+      for( IndexType c = 0; c < row.getLength(); c++ ) {
+         const IndexType j = row.getElementColumn( c );
+         if( j < columns ) {
+            const int owner = Partitioner::getOwner( j, columns, nproc );
+            // atomic assignment
+            buffer[ owner ].store( true );
+            // update comm_left/Right
+            if( owner < rank )
+               comm_left = true;
+            if( owner > rank )
+               comm_right = true;
+         }
+      }
+      // update local span
+      if( comm_left )
+         local_span[0].fetch_max( i + 1 );
+      if( comm_right )
+         local_span[1].fetch_min( i );
+   };
+
+   ParallelFor< DeviceType >::exec( (IndexType) 0, localMatrix.getRows(),
+                                    kernel,
+                                    &localMatrixPointer.template getData< DeviceType >(),
+                                    buffer.getData(),
+                                    local_span.getData()
+                                 );
+
+   // set the local-only span (optimization for banded matrices)
+   localOnlySpan.first = local_span.getElement( 0 );
+   localOnlySpan.second = local_span.getElement( 1 );
+
+   // copy the buffer into all rows of the preCommPattern matrix
+   Matrices::Dense< bool, Devices::Host, int > preCommPattern;
+   preCommPattern.setLike( commPattern );
+   for( int j = 0; j < nproc; j++ )
+   for( int i = 0; i < nproc; i++ )
+      preCommPattern.setElementFast( j, i, buffer.getElement( i ) );
+
+   // assemble the commPattern matrix
+   CommunicatorType::Alltoall( &preCommPattern(0, 0), nproc,
+                               &commPattern(0, 0), nproc,
+                               getCommunicationGroup() );
+}
+
+template< typename Matrix,
+          typename Communicator,
+          typename IndexMap >
+   template< typename Partitioner,
+             typename RealIn,
+             typename RealOut >
+void
+DistributedMatrix< Matrix, Communicator, IndexMap >::
+vectorProduct( const DistVector< RealIn >& inVector,
+               DistVector< RealOut >& outVector )
+{
+   TNL_ASSERT_EQ( inVector.getSize(), getColumns(), "input vector has wrong size" );
+   TNL_ASSERT_EQ( inVector.getIndexMap(), getRowIndexMap(), "input vector has wrong distribution" );
+   TNL_ASSERT_EQ( inVector.getCommunicationGroup(), getCommunicationGroup(), "input vector has wrong communication group" );
+   TNL_ASSERT_EQ( outVector.getSize(), getRows(), "output vector has wrong size" );
+   TNL_ASSERT_EQ( outVector.getIndexMap(), getRowIndexMap(), "output vector has wrong distribution" );
+   TNL_ASSERT_EQ( outVector.getCommunicationGroup(), getCommunicationGroup(), "output vector has wrong communication group" );
+
+   if( getCommunicationGroup() == CommunicatorType::NullGroup )
+      return;
+
+   const int rank = CommunicatorType::GetRank( getCommunicationGroup() );
+   const int nproc = CommunicatorType::GetSize( getCommunicationGroup() );
+
+   // update communication pattern
+   if( commPattern.getRows() != nproc )
+      updateVectorProductCommunicationPattern< Partitioner >();
+
+   // prepare buffers
+   globalBuffer.setSize( localMatrix.getColumns() );
+   commRequests.clear();
+
+   // send our data to all processes that need it
+   for( int i = 0; i < commPattern.getRows(); i++ )
+      if( commPattern( i, rank ) )
+         commRequests.push_back( CommunicatorType::ISend(
+                  inVector.getLocalVectorView().getData(),
+                  inVector.getLocalVectorView().getSize(),
+                  i, getCommunicationGroup() ) );
+
+   // receive data that we need
+   for( int j = 0; j < commPattern.getRows(); j++ )
+      if( commPattern( rank, j ) )
+         commRequests.push_back( CommunicatorType::IRecv(
+                  &globalBuffer[ Partitioner::getOffset( globalBuffer.getSize(), j, nproc ) ],
+                  Partitioner::getSizeForRank( globalBuffer.getSize(), j, nproc ),
+                  j, getCommunicationGroup() ) );
+
+   // general variant
+   if( localOnlySpan.first >= localOnlySpan.second ) {
+      // wait for all communications to finish
+      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
+
+      // perform matrix-vector multiplication
+      vectorProduct( globalBuffer, outVector );
+   }
+   // optimization for banded matrices
+   else {
+      Pointers::DevicePointer< MatrixType > localMatrixPointer( localMatrix );
+      auto outVectorView = outVector.getLocalVectorView();
+      // TODO
+//      const auto inVectorView = DistributedVectorView( inVector );
+      Pointers::DevicePointer< const DistVector< RealIn > > inVectorPointer( inVector );
+
+      // matrix-vector multiplication using local-only rows
+      auto kernel1 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix, const DistVector< RealIn >* inVector ) mutable
+      {
+         outVectorView[ i ] = localMatrix->rowVectorProduct( i, *inVector );
+      };
+      ParallelFor< DeviceType >::exec( localOnlySpan.first, localOnlySpan.second, kernel1,
+                                       &localMatrixPointer.template getData< DeviceType >(),
+                                       &inVectorPointer.template getData< DeviceType >() );
+
+      // wait for all communications to finish
+      CommunicatorType::WaitAll( &commRequests[0], commRequests.size() );
+
+      // finish the multiplication by adding the non-local entries
+      Containers::VectorView< RealType, DeviceType, IndexType > globalBufferView( globalBuffer );
+      auto kernel2 = [=] __cuda_callable__ ( IndexType i, const MatrixType* localMatrix ) mutable
+      {
+         outVectorView[ i ] = localMatrix->rowVectorProduct( i, globalBufferView );
+      };
+      ParallelFor< DeviceType >::exec( (IndexType) 0, localOnlySpan.first, kernel2,
+                                       &localMatrixPointer.template getData< DeviceType >() );
+      ParallelFor< DeviceType >::exec( localOnlySpan.second, localMatrix.getRows(), kernel2,
+                                       &localMatrixPointer.template getData< DeviceType >() );
+   }
+}
+
 } // namespace DistributedContainers
 } // namespace TNL
diff --git a/src/TNL/DistributedContainers/Partitioner.h b/src/TNL/DistributedContainers/Partitioner.h
index 4635fbd17ce925851aebc25baff325f1ff1ee087..68d0c9d3fa20ac084cf9f60cd6e10045c0fdd252 100644
--- a/src/TNL/DistributedContainers/Partitioner.h
+++ b/src/TNL/DistributedContainers/Partitioner.h
@@ -42,6 +42,29 @@ public:
       else
          return IndexMap( 0, 0, globalSize );
    }
+
+   // Gets the owner of given global index.
+   __cuda_callable__
+   static int getOwner( Index i, Index globalSize, int partitions )
+   {
+      return i * partitions / globalSize;
+   }
+
+   // Gets the offset of data for given rank.
+   __cuda_callable__
+   static Index getOffset( Index globalSize, int rank, int partitions )
+   {
+      return rank * globalSize / partitions;
+   }
+
+   // Gets the size of data assigned to given rank.
+   __cuda_callable__
+   static Index getSizeForRank( Index globalSize, int rank, int partitions )
+   {
+      const Index begin = min( globalSize, rank * globalSize / partitions );
+      const Index end = min( globalSize, (rank + 1) * globalSize / partitions );
+      return end - begin;
+   }
 };
 
 } // namespace DistributedContainers
diff --git a/src/UnitTests/DistributedContainers/DistributedMatrixTest.h b/src/UnitTests/DistributedContainers/DistributedMatrixTest.h
index f4b56fc6e3f7d3dbe837eaa0b0d5ea2159334b96..0bd95d6792a02112e048d1c03517ca28f6b7d7d6 100644
--- a/src/UnitTests/DistributedContainers/DistributedMatrixTest.h
+++ b/src/UnitTests/DistributedContainers/DistributedMatrixTest.h
@@ -70,6 +70,7 @@ protected:
    using IndexType = typename DistributedMatrix::IndexType;
    using IndexMap = typename DistributedMatrix::IndexMapType;
    using DistributedMatrixType = DistributedMatrix;
+   using Partitioner = DistributedContainers::Partitioner< IndexMap, CommunicatorType >;
 
    using RowLengthsVector = typename DistributedMatrixType::CompressedRowLengthsVector;
    using GlobalVector = Containers::Vector< RealType, DeviceType, IndexType >;
@@ -88,7 +89,7 @@ protected:
 
    void SetUp() override
    {
-      const IndexMap map = DistributedContainers::Partitioner< IndexMap, CommunicatorType >::splitRange( globalSize, group );
+      const IndexMap map = Partitioner::splitRange( globalSize, group );
       matrix.setDistribution( map, globalSize, group );
       rowLengths.setDistribution( map, group );
 
@@ -220,6 +221,7 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_globalInput )
 TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
 {
    using DistributedVector = typename TestFixture::DistributedVector;
+   using Partitioner = typename TestFixture::Partitioner;
 
    this->matrix.setCompressedRowLengths( this->rowLengths );
    setMatrix( this->matrix, this->rowLengths );
@@ -227,7 +229,7 @@ TYPED_TEST( DistributedMatrixTest, vectorProduct_distributedInput )
    DistributedVector inVector( this->matrix.getRowIndexMap(), this->matrix.getCommunicationGroup() );
    inVector.setValue( 1 );
    DistributedVector outVector( this->matrix.getRowIndexMap(), this->matrix.getCommunicationGroup() );
-   this->matrix.vectorProduct( inVector, outVector );
+   this->matrix.template vectorProduct< Partitioner >( inVector, outVector );
 
    EXPECT_EQ( outVector, this->rowLengths );
 }