Commit c218dd6d authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added preconditioning to the CG solver

parent ca5b9f96
Loading
Loading
Loading
Loading
+1 −1
Original line number Original line Diff line number Diff line
@@ -37,7 +37,7 @@ public:
protected:
protected:
   void setSize( IndexType size );
   void setSize( IndexType size );


   Containers::Vector< RealType, DeviceType, IndexType >  r, new_r, p, Ap;
   Containers::Vector< RealType, DeviceType, IndexType >  r, p, Ap, z;
};
};


} // namespace Linear
} // namespace Linear
+64 −45
Original line number Original line Diff line number Diff line
@@ -33,81 +33,100 @@ solve( ConstVectorViewType b, VectorViewType x )
   this->resetIterations();
   this->resetIterations();


   RealType alpha, beta, s1, s2;
   RealType alpha, beta, s1, s2;
   RealType bNorm = b.lpNorm( ( RealType ) 2.0 );

   // initialize the norm of the preconditioned right-hand-side
   RealType normb;
   if( this->preconditioner ) {
      this->preconditioner->solve( b, r );
      normb = r.lpNorm( 2.0 );
   }
   else
      normb = b.lpNorm( 2.0 );
   if( normb == 0.0 )
      normb = 1.0;


   /****
   /****
    * r_0 = b - A x_0, p_0 = r_0
    * r_0 = b - A x_0, p_0 = r_0
    */
    */
   this->matrix->vectorProduct( x, r );
   this->matrix->vectorProduct( x, r );
   r.addVector( b, 1.0, -1.0 );
   r.addVector( b, 1.0, -1.0 );
   p = r;


   if( this->preconditioner ) {
      // z_0 = M^{-1} r_0
      this->preconditioner->solve( r, z );
      // p_0 = z_0
      p = z;
      // s1 = (r_0, z_0)
      s1 = r.scalarProduct( z );
   }
   else {
      // p_0 = r_0
      p = r;
      // s1 = (r_0, r_0)
      s1 = r.scalarProduct( r );
      s1 = r.scalarProduct( r );
   // TODO
   }
   //this->setResidue( std::sqrt(s1) / bNorm );

   this->setResidue( std::sqrt(s1) );
   this->setResidue( std::sqrt(s1) / normb );


   while( this->nextIteration() )
   while( this->nextIteration() )
   {
   {
      /****
      // s2 = (A * p_j, p_j)
       * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
       */
      this->matrix->vectorProduct( p, Ap );
      this->matrix->vectorProduct( p, Ap );
      s2 = Ap.scalarProduct( p );
      s2 = Ap.scalarProduct( p );


      /****
      // if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
       * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
      if( s2 == 0.0 ) {
       */
         this->setResidue( 0.0 );
      if( s2 == 0.0 ) break;
         break;
      else alpha = s1 / s2;
      }


      /****
      // alpha_j = (r_j, z_j) / (A * p_j, p_j)
       * 2. x_{j+1} = x_j + \alpha_j p_j
      alpha = s1 / s2;
       */

      // x_{j+1} = x_j + alpha_j p_j
      x.addVector( p, alpha );
      x.addVector( p, alpha );


      /****
      // r_{j+1} = r_j - alpha_j A * p_j
       * 3. r_{j+1} = r_j - \alpha_j A * p_j
      r.addVector( Ap, -alpha );
       */
      new_r.addVectors( r, 1, Ap, -alpha, 0 );


      /****
      if( this->preconditioner ) {
       * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
         // z_{j+1} = M^{-1} * r_{j+1}
       */
         this->preconditioner->solve( r, z );
         // beta_j = (r_{j+1}, z_{j+1}) / (r_j, z_j)
         s2 = s1;
         s2 = s1;
      s1 = new_r.scalarProduct( new_r );
         s1 = r.scalarProduct( z );
      }
      else {
         // beta_j = (r_{j+1}, r_{j+1}) / (r_j, r_j)
         s2 = s1;
         s1 = r.scalarProduct( r );
      }


      /****
      // if s2 = 0 => r = 0 => we have the solution
       * if s2 = 0 => r = 0 => we have the solution
       */
      if( s2 == 0.0 ) beta = 0.0;
      if( s2 == 0.0 ) beta = 0.0;
      else beta = s1 / s2;
      else beta = s1 / s2;


      /****
      if( this->preconditioner )
       * 5. p_{j+1} = r_{j+1} + beta_j * p_j
         // p_{j+1} = z_{j+1} + beta_j * p_j
       */
         p.addVector( z, 1.0, beta );
      p.addVector( new_r, 1.0, beta );
      else

         // p_{j+1} = r_{j+1} + beta_j * p_j
      /****
         p.addVector( r, 1.0, beta );
       * 6. r_{j+1} = new_r
       */
      new_r.swap( r );


      // TODO
      this->setResidue( std::sqrt(s1) / normb );
      //this->setResidue( std::sqrt(s1) / bNorm );
      this->setResidue( std::sqrt(s1) );
   }
   }
   this->refreshSolverMonitor( true );
   this->refreshSolverMonitor( true );
   return this->checkConvergence();
   return this->checkConvergence();
}
}


template< typename Matrix >
template< typename Matrix >
void CG< Matrix > :: setSize( IndexType size )
void CG< Matrix >::
setSize( IndexType size )
{
{
   r.setSize( size );
   r.setSize( size );
   new_r.setSize( size );
   p.setSize( size );
   p.setSize( size );
   Ap.setSize( size );
   Ap.setSize( size );
   z.setSize( size );
}
}


} // namespace Linear
} // namespace Linear