From 8b0a59cb887c907c4db8f7bd7e9460cadf6d78d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz> Date: Tue, 1 Dec 2015 14:32:02 +0100 Subject: [PATCH] Fixing the TFQMR solver. --- .../linear/krylov/tnlTFQMRSolver_impl.h | 101 +++++++++--------- 1 file changed, 52 insertions(+), 49 deletions(-) diff --git a/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h b/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h index 52adc5f30d..b4c7f7c271 100644 --- a/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h +++ b/src/solvers/linear/krylov/tnlTFQMRSolver_impl.h @@ -87,77 +87,80 @@ bool tnlTFQMRSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector& this -> resetIterations(); this -> setResidue( this -> getConvergenceResidue() + 1.0 ); - RealType tau, theta, eta, rho, alpha; - const RealType bNorm = b. lpNorm( 2.0 ); - this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) ); + RealType tau, theta, eta, rho, alpha, w_norm; + RealType b_norm = b. lpNorm( 2.0 ); + if( b_norm == 0.0 ) + b_norm = 1.0; - dbgCout( "Computing Ax" ); this -> matrix -> vectorProduct( x, r ); - - /*if( M ) - { - } - else*/ + r. addVector( b, 1.0, -1.0 ); + w = u = r; + matrix -> vectorProduct( u, Au ); + v = Au; + d. setValue( 0.0 ); + tau = r. lpNorm( 2.0 ); + theta = eta = 0.0; + r_ast = r; + rho = r_ast. scalarProduct( r ); + alpha = 0.0; // TODO + + this->resetIterations(); + this -> setResidue( tau / b_norm ); + + while( this->nextIteration() ) { - r. addVector( b, -1.0, -1.0 ); - w = u = r; - matrix -> vectorProduct( u, v ); - d. setValue( 0.0 ); - tau = r. lpNorm( 2.0 ); - theta = 0.0; - eta = 0.0; - r_ast = r; - //cerr << "r_ast = " << r_ast << endl; - rho = this -> r_ast. scalarProduct( this -> r_ast ); - } + // start counting from 0 + const IndexType iter = this->getIterations() - 1; - while( this -> getIterations() < this -> getMaxIterations() && - this -> getResidue() > this -> getConvergenceResidue() ) - { - //dbgCout( "Starting TFQMR iteration " << iter + 1 ); +// cerr << "Starting TFQMR iteration " << iter << endl; - if( this -> getIterations() % 2 == 0 ) - { - //cerr << "rho = " << rho << endl; + if( iter % 2 == 0 ) { alpha = rho / v. scalarProduct( this -> r_ast ); - //cerr << "new alpha = " << alpha << endl; - u_new.addVector( v, -alpha ); } - matrix -> vectorProduct( u, Au ); + else { + // not necessary in even iteration since the previous odd iteration + // already computed v_{m+1} = A*u_{m+1} + matrix -> vectorProduct( u, Au ); + } w.addVector( Au, -alpha ); - //cerr << "alpha = " << alpha << endl; +// cerr << "alpha = " << alpha << endl; //cerr << "theta * theta / alpha * eta = " << theta * theta / alpha * eta << endl; - d. addVector( u, 1.0, theta * theta / alpha * eta ); - theta = w. lpNorm( 2.0 ) / tau; - const RealType c = sqrt( 1.0 + theta * theta ); + d. addVector( u, 1.0, theta * theta * eta / alpha ); + w_norm = w. lpNorm( 2.0 ); +// cerr << "w_norm / b_norm = residue = " << w_norm / b_norm << endl; + theta = w_norm / tau; + const RealType c = 1.0 / sqrt( 1.0 + theta * theta ); tau = tau * theta * c; +// cerr << "tau * sqrt(m+1) = " << tau * sqrt(iter+1) << endl; eta = c * c * alpha; //cerr << "eta = " << eta << endl; x.addVector( d, eta ); - if( this -> getIterations() % 2 == 1 ) + + this->setResidue( tau * sqrt(iter+1) / b_norm ); + if( iter > this->getMinIterations() && this->getResidue() < this->getConvergenceResidue() ) { + break; + } + + if( iter % 2 == 1 ) { const RealType rho_new = w. scalarProduct( this -> r_ast ); const RealType beta = rho_new / rho; rho = rho_new; - matrix -> vectorProduct( u, Au ); - Au.addVector( v, beta ); + u.addVector( w, 1.0, beta ); - matrix -> vectorProduct( u, Au_new ); - v.addVectors( Au_new, 1.0, Au, beta ); + v.addVector( Au, beta, beta * beta ); + matrix -> vectorProduct( u, Au ); + v.addVector( Au, 1.0 ); + } + else { + u.addVector( v, -alpha ); } - //this -> setResidue( residue ); - //if( this -> getIterations() % 10 == 0 ) - this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) ); - if( ! this -> nextIteration() ) - return false; this -> refreshSolverMonitor(); } - this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) ); - this -> refreshSolverMonitor( true ); - if( this -> getResidue() > this -> getConvergenceResidue() ) - return false; - return true; + + this->refreshSolverMonitor( true ); + return this->checkConvergence(); }; template< typename Matrix, -- GitLab