Commit 1c600882 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added RMSProp optimizer.

parent 6c4f81bc
Loading
Loading
Loading
Loading
+59 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/IterativeSolver.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {

/***
 * https://arxiv.org/pdf/1609.04747.pdf
 *
 */
template< typename Vector, typename SolverMonitor =  IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > >
class RMSProp : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor >
{
public:
   using RealType = typename Vector::RealType;
   using DeviceType = typename Vector::DeviceType;
   using IndexType = typename Vector::IndexType;
   using VectorType = Vector;
   using VectorView = typename Vector::ViewType;

   RMSProp() = default;

   static void
   configSetup( Config::ConfigDescription& config, const String& prefix = "" );

   bool
   setup( const Config::ParameterContainer& parameters, const String& prefix = "" );

   void
   setRelaxation( const RealType& lambda );

   const RealType&
   getRelaxation() const;

   template< typename GradientGetter >
   bool
   solve( VectorView& w, GradientGetter&& getGradient );

protected:

   RealType relaxation = 1.0, epsilon = 1.0e-8, beta = 0.9;

   VectorType gradient, a;

};

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL

#include <TNL/Solvers/Optimization/RMSProp.hpp>
+95 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/Optimization/RMSProp.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {


template< typename Vector, typename SolverMonitor >
void
RMSProp< Vector, SolverMonitor >::
configSetup( Config::ConfigDescription& config, const String& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
   config.addEntry< double >( prefix + "beta", "Momentum parameter for computing sum of squared gradients.", 0.9 );
}

template< typename Vector, typename SolverMonitor >
bool
RMSProp< Vector, SolverMonitor >::
setup( const Config::ParameterContainer& parameters, const String& prefix )
{
   this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
   this->beta = parameters.getParameter< double >( prefix + "beta" );
   return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}

template< typename Vector, typename SolverMonitor >
void
RMSProp< Vector, SolverMonitor >::
setRelaxation( const RealType& lambda )
{
   this->relaxation = lambda;
}

template< typename Vector, typename SolverMonitor >
auto
RMSProp< Vector, SolverMonitor >::
getRelaxation() const -> const RealType&
{
   return this->relaxation;
}

template< typename Vector, typename SolverMonitor >
   template< typename GradientGetter >
bool
RMSProp< Vector, SolverMonitor >::
solve( VectorView& w, GradientGetter&& getGradient )
{
   this->gradient.setLike( w );
   this->a.setLike( w );
   auto gradient_view = gradient.getView();
   auto w_view = w.getView();
   this->gradient = 0.0;
   this->a = 0.0;

   /////
   // Set necessary parameters
   this->resetIterations();
   this->setResidue( this->getConvergenceResidue() + 1.0 );

   /////
   // Start the main loop
   while( 1 )
   {
      /////
      // Compute the gradient
      getGradient( w_view, gradient_view );
      RealType lastResidue = this->getResidue();
      // a_i = beta * a_i + ( 1- beta ) * grad_i^2
      a = this->beta * a + ( 1.0  - this->beta ) * gradient_view * gradient_view;
      this->setResidue( addAndReduceAbs( w_view, -this->relaxation / sqrt( this->a + this->epsilon  ) * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();

      /////
      // Check the stop condition
      if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() )
         return true;
   }
   return false; // just to avoid warnings
}

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL