From 40cd3071254dd505d3406423d6085348c4ccdfdb Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkovsky@mmg.fjfi.cvut.cz>
Date: Sat, 14 Nov 2020 11:43:02 +0100
Subject: [PATCH] Refactored BiCGStab, added restarting

---
 src/TNL/Solvers/Linear/BICGStab.h      |   4 +
 src/TNL/Solvers/Linear/BICGStab_impl.h | 138 +++++++++++++------------
 2 files changed, 74 insertions(+), 68 deletions(-)

diff --git a/src/TNL/Solvers/Linear/BICGStab.h b/src/TNL/Solvers/Linear/BICGStab.h
index 2cede824ad..474a45d023 100644
--- a/src/TNL/Solvers/Linear/BICGStab.h
+++ b/src/TNL/Solvers/Linear/BICGStab.h
@@ -37,6 +37,10 @@ public:
    bool solve( ConstVectorViewType b, VectorViewType x ) override;
 
 protected:
+   void compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b );
+
+   void preconditioned_matvec( ConstVectorViewType src, VectorViewType dst );
+
    void setSize( const VectorViewType& x );
 
    bool exact_residue = false;
diff --git a/src/TNL/Solvers/Linear/BICGStab_impl.h b/src/TNL/Solvers/Linear/BICGStab_impl.h
index baa4b6363e..ff3b42ed0c 100644
--- a/src/TNL/Solvers/Linear/BICGStab_impl.h
+++ b/src/TNL/Solvers/Linear/BICGStab_impl.h
@@ -38,111 +38,80 @@ setup( const Config::ParameterContainer& parameters,
 }
 
 template< typename Matrix >
-bool BICGStab< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
+bool
+BICGStab< Matrix >::
+solve( ConstVectorViewType b, VectorViewType x )
 {
    this->setSize( x );
 
-   RealType alpha, beta, omega, aux, rho, rho_old, b_norm;
+   RealType alpha, beta, omega, rho, rho_old, b_norm, r_ast_sqnorm;
 
+   // initialize the norm of the preconditioned right-hand-side
    if( this->preconditioner ) {
       this->preconditioner->solve( b, M_tmp );
       b_norm = lpNorm( M_tmp, 2.0 );
-
-      this->matrix->vectorProduct( x, M_tmp );
-      M_tmp = b - M_tmp;
-      this->preconditioner->solve( M_tmp, r );
    }
-   else {
+   else
       b_norm = lpNorm( b, 2.0 );
-      this->matrix->vectorProduct( x, r );
-      r = b - r;
-   }
+   if( b_norm == 0.0 )
+      b_norm = 1.0;
+
+   // r = M.solve(b - A * x);
+   compute_residue( r, x, b );
 
    p = r_ast = r;
    s.setValue( 0.0 );
-   rho = (r, r_ast);
+   r_ast_sqnorm = rho = (r, r_ast);
 
-   if( b_norm == 0.0 )
-       b_norm = 1.0;
+   const RealType eps2 = std::numeric_limits<RealType>::epsilon() * std::numeric_limits<RealType>::epsilon();
 
    this->resetIterations();
    this->setResidue( std::sqrt( rho ) / b_norm );
 
    while( this->nextIteration() )
    {
-      /****
-       * alpha_j = ( r_j, r^ast_0 ) / ( A * p_j, r^ast_0 )
-       */
-      if( this->preconditioner ) {
-         this->matrix->vectorProduct( p, M_tmp );
-         this->preconditioner->solve( M_tmp, Ap );
-      }
-      else {
-         this->matrix->vectorProduct( p, Ap );
-      }
-      aux = (Ap, r_ast);
-      alpha = rho / aux;
+      // alpha_j = ( r_j, r^ast_0 ) / ( A * p_j, r^ast_0 )
+      preconditioned_matvec( p, Ap );
+      alpha = rho / (Ap, r_ast);
 
-      /****
-       * s_j = r_j - alpha_j * A p_j
-       */
+      // s_j = r_j - alpha_j * A p_j
       s = r - alpha * Ap;
 
-      /****
-       * omega_j = ( A s_j, s_j ) / ( A s_j, A s_j )
-       */
-      if( this->preconditioner ) {
-         this->matrix->vectorProduct( s, M_tmp );
-         this->preconditioner->solve( M_tmp, As );
-      }
-      else {
-         this->matrix->vectorProduct( s, As );
-      }
-      aux = lpNorm( As, 2.0 );
-      omega = (As, s) / (aux * aux);
+      // omega_j = ( A s_j, s_j ) / ( A s_j, A s_j )
+      preconditioned_matvec( s, As );
+      omega = (As, s) / (As, As);
 
-      /****
-       * x_{j+1} = x_j + alpha_j * p_j + omega_j * s_j
-       */
+      // x_{j+1} = x_j + alpha_j * p_j + omega_j * s_j
       x += alpha * p + omega * s;
 
-      /****
-       * r_{j+1} = s_j - omega_j * A s_j
-       */
+      // r_{j+1} = s_j - omega_j * A s_j
       r = s - omega * As;
 
-      /****
-       * beta = alpha_j / omega_j * ( r_{j+1}, r^ast_0 ) / ( r_j, r^ast_0 )
-       */
+      // compute scalar product of the residual vectors
       rho_old = rho;
       rho = (r, r_ast);
+      if( abs(rho) < eps2 * r_ast_sqnorm ) {
+         // The new residual vector has become too orthogonal to the arbitrarily chosen direction r_ast.
+         // Let's restart with a new r0:
+         compute_residue( r, x, b );
+         r_ast = r;
+         r_ast_sqnorm = rho = (r, r_ast);
+      }
+
+      // beta = alpha_j / omega_j * ( r_{j+1}, r^ast_0 ) / ( r_j, r^ast_0 )
       beta = (rho / rho_old) * (alpha / omega);
 
-      /****
-       * p_{j+1} = r_{j+1} + beta_j * ( p_j - omega_j * A p_j )
-       */
+      // p_{j+1} = r_{j+1} + beta_j * ( p_j - omega_j * A p_j )
       p = r + beta * p - (beta * omega) * Ap;
 
       if( exact_residue ) {
-         /****
-          * Compute the exact preconditioned residue into the 's' vector.
-          */
-         if( this->preconditioner ) {
-            this->matrix->vectorProduct( x, M_tmp );
-            M_tmp = b - M_tmp;
-            this->preconditioner->solve( M_tmp, s );
-         }
-         else {
-            this->matrix->vectorProduct( x, s );
-            s = b - s;
-         }
+         // Compute the exact preconditioned residue into the 's' vector.
+         compute_residue( s, x, b );
          const RealType residue = lpNorm( s, 2.0 );
          this->setResidue( residue / b_norm );
       }
       else {
-         /****
-          * Use the "orthogonal residue vector" for stopping.
-          */
+         // Use the "orthogonal residue vector" for stopping.
          const RealType residue = lpNorm( r, 2.0 );
          this->setResidue( residue / b_norm );
       }
@@ -153,7 +122,40 @@ bool BICGStab< Matrix >::solve( ConstVectorViewType b, VectorViewType x )
 }
 
 template< typename Matrix >
-void BICGStab< Matrix > :: setSize( const VectorViewType& x )
+void
+BICGStab< Matrix >::
+compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b )
+{
+   // r = M.solve(b - A * x);
+   if( this->preconditioner ) {
+      this->matrix->vectorProduct( x, M_tmp );
+      M_tmp = b - M_tmp;
+      this->preconditioner->solve( M_tmp, r );
+   }
+   else {
+      this->matrix->vectorProduct( x, r );
+      r = b - r;
+   }
+}
+
+template< typename Matrix >
+void
+BICGStab< Matrix >::
+preconditioned_matvec( ConstVectorViewType src, VectorViewType dst )
+{
+   if( this->preconditioner ) {
+      this->matrix->vectorProduct( src, M_tmp );
+      this->preconditioner->solve( M_tmp, dst );
+   }
+   else {
+      this->matrix->vectorProduct( src, dst );
+   }
+}
+
+template< typename Matrix >
+void
+BICGStab< Matrix >::
+setSize( const VectorViewType& x )
 {
    r.setLike( x );
    r_ast.setLike( x );
-- 
GitLab