diff --git a/src/TNL/Solvers/Linear/CG_impl.h b/src/TNL/Solvers/Linear/CG_impl.h index 8889451f2d32809a7dcf2a1dcb245b8ecfb25cf9..4b0a707d4adbf676c6c2abcb446543fac5bd7797 100644 --- a/src/TNL/Solvers/Linear/CG_impl.h +++ b/src/TNL/Solvers/Linear/CG_impl.h @@ -12,8 +12,6 @@ #include "CG.h" -#include <TNL/Solvers/Linear/LinearResidueGetter.h> - namespace TNL { namespace Solvers { namespace Linear { @@ -32,9 +30,7 @@ CG< Matrix >:: solve( ConstVectorViewType b, VectorViewType x ) { this->setSize( this->matrix->getRows() ); - this->resetIterations(); - this->setResidue( this->getConvergenceResidue() + 1.0 ); RealType alpha, beta, s1, s2; RealType bNorm = b.lpNorm( ( RealType ) 2.0 ); @@ -43,23 +39,26 @@ solve( ConstVectorViewType b, VectorViewType x ) * r_0 = b - A x_0, p_0 = r_0 */ this->matrix->vectorProduct( x, r ); - r. addVector( b, 1.0, -1.0 ); + r.addVector( b, 1.0, -1.0 ); p = r; + s1 = r.scalarProduct( r ); + // TODO + //this->setResidue( std::sqrt(s1) / bNorm ); + this->setResidue( std::sqrt(s1) ); + while( this->nextIteration() ) { /**** * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j ) */ this->matrix->vectorProduct( p, Ap ); - - s1 = r.scalarProduct( r ); s2 = Ap.scalarProduct( p ); /**** * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0) */ - if( s2 == 0.0 ) alpha = 0.0; + if( s2 == 0.0 ) break; else alpha = s1 / s2; /**** @@ -70,14 +69,13 @@ solve( ConstVectorViewType b, VectorViewType x ) /**** * 3. r_{j+1} = r_j - \alpha_j A * p_j */ - new_r = r; - new_r.addVector( Ap, -alpha ); + new_r.addVectors( r, 1, Ap, -alpha, 0 ); /**** * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j ) */ - s1 = new_r. scalarProduct( new_r ); - s2 = r. scalarProduct( r ); + s2 = s1; + s1 = new_r.scalarProduct( new_r ); /**** * if s2 = 0 => r = 0 => we have the solution @@ -88,17 +86,17 @@ solve( ConstVectorViewType b, VectorViewType x ) /**** * 5. p_{j+1} = r_{j+1} + beta_j * p_j */ - p. addVector( new_r, 1.0, beta ); + p.addVector( new_r, 1.0, beta ); /**** * 6. r_{j+1} = new_r */ new_r.swap( r ); - if( this->getIterations() % 10 == 0 ) - this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) ); + // TODO + //this->setResidue( std::sqrt(s1) / bNorm ); + this->setResidue( std::sqrt(s1) ); } - this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) ); this->refreshSolverMonitor( true ); return this->checkConvergence(); }