Commit a220cfb4 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed and optimized the CG solver

parent 30b25f63
Loading
Loading
Loading
Loading
+14 −16
Original line number Original line Diff line number Diff line
@@ -12,8 +12,6 @@


#include "CG.h"
#include "CG.h"


#include <TNL/Solvers/Linear/LinearResidueGetter.h>

namespace TNL {
namespace TNL {
namespace Solvers {
namespace Solvers {
namespace Linear {
namespace Linear {
@@ -32,9 +30,7 @@ CG< Matrix >::
solve( ConstVectorViewType b, VectorViewType x )
solve( ConstVectorViewType b, VectorViewType x )
{
{
   this->setSize( this->matrix->getRows() );
   this->setSize( this->matrix->getRows() );

   this->resetIterations();
   this->resetIterations();
   this->setResidue( this->getConvergenceResidue() + 1.0 );


   RealType alpha, beta, s1, s2;
   RealType alpha, beta, s1, s2;
   RealType bNorm = b.lpNorm( ( RealType ) 2.0 );
   RealType bNorm = b.lpNorm( ( RealType ) 2.0 );
@@ -46,20 +42,23 @@ solve( ConstVectorViewType b, VectorViewType x )
   r.addVector( b, 1.0, -1.0 );
   r.addVector( b, 1.0, -1.0 );
   p = r;
   p = r;


   s1 = r.scalarProduct( r );
   // TODO
   //this->setResidue( std::sqrt(s1) / bNorm );
   this->setResidue( std::sqrt(s1) );

   while( this->nextIteration() )
   while( this->nextIteration() )
   {
   {
      /****
      /****
       * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
       * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
       */
       */
      this->matrix->vectorProduct( p, Ap );
      this->matrix->vectorProduct( p, Ap );

      s1 = r.scalarProduct( r );
      s2 = Ap.scalarProduct( p );
      s2 = Ap.scalarProduct( p );


      /****
      /****
       * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
       * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
       */
       */
      if( s2 == 0.0 ) alpha = 0.0;
      if( s2 == 0.0 ) break;
      else alpha = s1 / s2;
      else alpha = s1 / s2;


      /****
      /****
@@ -70,14 +69,13 @@ solve( ConstVectorViewType b, VectorViewType x )
      /****
      /****
       * 3. r_{j+1} = r_j - \alpha_j A * p_j
       * 3. r_{j+1} = r_j - \alpha_j A * p_j
       */
       */
      new_r = r;
      new_r.addVectors( r, 1, Ap, -alpha, 0 );
      new_r.addVector( Ap, -alpha );


      /****
      /****
       * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
       * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
       */
       */
      s2 = s1;
      s1 = new_r.scalarProduct( new_r );
      s1 = new_r.scalarProduct( new_r );
      s2 = r. scalarProduct( r );


      /****
      /****
       * if s2 = 0 => r = 0 => we have the solution
       * if s2 = 0 => r = 0 => we have the solution
@@ -95,10 +93,10 @@ solve( ConstVectorViewType b, VectorViewType x )
       */
       */
      new_r.swap( r );
      new_r.swap( r );


      if( this->getIterations() % 10 == 0 )
      // TODO
         this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) );
      //this->setResidue( std::sqrt(s1) / bNorm );
      this->setResidue( std::sqrt(s1) );
   }
   }
   this->setResidue( LinearResidueGetter::getResidue( *this->matrix, x, b, bNorm ) );
   this->refreshSolverMonitor( true );
   this->refreshSolverMonitor( true );
   return this->checkConvergence();
   return this->checkConvergence();
}
}