From c218dd6d211dcc8ffc8aa9e996c1ab21c80ba059 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz> Date: Fri, 8 Mar 2019 17:42:56 +0100 Subject: [PATCH] Added preconditioning to the CG solver --- src/TNL/Solvers/Linear/CG.h | 2 +- src/TNL/Solvers/Linear/CG_impl.h | 109 ++++++++++++++++++------------- 2 files changed, 65 insertions(+), 46 deletions(-) diff --git a/src/TNL/Solvers/Linear/CG.h b/src/TNL/Solvers/Linear/CG.h index 146dd79472..b87caf2478 100644 --- a/src/TNL/Solvers/Linear/CG.h +++ b/src/TNL/Solvers/Linear/CG.h @@ -37,7 +37,7 @@ public: protected: void setSize( IndexType size ); - Containers::Vector< RealType, DeviceType, IndexType > r, new_r, p, Ap; + Containers::Vector< RealType, DeviceType, IndexType > r, p, Ap, z; }; } // namespace Linear diff --git a/src/TNL/Solvers/Linear/CG_impl.h b/src/TNL/Solvers/Linear/CG_impl.h index 4b0a707d4a..c8d6b2de6d 100644 --- a/src/TNL/Solvers/Linear/CG_impl.h +++ b/src/TNL/Solvers/Linear/CG_impl.h @@ -33,81 +33,100 @@ solve( ConstVectorViewType b, VectorViewType x ) this->resetIterations(); RealType alpha, beta, s1, s2; - RealType bNorm = b.lpNorm( ( RealType ) 2.0 ); + + // initialize the norm of the preconditioned right-hand-side + RealType normb; + if( this->preconditioner ) { + this->preconditioner->solve( b, r ); + normb = r.lpNorm( 2.0 ); + } + else + normb = b.lpNorm( 2.0 ); + if( normb == 0.0 ) + normb = 1.0; /**** * r_0 = b - A x_0, p_0 = r_0 */ this->matrix->vectorProduct( x, r ); r.addVector( b, 1.0, -1.0 ); - p = r; - s1 = r.scalarProduct( r ); - // TODO - //this->setResidue( std::sqrt(s1) / bNorm ); - this->setResidue( std::sqrt(s1) ); + if( this->preconditioner ) { + // z_0 = M^{-1} r_0 + this->preconditioner->solve( r, z ); + // p_0 = z_0 + p = z; + // s1 = (r_0, z_0) + s1 = r.scalarProduct( z ); + } + else { + // p_0 = r_0 + p = r; + // s1 = (r_0, r_0) + s1 = r.scalarProduct( r ); + } + + this->setResidue( std::sqrt(s1) / normb ); while( this->nextIteration() ) { - /**** - * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j ) - */ + // s2 = (A * p_j, p_j) this->matrix->vectorProduct( p, Ap ); s2 = Ap.scalarProduct( p ); - /**** - * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0) - */ - if( s2 == 0.0 ) break; - else alpha = s1 / s2; + // if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0) + if( s2 == 0.0 ) { + this->setResidue( 0.0 ); + break; + } - /**** - * 2. x_{j+1} = x_j + \alpha_j p_j - */ - x.addVector( p, alpha ); - - /**** - * 3. r_{j+1} = r_j - \alpha_j A * p_j - */ - new_r.addVectors( r, 1, Ap, -alpha, 0 ); + // alpha_j = (r_j, z_j) / (A * p_j, p_j) + alpha = s1 / s2; - /**** - * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j ) - */ - s2 = s1; - s1 = new_r.scalarProduct( new_r ); + // x_{j+1} = x_j + alpha_j p_j + x.addVector( p, alpha ); - /**** - * if s2 = 0 => r = 0 => we have the solution - */ + // r_{j+1} = r_j - alpha_j A * p_j + r.addVector( Ap, -alpha ); + + if( this->preconditioner ) { + // z_{j+1} = M^{-1} * r_{j+1} + this->preconditioner->solve( r, z ); + // beta_j = (r_{j+1}, z_{j+1}) / (r_j, z_j) + s2 = s1; + s1 = r.scalarProduct( z ); + } + else { + // beta_j = (r_{j+1}, r_{j+1}) / (r_j, r_j) + s2 = s1; + s1 = r.scalarProduct( r ); + } + + // if s2 = 0 => r = 0 => we have the solution if( s2 == 0.0 ) beta = 0.0; else beta = s1 / s2; - /**** - * 5. p_{j+1} = r_{j+1} + beta_j * p_j - */ - p.addVector( new_r, 1.0, beta ); - - /**** - * 6. r_{j+1} = new_r - */ - new_r.swap( r ); + if( this->preconditioner ) + // p_{j+1} = z_{j+1} + beta_j * p_j + p.addVector( z, 1.0, beta ); + else + // p_{j+1} = r_{j+1} + beta_j * p_j + p.addVector( r, 1.0, beta ); - // TODO - //this->setResidue( std::sqrt(s1) / bNorm ); - this->setResidue( std::sqrt(s1) ); + this->setResidue( std::sqrt(s1) / normb ); } this->refreshSolverMonitor( true ); return this->checkConvergence(); } template< typename Matrix > -void CG< Matrix > :: setSize( IndexType size ) +void CG< Matrix >:: +setSize( IndexType size ) { r.setSize( size ); - new_r.setSize( size ); p.setSize( size ); Ap.setSize( size ); + z.setSize( size ); } } // namespace Linear -- GitLab