Commit cfc7a125 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Renaming aux to gradient in the gradient descent method.

parent ea91987f
Loading
Loading
Loading
Loading
+5 −9
Original line number Diff line number Diff line
/***************************************************************************
                          Timer.h  -  description
                             -------------------
    begin                : Mar 14, 2016
    copyright            : (C) 2016 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

+13 −17
Original line number Diff line number Diff line
/***************************************************************************
                          Timer.h  -  description
                             -------------------
    begin                : Mar 14, 2016
    copyright            : (C) 2016 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

@@ -23,7 +19,7 @@ GradientDescent< Vector, SolverMonitor >::
configSetup( Config::ConfigDescription& config, const String& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "gd-relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
}

template< typename Vector, typename SolverMonitor >
@@ -31,7 +27,7 @@ bool
GradientDescent< Vector, SolverMonitor >::
setup( const Config::ParameterContainer& parameters, const String& prefix )
{
   this->setRelaxation( parameters.getParameter< double >( prefix + "gd-relaxation" ) );
   this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
   return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}

@@ -57,10 +53,10 @@ bool
GradientDescent< Vector, SolverMonitor >::
solve( VectorView& w, GradientGetter&& getGradient )
{
   this->aux.setLike( w );
   auto aux_view = aux.getView();
   this->gradient.setLike( w );
   auto gradient_view = gradient.getView();
   auto w_view = w.getView();
   aux = 0.0;
   this->gradient = 0.0;

   /////
   // Set necessary parameters
@@ -73,9 +69,9 @@ solve( VectorView& w, GradientGetter&& getGradient )
   {
      /////
      // Compute the gradient
      getGradient( w_view, aux_view );
      getGradient( w_view, gradient_view );
      RealType lastResidue = this->getResidue();
      this->setResidue( addAndReduceAbs( w_view, this->relaxation * aux_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );
      this->setResidue( addAndReduceAbs( w_view, -this->relaxation * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();