Skip to content
Snippets Groups Projects
Commit 6c4f81bc authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added AdaGrad optimizer.

parent 376063bd
No related branches found
No related tags found
1 merge request!134To/optimization
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT
#pragma once
#include <TNL/Solvers/IterativeSolver.h>
namespace TNL {
namespace Solvers {
namespace Optimization {
/***
* https://arxiv.org/pdf/1609.04747.pdf
*
*/
template< typename Vector, typename SolverMonitor = IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > >
class AdaGrad : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor >
{
public:
using RealType = typename Vector::RealType;
using DeviceType = typename Vector::DeviceType;
using IndexType = typename Vector::IndexType;
using VectorType = Vector;
using VectorView = typename Vector::ViewType;
AdaGrad() = default;
static void
configSetup( Config::ConfigDescription& config, const String& prefix = "" );
bool
setup( const Config::ParameterContainer& parameters, const String& prefix = "" );
void
setRelaxation( const RealType& lambda );
const RealType&
getRelaxation() const;
template< typename GradientGetter >
bool
solve( VectorView& w, GradientGetter&& getGradient );
protected:
RealType relaxation = 1.0, epsilon = 1.0e-8;
VectorType gradient, a;
};
} //namespace Optimization
} //namespace Solvers
} //namespace TNL
#include <TNL/Solvers/Optimization/AdaGrad.hpp>
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT
#pragma once
#include <TNL/Solvers/Optimization/AdaGrad.h>
namespace TNL {
namespace Solvers {
namespace Optimization {
template< typename Vector, typename SolverMonitor >
void
AdaGrad< Vector, SolverMonitor >::
configSetup( Config::ConfigDescription& config, const String& prefix )
{
IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
}
template< typename Vector, typename SolverMonitor >
bool
AdaGrad< Vector, SolverMonitor >::
setup( const Config::ParameterContainer& parameters, const String& prefix )
{
this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}
template< typename Vector, typename SolverMonitor >
void
AdaGrad< Vector, SolverMonitor >::
setRelaxation( const RealType& lambda )
{
this->relaxation = lambda;
}
template< typename Vector, typename SolverMonitor >
auto
AdaGrad< Vector, SolverMonitor >::
getRelaxation() const -> const RealType&
{
return this->relaxation;
}
template< typename Vector, typename SolverMonitor >
template< typename GradientGetter >
bool
AdaGrad< Vector, SolverMonitor >::
solve( VectorView& w, GradientGetter&& getGradient )
{
this->gradient.setLike( w );
this->a.setLike( w );
auto gradient_view = gradient.getView();
auto w_view = w.getView();
this->gradient = 0.0;
this->a = 0.0;
/////
// Set necessary parameters
this->resetIterations();
this->setResidue( this->getConvergenceResidue() + 1.0 );
/////
// Start the main loop
while( 1 )
{
/////
// Compute the gradient
getGradient( w_view, gradient_view );
RealType lastResidue = this->getResidue();
// a_i += grad_i^2
a += gradient_view * gradient_view;
this->setResidue( addAndReduceAbs( w_view, -this->relaxation / sqrt( this->a + this->epsilon ) * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );
if( ! this->nextIteration() )
return this->checkConvergence();
/////
// Check the stop condition
if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() )
return true;
}
return false; // just to avoid warnings
}
} //namespace Optimization
} //namespace Solvers
} //namespace TNL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment