Skip to content
Snippets Groups Projects
Commit a220cfb4 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed and optimized the CG solver

parent 30b25f63
No related branches found
No related tags found
No related merge requests found
......@@ -12,8 +12,6 @@
#include "CG.h"
#include <TNL/Solvers/Linear/LinearResidueGetter.h>
namespace TNL {
namespace Solvers {
namespace Linear {
......@@ -32,9 +30,7 @@ CG< Matrix >::
solve( ConstVectorViewType b, VectorViewType x )
{
this->setSize( this->matrix->getRows() );
this->resetIterations();
this->setResidue( this->getConvergenceResidue() + 1.0 );
RealType alpha, beta, s1, s2;
RealType bNorm = b.lpNorm( ( RealType ) 2.0 );
......@@ -43,23 +39,26 @@ solve( ConstVectorViewType b, VectorViewType x )
* r_0 = b - A x_0, p_0 = r_0
*/
this->matrix->vectorProduct( x, r );
r. addVector( b, 1.0, -1.0 );
r.addVector( b, 1.0, -1.0 );
p = r;
s1 = r.scalarProduct( r );
// TODO
//this->setResidue( std::sqrt(s1) / bNorm );
this->setResidue( std::sqrt(s1) );
while( this->nextIteration() )
{
/****
* 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
*/
this->matrix->vectorProduct( p, Ap );
s1 = r.scalarProduct( r );
s2 = Ap.scalarProduct( p );
/****
* if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
*/
if( s2 == 0.0 ) alpha = 0.0;
if( s2 == 0.0 ) break;
else alpha = s1 / s2;
/****
......@@ -70,14 +69,13 @@ solve( ConstVectorViewType b, VectorViewType x )
/****
* 3. r_{j+1} = r_j - \alpha_j A * p_j
*/
new_r = r;
new_r.addVector( Ap, -alpha );
new_r.addVectors( r, 1, Ap, -alpha, 0 );
/****
* 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
*/
s1 = new_r. scalarProduct( new_r );
s2 = r. scalarProduct( r );
s2 = s1;
s1 = new_r.scalarProduct( new_r );
/****
* if s2 = 0 => r = 0 => we have the solution
......@@ -88,17 +86,17 @@ solve( ConstVectorViewType b, VectorViewType x )
/****
* 5. p_{j+1} = r_{j+1} + beta_j * p_j
*/
p. addVector( new_r, 1.0, beta );
p.addVector( new_r, 1.0, beta );
/****
* 6. r_{j+1} = new_r
*/
new_r.swap( r );
if( this->getIterations() % 10 == 0 )
this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) );
// TODO
//this->setResidue( std::sqrt(s1) / bNorm );
this->setResidue( std::sqrt(s1) );
}
this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) );
this->refreshSolverMonitor( true );
return this->checkConvergence();
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment