Commit a3e6550b authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Refactoring the linear solvers.

parent 674aebce
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
TODO: doladit iterativni resice 
      * nepocitaji se iterace
      * iterovat podle metody nextIteration
      * naspat metodu checkConvergence - ta se vola na konci a napise hlaksu, proc resic pripadne neskonvergoval
TODO: doladit vse s CUDA
TODO: doplnit mesh travelsals pro jine mesh entity nez cell
TODO: implementace maticovych resicu
+33 −27
Original line number Diff line number Diff line
@@ -121,12 +121,11 @@ bool tnlBICGStabSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vect
      bNorm = b. lpNorm( 2.0 );
   }

   while( this -> getIterations() < this -> getMaxIterations() &&
          this -> getResidue() > this -> getConvergenceResidue() )
   while( this->nextIteration() )
   {
      //dbgCout( "Starting BiCGStab iteration " << iter + 1 );

      // alpha_j = ( r_j, r^ast_0 ) / ( A * p_j, r^ast_0 )
      /****
       * alpha_j = ( r_j, r^ast_0 ) / ( A * p_j, r^ast_0 )
       */
      /*if( M ) // preconditioner
      {
         A. vectorProduct( p, M_tmp );
@@ -140,11 +139,14 @@ bool tnlBICGStabSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vect
      if( s2 == 0.0 ) alpha = 0.0;
      else alpha = rho / s2;

      // s_j = r_j - alpha_j * A p_j
      /****
       * s_j = r_j - alpha_j * A p_j
       */
      s. alphaXPlusBetaZ( 1.0, r, -alpha, Ap );

      // omega_j = ( A s_j, s_j ) / ( A s_j, A s_j )
      //dbgCout( "Computing As" );
      /****
       * omega_j = ( A s_j, s_j ) / ( A s_j, A s_j )
       */
      /*if( M ) // preconditioner
      {
         A. vectorProduct( s, M_tmp );
@@ -153,40 +155,44 @@ bool tnlBICGStabSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vect
      }
      else*/
          this -> matrix -> vectorProduct( s, As );
      s1 = s2 = 0.0;

      s1 = As. scalarProduct( s );
      s2 = As. scalarProduct( As );
      if( s2 == 0.0 ) omega = 0.0;
      else omega = s1 / s2;

      // x_{j+1} = x_j + alpha_j * p_j + omega_j * s_j
      // r_{j+1} = s_j - omega_j * A * s_j
      //dbgCout( "Computing new x and new r." );
      /****
       * x_{j+1} = x_j + alpha_j * p_j + omega_j * s_j
       */
      x. alphaXPlusBetaZPlusY( alpha, p, omega, s );
      
      /****
       * r_{j+1} = s_j - omega_j * A * s_j
       */
      r. alphaXPlusBetaZ( 1.0, s, -omega, As );

      // beta = alpha_j / omega_j * ( r_{j+1}, r^ast_0 ) / ( r_j, r^ast_0 )
      /****
       * beta = alpha_j / omega_j * ( r_{j+1}, r^ast_0 ) / ( r_j, r^ast_0 )
       */
      s1 = 0.0;
      s1 = r. scalarProduct( r_ast );
      if( rho == 0.0 ) beta = 0.0;
      else beta = ( s1 / rho ) * ( alpha / omega );
      rho = s1;

      // p_{j+1} = r_{j+1} + beta_j * ( p_j - omega_j * A p_j )
      /****
       * p_{j+1} = r_{j+1} + beta_j * ( p_j - omega_j * A p_j )
       */
      RealType residue = computeBICGStabNewP( p, r, beta, omega, Ap );

      residue /= bNorm;
      this->setResidue( residue );
      if( this->getIterations() % 10 == 0 )
         this->setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
      if( ! this -> nextIteration() )
         return false;
      this -> refreshSolverMonitor();
   }
   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
   //this->setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
   this->refreshSolverMonitor();
      if( this -> getResidue() > this -> getConvergenceResidue() ) return false;
   return true;
   return this->checkConvergence();
};

template< typename Matrix,
+44 −28
Original line number Diff line number Diff line
@@ -70,7 +70,9 @@ void tnlCGSolver< Matrix, Preconditioner > :: setPreconditioner( const Precondit
template< typename Matrix,
          typename Preconditioner >
   template< typename Vector, typename ResidueGetter >
bool tnlCGSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector& x )
bool
tnlCGSolver< Matrix, Preconditioner >::
solve( const Vector& b, Vector& x )
{
   if( ! this->setSize( matrix->getRows() ) ) return false;

@@ -80,54 +82,68 @@ bool tnlCGSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector& x
   RealType alpha, beta, s1, s2;
   RealType bNorm = b. lpNorm( ( RealType ) 2.0 );

   // r_0 = b - A x_0, p_0 = r_0
   /****
    * r_0 = b - A x_0, p_0 = r_0
    */
   this->matrix->vectorProduct( x, r );
   r. addVector( b, 1.0, -1.0 );
   p = r;

   while( this -> getIterations() < this -> getMaxIterations() &&
          this -> getResidue() > this -> getConvergenceResidue() )
   while( this->nextIteration() )
   {
      // 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
      /****
       * 1. alpha_j = ( r_j, r_j ) / ( A * p_j, p_j )
       */
      this->matrix->vectorProduct( p, Ap );

      s1 = r.scalarProduct( r );
      s2 = Ap.scalarProduct( p );
      s1 = s2 = 0.0;
      // if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)

      /****
       * if s2 = 0 => p = 0 => r = 0 => we have the solution (provided A != 0)
       */
      if( s2 == 0.0 ) alpha = 0.0;
      else alpha = s1 / s2;
      
      // 2. x_{j+1} = x_j + \alpha_j p_j
      /****
       * 2. x_{j+1} = x_j + \alpha_j p_j
       */
      x.addVector( p, alpha );
      
      // 3. r_{j+1} = r_j - \alpha_j A * p_j
      /****
       * 3. r_{j+1} = r_j - \alpha_j A * p_j
       */
      new_r = r;
      new_r.addVector( Ap, -alpha );

      //4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
      /****
       * 4. beta_j = ( r_{j+1}, r_{j+1} ) / ( r_j, r_j )
       */
      s1 = new_r. scalarProduct( new_r );
      s2 = r. scalarProduct( r );
      // if s2 = 0 => r = 0 => we have the solution

      /****
       * if s2 = 0 => r = 0 => we have the solution
       */
      if( s2 == 0.0 ) beta = 0.0;
      else beta = s1 / s2;

      // 5. p_{j+1} = r_{j+1} + beta_j * p_j
      /****
       * 5. p_{j+1} = r_{j+1} + beta_j * p_j
       */
      p. addVector( new_r, 1.0, beta );

      // 6. r_{j+1} = new_r
      /****
       * 6. r_{j+1} = new_r
       */
      new_r.swap( r );
      
      if( this -> getIterations() % 10 == 0 )
         this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
      if( ! this -> nextIteration() )
         return false;
      this -> refreshSolverMonitor();
   }
   this -> setResidue( ResidueGetter :: getResidue( *matrix, b, x, bNorm ) );
   this -> refreshSolverMonitor();
      if( this -> getResidue() > this -> getConvergenceResidue() ) return false;
   return true;
   return this->checkConvergence();
};

template< typename Matrix,
+14 −21
Original line number Diff line number Diff line
@@ -156,8 +156,7 @@ bool tnlGMRESSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
   vi. setName( "tnlGMRESSolver::vi" );
   tnlSharedVector< RealType, DeviceType, IndexType > vk;
   vk. setName( "tnlGMRESSolver::vk" );
   while( this -> getIterations() < this -> getMaxIterations() &&
          this -> getResidue() > this -> getConvergenceResidue() )
   while( this->nextIteration() )
   {
      const IndexType m = restarting;
      for( i = 0; i < m + 1; i ++ )
@@ -250,8 +249,6 @@ bool tnlGMRESSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
         if( this->getResidue() < this->getConvergenceResidue() )
         {
            update( i, m, _H, _s, _v, x );
            //if( this -> verbosity > 0 )
            //   this -> printOut();
            return true;
         }
         if( ! this->nextIteration() )
@@ -279,13 +276,9 @@ bool tnlGMRESSolver< Matrix, Preconditioner > :: solve( const Vector& b, Vector&
         beta = _r. lpNorm( ( RealType ) 2.0 );
      }
      this->setResidue( beta / normb );
      this -> refreshSolverMonitor();
      if( ! this -> nextIteration() )
         return false;
   }
   this->refreshSolverMonitor();
   if( this -> getResidue() > this -> getConvergenceResidue() ) return false;
   return true;
   return this->checkConvergence();
};

template< typename Matrix,
+17 −5
Original line number Diff line number Diff line
@@ -101,14 +101,26 @@ bool tnlIterativeSolver< Real, Index> :: nextIteration()
   }

   if( std::isnan( this->getResidue() ) )
   {
      //cerr << endl << "RES is Nan" << endl;
      return false;
   }
   if(( this->getResidue() > this->getDivergenceResidue() &&
         this->getIterations() > this->minIterations ) )
   {
      ///cerr << endl << "RES is over the divergence residue." << endl;
      return false;
   }
   if( this->getIterations() > this->getMaxIterations() )
   {
      //cerr << endl << "Max. iterations exceeded." << endl;
      return false;
   }
   if( this->getResidue() < this->getConvergenceResidue() )
   {
      //cerr << endl << "The solver has. converged." <<  endl;
      return false;
   }
   return true;
}

@@ -119,23 +131,23 @@ checkConvergence()
{
   if( std::isnan( this->getResidue() ) )
   {
      cerr << "The residue is NaN." << endl;
      cerr << endl << "The residue is NaN." << endl;
      return false;
   }
   if(( this->getResidue() > this->getDivergenceResidue() &&
         this->getIterations() > this->minIterations ) )
   {
      cerr << "The residue has exceeded allowed tolerance " << this->getDivergenceResidue() << "." << endl;
      cerr << endl  << "The residue has exceeded allowed tolerance " << this->getDivergenceResidue() << "." << endl;
      return false;
   }
   if( this->getIterations() > this->getMaxIterations() )
   if( this->getIterations() >= this->getMaxIterations() )
   {
      cerr << "The solver has exceeded maximal allowed number of iterations " << this->getMaxIterations() << "." << endl;
      cerr << endl  << "The solver has exceeded maximal allowed number of iterations " << this->getMaxIterations() << "." << endl;
      return false;
   }
   if( this->getResidue() > this->getConvergenceResidue() )
   {
      cerr << "The residue ( = " << this->getResidue() << " ) is too large." << endl;
      cerr << endl  << "The residue ( = " << this->getResidue() << " ) is too large( > " << this->getConvergenceResidue() << " )." << endl;
      return false;
   }
   return true;