From 8ae818f14b121ec5ebbdbe087cd1801126b15f05 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Thu, 20 Sep 2018 17:05:40 +0200
Subject: [PATCH] Updating linear solvers and preconditioners for distributed
 matrices/vectors

---
 .../tnl-benchmark-linear-solvers.h            | 92 +++++++++----------
 src/TNL/Solvers/Linear/BICGStab.h             |  6 +-
 src/TNL/Solvers/Linear/BICGStab_impl.h        | 18 ++--
 src/TNL/Solvers/Linear/LinearSolver.h         |  7 +-
 .../Solvers/Linear/Preconditioners/Diagonal.h | 29 ++++++
 .../Linear/Preconditioners/Diagonal_impl.h    | 40 ++++++++
 src/TNL/Solvers/Linear/Preconditioners/ILU0.h | 38 ++++++++
 .../Linear/Preconditioners/ILU0_impl.h        | 40 ++++++--
 src/TNL/Solvers/Linear/Preconditioners/ILUT.h | 14 +++
 .../Linear/Preconditioners/ILUT_impl.h        | 32 +++++--
 .../Linear/Preconditioners/Preconditioner.h   |  6 +-
 src/TNL/Solvers/Linear/TFQMR.h                |  8 +-
 src/TNL/Solvers/Linear/TFQMR_impl.h           | 23 ++---
 src/TNL/Solvers/Linear/Traits.h               | 91 ++++++++++++++++++
 14 files changed, 341 insertions(+), 103 deletions(-)
 create mode 100644 src/TNL/Solvers/Linear/Traits.h

diff --git a/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h b/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h
index fe10fd72e1..04911eb867 100644
--- a/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h
+++ b/src/Benchmarks/LinearSolvers/tnl-benchmark-linear-solvers.h
@@ -99,17 +99,17 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
    benchmarkPreconditionerUpdate< Diagonal >( benchmark, parameters, cudaMatrixPointer );
 #endif
 
-   benchmark.setOperation("GMRES (Jacobi)");
-   benchmarkSolver< GMRES, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< GMRES, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
-
-   benchmark.setOperation("CWYGMRES (Jacobi)");
-   benchmarkSolver< CWYGMRES, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< CWYGMRES, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//   benchmark.setOperation("GMRES (Jacobi)");
+//   benchmarkSolver< GMRES, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< GMRES, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
+
+//   benchmark.setOperation("CWYGMRES (Jacobi)");
+//   benchmarkSolver< CWYGMRES, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< CWYGMRES, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
 
    benchmark.setOperation("TFQMR (Jacobi)");
    benchmarkSolver< TFQMR, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
@@ -125,11 +125,11 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
 
    for( int ell = 1; ell <= ell_max; ell++ ) {
       parameters.template setParameter< int >( "bicgstab-ell", ell );
-      benchmark.setOperation("BiCGstab(" + String(ell) + ") (Jacobi)");
-      benchmarkSolver< BICGStabL, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-      benchmarkSolver< BICGStabL, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//      benchmark.setOperation("BiCGstab(" + String(ell) + ") (Jacobi)");
+//      benchmarkSolver< BICGStabL, Diagonal >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//      benchmarkSolver< BICGStabL, Diagonal >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
    }
 
 
@@ -139,17 +139,17 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
    benchmarkPreconditionerUpdate< ILU0 >( benchmark, parameters, cudaMatrixPointer );
 #endif
 
-   benchmark.setOperation("GMRES (ILU0)");
-   benchmarkSolver< GMRES, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< GMRES, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//   benchmark.setOperation("GMRES (ILU0)");
+//   benchmarkSolver< GMRES, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< GMRES, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
 
-   benchmark.setOperation("CWYGMRES (ILU0)");
-   benchmarkSolver< CWYGMRES, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< CWYGMRES, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//   benchmark.setOperation("CWYGMRES (ILU0)");
+//   benchmarkSolver< CWYGMRES, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< CWYGMRES, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
 
    benchmark.setOperation("TFQMR (ILU0)");
    benchmarkSolver< TFQMR, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
@@ -165,11 +165,11 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
 
    for( int ell = 1; ell <= ell_max; ell++ ) {
       parameters.template setParameter< int >( "bicgstab-ell", ell );
-      benchmark.setOperation("BiCGstab(" + String(ell) + ") (ILU0)");
-      benchmarkSolver< BICGStabL, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-      benchmarkSolver< BICGStabL, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//      benchmark.setOperation("BiCGstab(" + String(ell) + ") (ILU0)");
+//      benchmarkSolver< BICGStabL, ILU0 >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//      benchmarkSolver< BICGStabL, ILU0 >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
    }
 
 
@@ -179,17 +179,17 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
    benchmarkPreconditionerUpdate< ILUT >( benchmark, parameters, cudaMatrixPointer );
 #endif
 
-   benchmark.setOperation("GMRES (ILUT)");
-   benchmarkSolver< GMRES, ILUT >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< GMRES, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//   benchmark.setOperation("GMRES (ILUT)");
+//   benchmarkSolver< GMRES, ILUT >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< GMRES, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
 
-   benchmark.setOperation("CWYGMRES (ILUT)");
-   benchmarkSolver< CWYGMRES, ILUT >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-   benchmarkSolver< CWYGMRES, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//   benchmark.setOperation("CWYGMRES (ILUT)");
+//   benchmarkSolver< CWYGMRES, ILUT >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//   benchmarkSolver< CWYGMRES, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
 
    benchmark.setOperation("TFQMR (ILUT)");
    benchmarkSolver< TFQMR, ILUT >( benchmark, parameters, matrixPointer, x0, b );
@@ -205,11 +205,11 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
 
    for( int ell = 1; ell <= ell_max; ell++ ) {
       parameters.template setParameter< int >( "bicgstab-ell", ell );
-      benchmark.setOperation("BiCGstab(" + String(ell) + ") (ILUT)");
-      benchmarkSolver< BICGStabL, ILUT >( benchmark, parameters, matrixPointer, x0, b );
-#ifdef HAVE_CUDA
-      benchmarkSolver< BICGStabL, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
-#endif
+//      benchmark.setOperation("BiCGstab(" + String(ell) + ") (ILUT)");
+//      benchmarkSolver< BICGStabL, ILUT >( benchmark, parameters, matrixPointer, x0, b );
+//#ifdef HAVE_CUDA
+//      benchmarkSolver< BICGStabL, ILUT >( benchmark, parameters, cudaMatrixPointer, cuda_x0, cuda_b );
+//#endif
    }
 }
 
diff --git a/src/TNL/Solvers/Linear/BICGStab.h b/src/TNL/Solvers/Linear/BICGStab.h
index 03b5f70b74..686d6f4503 100644
--- a/src/TNL/Solvers/Linear/BICGStab.h
+++ b/src/TNL/Solvers/Linear/BICGStab.h
@@ -12,8 +12,6 @@
 
 #include "LinearSolver.h"
 
-#include <TNL/Containers/Vector.h>
-
 namespace TNL {
 namespace Solvers {
 namespace Linear {
@@ -41,11 +39,11 @@ public:
    bool solve( ConstVectorViewType b, VectorViewType x ) override;
 
 protected:
-   void setSize( IndexType size );
+   void setSize( const VectorViewType& x );
 
    bool exact_residue = false;
 
-   Containers::Vector< RealType, DeviceType, IndexType > r, r_ast, p, s, Ap, As, M_tmp;
+   typename Traits< Matrix >::VectorType r, r_ast, p, s, Ap, As, M_tmp;
 };
 
 } // namespace Linear
diff --git a/src/TNL/Solvers/Linear/BICGStab_impl.h b/src/TNL/Solvers/Linear/BICGStab_impl.h
index 86702b3103..e0313e0048 100644
--- a/src/TNL/Solvers/Linear/BICGStab_impl.h
+++ b/src/TNL/Solvers/Linear/BICGStab_impl.h
@@ -48,7 +48,7 @@ setup( const Config::ParameterContainer& parameters,
 template< typename Matrix >
 bool BICGStab< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
 {
-   this->setSize( this->matrix->getRows() );
+   this->setSize( x );
 
    RealType alpha, beta, omega, aux, rho, rho_old, b_norm;
 
@@ -161,15 +161,15 @@ bool BICGStab< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
 }
 
 template< typename Matrix >
-void BICGStab< Matrix > :: setSize( IndexType size )
+void BICGStab< Matrix > :: setSize( const VectorViewType& x )
 {
-   r.setSize( size );
-   r_ast.setSize( size );
-   p.setSize( size );
-   s.setSize( size );
-   Ap.setSize( size );
-   As.setSize( size );
-   M_tmp.setSize( size );
+   r.setLike( x );
+   r_ast.setLike( x );
+   p.setLike( x );
+   s.setLike( x );
+   Ap.setLike( x );
+   As.setLike( x );
+   M_tmp.setLike( x );
 }
 
 } // namespace Linear
diff --git a/src/TNL/Solvers/Linear/LinearSolver.h b/src/TNL/Solvers/Linear/LinearSolver.h
index 7cf9e9665a..66f5f0be0f 100644
--- a/src/TNL/Solvers/Linear/LinearSolver.h
+++ b/src/TNL/Solvers/Linear/LinearSolver.h
@@ -17,9 +17,10 @@
 
 #include <TNL/Solvers/IterativeSolver.h>
 #include <TNL/Solvers/Linear/Preconditioners/Preconditioner.h>
-#include <TNL/Containers/VectorView.h>
 #include <TNL/Pointers/SharedPointer.h>
 
+#include "Traits.h"
+
 namespace TNL {
 namespace Solvers {
 namespace Linear {
@@ -32,8 +33,8 @@ public:
    using RealType = typename Matrix::RealType;
    using DeviceType = typename Matrix::DeviceType;
    using IndexType = typename Matrix::IndexType;
-   using VectorViewType = Containers::VectorView< RealType, DeviceType, IndexType >;
-   using ConstVectorViewType = Containers::VectorView< typename std::add_const< RealType >::type, DeviceType, IndexType >;
+   using VectorViewType = typename Traits< Matrix >::VectorViewType;
+   using ConstVectorViewType = typename Traits< Matrix >::ConstVectorViewType;
    using MatrixType = Matrix;
    using MatrixPointer = Pointers::SharedPointer< typename std::add_const< MatrixType >::type >;
    using PreconditionerType = Preconditioners::Preconditioner< MatrixType >;
diff --git a/src/TNL/Solvers/Linear/Preconditioners/Diagonal.h b/src/TNL/Solvers/Linear/Preconditioners/Diagonal.h
index 16dda9eea9..0b09814f95 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/Diagonal.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/Diagonal.h
@@ -47,6 +47,35 @@ protected:
    VectorType diagonal;
 };
 
+template< typename Matrix, typename Communicator >
+class Diagonal< DistributedContainers::DistributedMatrix< Matrix, Communicator > >
+: public Preconditioner< DistributedContainers::DistributedMatrix< Matrix, Communicator > >
+{
+public:
+   using MatrixType = DistributedContainers::DistributedMatrix< Matrix, Communicator >;
+   using RealType = typename MatrixType::RealType;
+   using DeviceType = typename MatrixType::DeviceType;
+   using IndexType = typename MatrixType::IndexType;
+   using typename Preconditioner< MatrixType >::VectorViewType;
+   using typename Preconditioner< MatrixType >::ConstVectorViewType;
+   using typename Preconditioner< MatrixType >::MatrixPointer;
+   using VectorType = Containers::Vector< RealType, DeviceType, IndexType >;
+   using LocalVectorViewType = Containers::VectorView< RealType, DeviceType, IndexType >;
+   using ConstLocalVectorViewType = Containers::VectorView< typename std::add_const< RealType >::type, DeviceType, IndexType >;
+
+   virtual void update( const MatrixPointer& matrixPointer ) override;
+
+   virtual void solve( ConstVectorViewType b, VectorViewType x ) const override;
+
+   String getType() const
+   {
+      return String( "Diagonal" );
+   }
+
+protected:
+   VectorType diagonal;
+};
+
 } // namespace Preconditioners
 } // namespace Linear
 } // namespace Solvers
diff --git a/src/TNL/Solvers/Linear/Preconditioners/Diagonal_impl.h b/src/TNL/Solvers/Linear/Preconditioners/Diagonal_impl.h
index d72263d3c6..42724077e9 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/Diagonal_impl.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/Diagonal_impl.h
@@ -57,6 +57,46 @@ solve( ConstVectorViewType b, VectorViewType x ) const
    ParallelFor< DeviceType >::exec( (IndexType) 0, diagonal.getSize(), kernel );
 }
 
+
+template< typename Matrix, typename Communicator >
+void
+Diagonal< DistributedContainers::DistributedMatrix< Matrix, Communicator > >::
+update( const MatrixPointer& matrixPointer )
+{
+   TNL_ASSERT_GT( matrixPointer->getRows(), 0, "empty matrix" );
+   TNL_ASSERT_EQ( matrixPointer->getRows(), matrixPointer->getColumns(), "matrix must be square" );
+
+   diagonal.setSize( matrixPointer->getLocalMatrix().getRows() );
+
+   LocalVectorViewType diag_view( diagonal );
+   const MatrixType* kernel_matrix = &matrixPointer.template getData< DeviceType >();
+
+   auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
+   {
+      const IndexType gi = kernel_matrix->getLocalRowRange().getGlobalIndex( i );
+      diag_view[ i ] = kernel_matrix->getLocalMatrix().getElementFast( i, gi );
+   };
+
+   ParallelFor< DeviceType >::exec( (IndexType) 0, diagonal.getSize(), kernel );
+}
+
+template< typename Matrix, typename Communicator >
+void
+Diagonal< DistributedContainers::DistributedMatrix< Matrix, Communicator > >::
+solve( ConstVectorViewType b, VectorViewType x ) const
+{
+   ConstLocalVectorViewType diag_view( diagonal );
+   const auto b_view = b.getLocalVectorView();
+   auto x_view = x.getLocalVectorView();
+
+   auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
+   {
+      x_view[ i ] = b_view[ i ] / diag_view[ i ];
+   };
+
+   ParallelFor< DeviceType >::exec( (IndexType) 0, diagonal.getSize(), kernel );
+}
+
 } // namespace Preconditioners
 } // namespace Linear
 } // namespace Solvers
diff --git a/src/TNL/Solvers/Linear/Preconditioners/ILU0.h b/src/TNL/Solvers/Linear/Preconditioners/ILU0.h
index b1d835a3d9..0cbf21a75a 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/ILU0.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/ILU0.h
@@ -64,6 +64,20 @@ public:
 protected:
    // The factors L and U are stored separately and the rows of U are reversed.
    Matrices::CSR< RealType, DeviceType, IndexType > L, U;
+
+   // Specialized methods to distinguish between normal and distributed matrices
+   // in the implementation.
+   template< typename M >
+   static IndexType getMinColumn( const M& m )
+   {
+      return 0;
+   }
+
+   template< typename M >
+   static IndexType getMinColumn( const DistributedContainers::DistributedMatrix< M >& m )
+   {
+      return m.getLocalRowRange().getBegin();
+   }
 };
 
 template< typename Matrix >
@@ -171,6 +185,30 @@ protected:
 #endif
 };
 
+template< typename Matrix, typename Communicator >
+class ILU0_impl< DistributedContainers::DistributedMatrix< Matrix, Communicator >, double, Devices::Cuda, int >
+: public Preconditioner< DistributedContainers::DistributedMatrix< Matrix, Communicator > >
+{
+   using MatrixType = DistributedContainers::DistributedMatrix< Matrix, Communicator >;
+public:
+   using RealType = double;
+   using DeviceType = Devices::Cuda;
+   using IndexType = int;
+   using typename Preconditioner< MatrixType >::VectorViewType;
+   using typename Preconditioner< MatrixType >::ConstVectorViewType;
+   using typename Preconditioner< MatrixType >::MatrixPointer;
+
+   virtual void update( const MatrixPointer& matrixPointer ) override
+   {
+      throw std::runtime_error("ILU0 is not implemented yet for CUDA and distributed matrices.");
+   }
+
+   virtual void solve( ConstVectorViewType b, VectorViewType x ) const override
+   {
+      throw std::runtime_error("ILU0 is not implemented yet for CUDA and distributed matrices.");
+   }
+};
+
 template< typename Matrix, typename Real, typename Index >
 class ILU0_impl< Matrix, Real, Devices::MIC, Index >
 : public Preconditioner< Matrix >
diff --git a/src/TNL/Solvers/Linear/Preconditioners/ILU0_impl.h b/src/TNL/Solvers/Linear/Preconditioners/ILU0_impl.h
index 0d867f2b72..e598609ae5 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/ILU0_impl.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/ILU0_impl.h
@@ -30,7 +30,9 @@ update( const MatrixPointer& matrixPointer )
    TNL_ASSERT_GT( matrixPointer->getRows(), 0, "empty matrix" );
    TNL_ASSERT_EQ( matrixPointer->getRows(), matrixPointer->getColumns(), "matrix must be square" );
 
-   const IndexType N = matrixPointer->getRows();
+   const auto& localMatrix = Traits< Matrix >::getLocalMatrix( *matrixPointer );
+   const IndexType N = localMatrix.getRows();
+   const IndexType minColumn = getMinColumn( *matrixPointer );
 
    L.setDimensions( N, N );
    U.setDimensions( N, N );
@@ -41,15 +43,17 @@ update( const MatrixPointer& matrixPointer )
    L_rowLengths.setSize( N );
    U_rowLengths.setSize( N );
    for( IndexType i = 0; i < N; i++ ) {
-      const auto row = matrixPointer->getRow( i );
-      const auto max_length = matrixPointer->getRowLength( i );
+      const auto row = localMatrix.getRow( i );
+      const auto max_length = localMatrix.getRowLength( i );
       IndexType L_entries = 0;
       IndexType U_entries = 0;
       for( IndexType j = 0; j < max_length; j++ ) {
          const auto column = row.getElementColumn( j );
-         if( column < i )
+         if( column < minColumn )
+            continue;
+         if( column < i + minColumn )
             L_entries++;
-         else if( column < N )
+         else if( column < N + minColumn )
             U_entries++;
          else
             break;
@@ -64,10 +68,23 @@ update( const MatrixPointer& matrixPointer )
    // The factors L and U are stored separately and the rows of U are reversed.
    for( IndexType i = 0; i < N; i++ ) {
       // copy all non-zero entries from A into L and U
-      const auto max_length = matrixPointer->getRowLength( i );
-      IndexType columns[ max_length ];
-      RealType values[ max_length ];
-      matrixPointer->getRowFast( i, columns, values );
+      const auto max_length = localMatrix.getRowLength( i );
+      IndexType all_columns[ max_length ];
+      RealType all_values[ max_length ];
+      localMatrix.getRowFast( i, all_columns, all_values );
+
+      // skip non-local elements
+      IndexType* columns = all_columns;
+      RealType* values = all_values;
+      while( columns[0] < minColumn ) {
+         columns++;
+         values++;
+      }
+
+      // update column column indices
+      if( minColumn > 0 )
+         for( IndexType c_j = 0; c_j < max_length; c_j++ )
+            all_columns[ c_j ] -= minColumn;
 
       const auto L_entries = L_rowLengths[ i ];
       const auto U_entries = U_rowLengths[ N - 1 - i ];
@@ -106,8 +123,11 @@ update( const MatrixPointer& matrixPointer )
 template< typename Matrix, typename Real, typename Index >
 void
 ILU0_impl< Matrix, Real, Devices::Host, Index >::
-solve( ConstVectorViewType b, VectorViewType x ) const
+solve( ConstVectorViewType _b, VectorViewType _x ) const
 {
+   const auto b = Traits< Matrix >::getLocalVectorView( _b );
+   auto x = Traits< Matrix >::getLocalVectorView( _x );
+
    // Step 1: solve y from Ly = b
    triangularSolveLower< true >( L, x, b );
 
diff --git a/src/TNL/Solvers/Linear/Preconditioners/ILUT.h b/src/TNL/Solvers/Linear/Preconditioners/ILUT.h
index fff6fd2967..5f1654ecf0 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/ILUT.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/ILUT.h
@@ -72,6 +72,20 @@ protected:
 
    // The factors L and U are stored separately and the rows of U are reversed.
    Matrices::CSR< RealType, DeviceType, IndexType > L, U;
+
+   // Specialized methods to distinguish between normal and distributed matrices
+   // in the implementation.
+   template< typename M >
+   static IndexType getMinColumn( const M& m )
+   {
+      return 0;
+   }
+
+   template< typename M >
+   static IndexType getMinColumn( const DistributedContainers::DistributedMatrix< M >& m )
+   {
+      return m.getLocalRowRange().getBegin();
+   }
 };
 
 template< typename Matrix, typename Real, typename Index >
diff --git a/src/TNL/Solvers/Linear/Preconditioners/ILUT_impl.h b/src/TNL/Solvers/Linear/Preconditioners/ILUT_impl.h
index d0e6708351..67d4aa627f 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/ILUT_impl.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/ILUT_impl.h
@@ -43,7 +43,9 @@ update( const MatrixPointer& matrixPointer )
    TNL_ASSERT_GT( matrixPointer->getRows(), 0, "empty matrix" );
    TNL_ASSERT_EQ( matrixPointer->getRows(), matrixPointer->getColumns(), "matrix must be square" );
 
-   const IndexType N = matrixPointer->getRows();
+   const auto& localMatrix = Traits< Matrix >::getLocalMatrix( *matrixPointer );
+   const IndexType N = localMatrix.getRows();
+   const IndexType minColumn = getMinColumn( *matrixPointer );
 
    L.setDimensions( N, N );
    U.setDimensions( N, N );
@@ -59,15 +61,17 @@ update( const MatrixPointer& matrixPointer )
    L_rowLengths.setSize( N );
    U_rowLengths.setSize( N );
    for( IndexType i = 0; i < N; i++ ) {
-      const auto row = matrixPointer->getRow( i );
-      const auto max_length = matrixPointer->getRowLength( i );
+      const auto row = localMatrix.getRow( i );
+      const auto max_length = localMatrix.getRowLength( i );
       IndexType L_entries = 0;
       IndexType U_entries = 0;
       for( IndexType j = 0; j < max_length; j++ ) {
          const auto column = row.getElementColumn( j );
-         if( column < i )
+         if( column < minColumn )
+            continue;
+         if( column < i + minColumn )
             L_entries++;
-         else if( column < N )
+         else if( column < N + minColumn )
             U_entries++;
          else
             break;
@@ -103,15 +107,20 @@ update( const MatrixPointer& matrixPointer )
    // Incomplete LU factorization with threshold
    // (see Saad - Iterative methods for sparse linear systems, section 10.4)
    for( IndexType i = 0; i < N; i++ ) {
-      const auto max_length = matrixPointer->getRowLength( i );
-      const auto A_i = matrixPointer->getRow( i );
+      const auto max_length = localMatrix.getRowLength( i );
+      const auto A_i = localMatrix.getRow( i );
 
       RealType A_i_norm = 0.0;
 
       // copy A_i into the full vector w
       timer_copy_into_w.start();
       for( IndexType c_j = 0; c_j < max_length; c_j++ ) {
-         const auto j = A_i.getElementColumn( c_j );
+         auto j = A_i.getElementColumn( c_j );
+         if( minColumn > 0 ) {
+            // skip non-local elements
+            if( j < minColumn ) continue;
+            j -= minColumn;
+         }
          // handle ellpack dummy entries
          if( j >= N ) break;
          w[ j ] = A_i.getElementValue( c_j );
@@ -132,7 +141,7 @@ update( const MatrixPointer& matrixPointer )
          if( w_k == 0.0 )
             continue;
 
-         w_k /= matrixPointer->getElementFast( k, k );
+         w_k /= localMatrix.getElementFast( k, k + minColumn );
 
          // apply dropping rule to w_k
          if( std::abs( w_k ) < tau_i )
@@ -245,8 +254,11 @@ update( const MatrixPointer& matrixPointer )
 template< typename Matrix, typename Real, typename Index >
 void
 ILUT_impl< Matrix, Real, Devices::Host, Index >::
-solve( ConstVectorViewType b, VectorViewType x ) const
+solve( ConstVectorViewType _b, VectorViewType _x ) const
 {
+   const auto b = Traits< Matrix >::getLocalVectorView( _b );
+   auto x = Traits< Matrix >::getLocalVectorView( _x );
+
    // Step 1: solve y from Ly = b
    triangularSolveLower< false >( L, x, b );
 
diff --git a/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h b/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h
index 70c5d7cf84..2efc01001b 100644
--- a/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h
+++ b/src/TNL/Solvers/Linear/Preconditioners/Preconditioner.h
@@ -16,6 +16,8 @@
 #include <TNL/Pointers/SharedPointer.h>
 #include <TNL/Config/ParameterContainer.h>
 
+#include "../Traits.h"
+
 namespace TNL {
 namespace Solvers {
 namespace Linear {
@@ -28,8 +30,8 @@ public:
    using RealType = typename Matrix::RealType;
    using DeviceType = typename Matrix::DeviceType;
    using IndexType = typename Matrix::IndexType;
-   using VectorViewType = Containers::VectorView< RealType, DeviceType, IndexType >;
-   using ConstVectorViewType = Containers::VectorView< typename std::add_const< RealType >::type, DeviceType, IndexType >;
+   using VectorViewType = typename Traits< Matrix >::VectorViewType;
+   using ConstVectorViewType = typename Traits< Matrix >::ConstVectorViewType;
    using MatrixType = Matrix;
    using MatrixPointer = Pointers::SharedPointer< typename std::add_const< MatrixType >::type >;
 
diff --git a/src/TNL/Solvers/Linear/TFQMR.h b/src/TNL/Solvers/Linear/TFQMR.h
index e693032a3f..73d0894aad 100644
--- a/src/TNL/Solvers/Linear/TFQMR.h
+++ b/src/TNL/Solvers/Linear/TFQMR.h
@@ -12,8 +12,6 @@
 
 #include "LinearSolver.h"
 
-#include <TNL/Containers/Vector.h>
-
 namespace TNL {
 namespace Solvers {
 namespace Linear {
@@ -35,11 +33,9 @@ public:
    bool solve( ConstVectorViewType b, VectorViewType x ) override;
 
 protected:
-   void setSize( IndexType size );
-
-   Containers::Vector< RealType, DeviceType, IndexType > d, r, w, u, v, r_ast, Au, M_tmp;
+   void setSize( const VectorViewType& x );
 
-   IndexType size = 0;
+   typename Traits< Matrix >::VectorType d, r, w, u, v, r_ast, Au, M_tmp;
 };
 
 } // namespace Linear
diff --git a/src/TNL/Solvers/Linear/TFQMR_impl.h b/src/TNL/Solvers/Linear/TFQMR_impl.h
index f87d961520..3f144875dd 100644
--- a/src/TNL/Solvers/Linear/TFQMR_impl.h
+++ b/src/TNL/Solvers/Linear/TFQMR_impl.h
@@ -29,7 +29,7 @@ String TFQMR< Matrix > :: getType() const
 template< typename Matrix >
 bool TFQMR< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
 {
-   this->setSize( this->matrix->getRows() );
+   this->setSize( x );
 
    RealType tau, theta, eta, rho, alpha, b_norm, w_norm;
 
@@ -129,19 +129,16 @@ bool TFQMR< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
 }
 
 template< typename Matrix >
-void TFQMR< Matrix > :: setSize( IndexType size )
+void TFQMR< Matrix > :: setSize( const VectorViewType& x )
 {
-   if( this->size == size )
-      return;
-   this->size = size;
-   d.setSize( size );
-   r.setSize( size );
-   w.setSize( size );
-   u.setSize( size );
-   v.setSize( size );
-   r_ast.setSize( size );
-   Au.setSize( size );
-   M_tmp.setSize( size );
+   d.setLike( x );
+   r.setLike( x );
+   w.setLike( x );
+   u.setLike( x );
+   v.setLike( x );
+   r_ast.setLike( x );
+   Au.setLike( x );
+   M_tmp.setLike( x );
 }
 
 } // namespace Linear
diff --git a/src/TNL/Solvers/Linear/Traits.h b/src/TNL/Solvers/Linear/Traits.h
new file mode 100644
index 0000000000..343fae5c8a
--- /dev/null
+++ b/src/TNL/Solvers/Linear/Traits.h
@@ -0,0 +1,91 @@
+/***************************************************************************
+                          Traits.h  -  description
+                             -------------------
+    begin                : Sep 20, 2018
+    copyright            : (C) 2018 by Tomas Oberhuber et al.
+    email                : tomas.oberhuber@fjfi.cvut.cz
+ ***************************************************************************/
+
+/* See Copyright Notice in tnl/Copyright */
+
+// Implemented by: Jakub Klinkovský
+
+#pragma once
+
+#include <TNL/Containers/VectorView.h>
+#include <TNL/DistributedContainers/DistributedVectorView.h>
+#include <TNL/DistributedContainers/DistributedMatrix.h>
+
+namespace TNL {
+namespace Solvers {
+namespace Linear {
+
+template< typename Matrix >
+struct Traits
+{
+   using VectorType = Containers::Vector
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+   using VectorViewType = Containers::VectorView
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+   using ConstVectorViewType = Containers::VectorView
+         < typename std::add_const< typename Matrix::RealType >::type,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+
+   // compatibility aliases
+   using LocalVectorType = VectorType;
+   using LocalVectorViewType = VectorViewType;
+   using ConstLocalVectorViewType = ConstVectorViewType;
+
+   // compatibility wrappers for some DistributedMatrix methods
+   static const Matrix& getLocalMatrix( const Matrix& m ) { return m; }
+   static ConstLocalVectorViewType getLocalVectorView( ConstVectorViewType v ) { return v; }
+   static LocalVectorViewType getLocalVectorView( VectorViewType v ) { return v; }
+};
+
+template< typename Matrix, typename Communicator >
+struct Traits< DistributedContainers::DistributedMatrix< Matrix, Communicator > >
+{
+   using VectorType = DistributedContainers::DistributedVector
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType,
+           Communicator >;
+   using VectorViewType = DistributedContainers::DistributedVectorView
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType,
+           Communicator >;
+   using ConstVectorViewType = DistributedContainers::DistributedVectorView
+         < typename std::add_const< typename Matrix::RealType >::type,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType,
+           Communicator >;
+
+   using LocalVectorType = Containers::Vector
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+   using LocalVectorViewType = Containers::VectorView
+         < typename Matrix::RealType,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+   using ConstLocalVectorViewType = Containers::VectorView
+         < typename std::add_const< typename Matrix::RealType >::type,
+           typename Matrix::DeviceType,
+           typename Matrix::IndexType >;
+
+   // compatibility wrappers for some DistributedMatrix methods
+   static const Matrix& getLocalMatrix( const DistributedContainers::DistributedMatrix< Matrix, Communicator >& m )
+   { return m.getLocalMatrix(); }
+   static ConstLocalVectorViewType getLocalVectorView( ConstVectorViewType v ) { return v.getLocalVectorView(); }
+   static LocalVectorViewType getLocalVectorView( VectorViewType v ) { return v.getLocalVectorView(); }
+};
+
+} // namespace Linear
+} // namespace Solvers
+} // namespace TNL
-- 
GitLab