Skip to content
Snippets Groups Projects
Commit 8b0a59cb authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixing the TFQMR solver.

parent cd8a8ed4
No related branches found
No related tags found
No related merge requests found
...@@ -87,77 +87,80 @@ bool tnlTFQMRSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector& ...@@ -87,77 +87,80 @@ bool tnlTFQMRSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
this -> resetIterations(); this -> resetIterations();
this -> setResidue( this -> getConvergenceResidue() + 1.0 ); this -> setResidue( this -> getConvergenceResidue() + 1.0 );
RealType tau, theta, eta, rho, alpha; RealType tau, theta, eta, rho, alpha, w_norm;
const RealType bNorm = b. lpNorm( 2.0 ); RealType b_norm = b. lpNorm( 2.0 );
this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) ); if( b_norm == 0.0 )
b_norm = 1.0;
dbgCout( "Computing Ax" );
this -> matrix -> vectorProduct( x, r ); this -> matrix -> vectorProduct( x, r );
r. addVector( b, 1.0, -1.0 );
/*if( M ) w = u = r;
{ matrix -> vectorProduct( u, Au );
} v = Au;
else*/ d. setValue( 0.0 );
tau = r. lpNorm( 2.0 );
theta = eta = 0.0;
r_ast = r;
rho = r_ast. scalarProduct( r );
alpha = 0.0; // TODO
this->resetIterations();
this -> setResidue( tau / b_norm );
while( this->nextIteration() )
{ {
r. addVector( b, -1.0, -1.0 ); // start counting from 0
w = u = r; const IndexType iter = this->getIterations() - 1;
matrix -> vectorProduct( u, v );
d. setValue( 0.0 );
tau = r. lpNorm( 2.0 );
theta = 0.0;
eta = 0.0;
r_ast = r;
//cerr << "r_ast = " << r_ast << endl;
rho = this -> r_ast. scalarProduct( this -> r_ast );
}
while( this -> getIterations() < this -> getMaxIterations() && // cerr << "Starting TFQMR iteration " << iter << endl;
this -> getResidue() > this -> getConvergenceResidue() )
{
//dbgCout( "Starting TFQMR iteration " << iter + 1 );
if( this -> getIterations() % 2 == 0 ) if( iter % 2 == 0 ) {
{
//cerr << "rho = " << rho << endl;
alpha = rho / v. scalarProduct( this -> r_ast ); alpha = rho / v. scalarProduct( this -> r_ast );
//cerr << "new alpha = " << alpha << endl;
u_new.addVector( v, -alpha );
} }
matrix -> vectorProduct( u, Au ); else {
// not necessary in even iteration since the previous odd iteration
// already computed v_{m+1} = A*u_{m+1}
matrix -> vectorProduct( u, Au );
}
w.addVector( Au, -alpha ); w.addVector( Au, -alpha );
//cerr << "alpha = " << alpha << endl; // cerr << "alpha = " << alpha << endl;
//cerr << "theta * theta / alpha * eta = " << theta * theta / alpha * eta << endl; //cerr << "theta * theta / alpha * eta = " << theta * theta / alpha * eta << endl;
d. addVector( u, 1.0, theta * theta / alpha * eta ); d. addVector( u, 1.0, theta * theta * eta / alpha );
theta = w. lpNorm( 2.0 ) / tau; w_norm = w. lpNorm( 2.0 );
const RealType c = sqrt( 1.0 + theta * theta ); // cerr << "w_norm / b_norm = residue = " << w_norm / b_norm << endl;
theta = w_norm / tau;
const RealType c = 1.0 / sqrt( 1.0 + theta * theta );
tau = tau * theta * c; tau = tau * theta * c;
// cerr << "tau * sqrt(m+1) = " << tau * sqrt(iter+1) << endl;
eta = c * c * alpha; eta = c * c * alpha;
//cerr << "eta = " << eta << endl; //cerr << "eta = " << eta << endl;
x.addVector( d, eta ); x.addVector( d, eta );
if( this -> getIterations() % 2 == 1 )
this->setResidue( tau * sqrt(iter+1) / b_norm );
if( iter > this->getMinIterations() && this->getResidue() < this->getConvergenceResidue() ) {
break;
}
if( iter % 2 == 1 )
{ {
const RealType rho_new = w. scalarProduct( this -> r_ast ); const RealType rho_new = w. scalarProduct( this -> r_ast );
const RealType beta = rho_new / rho; const RealType beta = rho_new / rho;
rho = rho_new; rho = rho_new;
matrix -> vectorProduct( u, Au );
Au.addVector( v, beta );
u.addVector( w, 1.0, beta ); u.addVector( w, 1.0, beta );
matrix -> vectorProduct( u, Au_new ); v.addVector( Au, beta, beta * beta );
v.addVectors( Au_new, 1.0, Au, beta ); matrix -> vectorProduct( u, Au );
v.addVector( Au, 1.0 );
}
else {
u.addVector( v, -alpha );
} }
//this -> setResidue( residue );
//if( this -> getIterations() % 10 == 0 )
this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
if( ! this -> nextIteration() )
return false;
this -> refreshSolverMonitor(); this -> refreshSolverMonitor();
} }
this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
this -> refreshSolverMonitor( true ); this->refreshSolverMonitor( true );
if( this -> getResidue() > this -> getConvergenceResidue() ) return this->checkConvergence();
return false;
return true;
}; };
template< typename Matrix, template< typename Matrix,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment