From 8b0a59cb887c907c4db8f7bd7e9460cadf6d78d6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Tue, 1 Dec 2015 14:32:02 +0100
Subject: [PATCH] Fixing the TFQMR solver.

---
 .../linear/krylov/tnlTFQMRSolver_impl.h       | 101 +++++++++---------
 1 file changed, 52 insertions(+), 49 deletions(-)

diff --git a/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h b/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h
index 52adc5f30d..b4c7f7c271 100644
--- a/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h
+++ b/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h
@@ -87,77 +87,80 @@ bool tnlTFQMRSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
    this -> resetIterations();
    this -> setResidue( this -> getConvergenceResidue() + 1.0 );
 
-   RealType tau, theta, eta, rho, alpha;
-   const RealType bNorm = b. lpNorm( 2.0 );
-   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
+   RealType tau, theta, eta, rho, alpha, w_norm;
+   RealType b_norm = b. lpNorm( 2.0 );
+   if( b_norm == 0.0 )
+       b_norm = 1.0;
 
-   dbgCout( "Computing Ax" );
    this -> matrix -> vectorProduct( x, r );
-
-   /*if( M )
-   {
-   }
-   else*/
+   r. addVector( b, 1.0, -1.0 );
+   w = u = r;
+   matrix -> vectorProduct( u, Au );
+   v = Au;
+   d. setValue( 0.0 );
+   tau = r. lpNorm( 2.0 );
+   theta = eta = 0.0;
+   r_ast = r;
+   rho = r_ast. scalarProduct( r );
+   alpha = 0.0; // TODO
+
+   this->resetIterations();
+   this -> setResidue( tau / b_norm );
+
+   while( this->nextIteration() )
    {
-      r. addVector( b, -1.0, -1.0 );
-      w = u = r;
-      matrix -> vectorProduct( u, v );
-      d. setValue( 0.0 );
-      tau = r. lpNorm( 2.0 );
-      theta = 0.0;
-      eta = 0.0;
-      r_ast = r;
-      //cerr << "r_ast = " << r_ast << endl;
-      rho = this -> r_ast. scalarProduct( this -> r_ast );
-   }
+      // start counting from 0
+      const IndexType iter = this->getIterations() - 1;
 
-   while( this -> getIterations() < this -> getMaxIterations() &&
-          this -> getResidue() > this -> getConvergenceResidue() )
-   {
-      //dbgCout( "Starting TFQMR iteration " << iter + 1 );
+//      cerr << "Starting TFQMR iteration " << iter << endl;
 
-      if( this -> getIterations() % 2 == 0 )
-      {
-         //cerr << "rho = " << rho << endl;
+      if( iter % 2 == 0 ) {
          alpha = rho / v. scalarProduct( this -> r_ast );
-         //cerr << "new alpha = " << alpha << endl;
-         u_new.addVector( v, -alpha );
       }
-      matrix -> vectorProduct( u, Au );
+      else {
+         // not necessary in even iteration since the previous odd iteration
+         // already computed v_{m+1} = A*u_{m+1}
+         matrix -> vectorProduct( u, Au );
+      }
       w.addVector( Au, -alpha );
-      //cerr << "alpha = " << alpha << endl;
+//      cerr << "alpha = " << alpha << endl;
       //cerr << "theta * theta / alpha * eta = " << theta * theta / alpha * eta << endl;
-      d. addVector( u, 1.0, theta * theta / alpha * eta );
-      theta = w. lpNorm( 2.0 ) / tau;
-      const RealType c = sqrt( 1.0 + theta * theta );
+      d. addVector( u, 1.0, theta * theta * eta / alpha );
+      w_norm = w. lpNorm( 2.0 );
+//      cerr << "w_norm / b_norm = residue = " << w_norm / b_norm << endl;
+      theta = w_norm / tau;
+      const RealType c = 1.0 / sqrt( 1.0 + theta * theta );
       tau = tau * theta * c;
+//      cerr << "tau * sqrt(m+1) = " << tau * sqrt(iter+1) << endl;
       eta = c * c  * alpha;
       //cerr << "eta = " << eta << endl;
       x.addVector( d, eta );
-      if( this -> getIterations() % 2 == 1 )
+
+      this->setResidue( tau * sqrt(iter+1) / b_norm );
+      if( iter > this->getMinIterations() && this->getResidue() < this->getConvergenceResidue() ) {
+          break;
+      }
+
+      if( iter % 2 == 1 )
       {
          const RealType rho_new  = w. scalarProduct( this -> r_ast );
          const RealType beta = rho_new / rho;
          rho = rho_new;
-         matrix -> vectorProduct( u, Au );
-         Au.addVector( v, beta );
+
          u.addVector( w, 1.0, beta );
-         matrix -> vectorProduct( u, Au_new );
-         v.addVectors( Au_new, 1.0, Au, beta );
+         v.addVector( Au, beta, beta * beta );
+         matrix -> vectorProduct( u, Au );
+         v.addVector( Au, 1.0 );
+      }
+      else {
+         u.addVector( v, -alpha );
       }
       
-      //this -> setResidue( residue );
-      //if( this -> getIterations() % 10 == 0 )
-         this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
-      if( ! this -> nextIteration() )
-         return false;
       this -> refreshSolverMonitor();
    }
-   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
-   this -> refreshSolverMonitor( true );
-   if( this -> getResidue() > this -> getConvergenceResidue() )
-      return false;
-   return true;
+
+   this->refreshSolverMonitor( true );
+   return this->checkConvergence();
 };
 
 template< typename Matrix,
-- 
GitLab