Commit 8b0a59cb authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixing the TFQMR solver.

parent cd8a8ed4
Loading
Loading
Loading
Loading
+52 −49
Original line number Diff line number Diff line
@@ -87,77 +87,80 @@ bool tnlTFQMRSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
   this -> resetIterations();
   this -> setResidue( this -> getConvergenceResidue() + 1.0 );

   RealType tau, theta, eta, rho, alpha;
   const RealType bNorm = b. lpNorm( 2.0 );
   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
   RealType tau, theta, eta, rho, alpha, w_norm;
   RealType b_norm = b. lpNorm( 2.0 );
   if( b_norm == 0.0 )
       b_norm = 1.0;

   dbgCout( "Computing Ax" );
   this -> matrix -> vectorProduct( x, r );

   /*if( M )
   {
   }
   else*/
   {
      r. addVector( b, -1.0, -1.0 );
   r. addVector( b, 1.0, -1.0 );
   w = u = r;
      matrix -> vectorProduct( u, v );
   matrix -> vectorProduct( u, Au );
   v = Au;
   d. setValue( 0.0 );
   tau = r. lpNorm( 2.0 );
      theta = 0.0;
      eta = 0.0;
   theta = eta = 0.0;
   r_ast = r;
      //cerr << "r_ast = " << r_ast << endl;
      rho = this -> r_ast. scalarProduct( this -> r_ast );
   }
   rho = r_ast. scalarProduct( r );
   alpha = 0.0; // TODO

   while( this -> getIterations() < this -> getMaxIterations() &&
          this -> getResidue() > this -> getConvergenceResidue() )
   {
      //dbgCout( "Starting TFQMR iteration " << iter + 1 );
   this->resetIterations();
   this -> setResidue( tau / b_norm );

      if( this -> getIterations() % 2 == 0 )
   while( this->nextIteration() )
   {
         //cerr << "rho = " << rho << endl;
      // start counting from 0
      const IndexType iter = this->getIterations() - 1;

//      cerr << "Starting TFQMR iteration " << iter << endl;

      if( iter % 2 == 0 ) {
         alpha = rho / v. scalarProduct( this -> r_ast );
         //cerr << "new alpha = " << alpha << endl;
         u_new.addVector( v, -alpha );
      }
      else {
         // not necessary in even iteration since the previous odd iteration
         // already computed v_{m+1} = A*u_{m+1}
         matrix -> vectorProduct( u, Au );
      }
      w.addVector( Au, -alpha );
//      cerr << "alpha = " << alpha << endl;
      //cerr << "theta * theta / alpha * eta = " << theta * theta / alpha * eta << endl;
      d. addVector( u, 1.0, theta * theta / alpha * eta );
      theta = w. lpNorm( 2.0 ) / tau;
      const RealType c = sqrt( 1.0 + theta * theta );
      d. addVector( u, 1.0, theta * theta * eta / alpha );
      w_norm = w. lpNorm( 2.0 );
//      cerr << "w_norm / b_norm = residue = " << w_norm / b_norm << endl;
      theta = w_norm / tau;
      const RealType c = 1.0 / sqrt( 1.0 + theta * theta );
      tau = tau * theta * c;
//      cerr << "tau * sqrt(m+1) = " << tau * sqrt(iter+1) << endl;
      eta = c * c  * alpha;
      //cerr << "eta = " << eta << endl;
      x.addVector( d, eta );
      if( this -> getIterations() % 2 == 1 )

      this->setResidue( tau * sqrt(iter+1) / b_norm );
      if( iter > this->getMinIterations() && this->getResidue() < this->getConvergenceResidue() ) {
          break;
      }

      if( iter % 2 == 1 )
      {
         const RealType rho_new  = w. scalarProduct( this -> r_ast );
         const RealType beta = rho_new / rho;
         rho = rho_new;
         matrix -> vectorProduct( u, Au );
         Au.addVector( v, beta );

         u.addVector( w, 1.0, beta );
         matrix -> vectorProduct( u, Au_new );
         v.addVectors( Au_new, 1.0, Au, beta );
         v.addVector( Au, beta, beta * beta );
         matrix -> vectorProduct( u, Au );
         v.addVector( Au, 1.0 );
      }
      else {
         u.addVector( v, -alpha );
      }
      
      //this -> setResidue( residue );
      //if( this -> getIterations() % 10 == 0 )
         this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
      if( ! this -> nextIteration() )
         return false;
      this -> refreshSolverMonitor();
   }
   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );

   this->refreshSolverMonitor( true );
   if( this -> getResidue() > this -> getConvergenceResidue() )
      return false;
   return true;
   return this->checkConvergence();
};

template< typename Matrix,