Commit 8b58f8cb authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Merge branch 'TO/optimization' into 'develop'

To/optimization

See merge request !134
parents 947efd55 1c600882
Loading
Loading
Loading
Loading
+59 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/IterativeSolver.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {

/***
 * https://arxiv.org/pdf/1609.04747.pdf
 *
 */
template< typename Vector, typename SolverMonitor =  IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > >
class AdaGrad : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor >
{
public:
   using RealType = typename Vector::RealType;
   using DeviceType = typename Vector::DeviceType;
   using IndexType = typename Vector::IndexType;
   using VectorType = Vector;
   using VectorView = typename Vector::ViewType;

   AdaGrad() = default;

   static void
   configSetup( Config::ConfigDescription& config, const String& prefix = "" );

   bool
   setup( const Config::ParameterContainer& parameters, const String& prefix = "" );

   void
   setRelaxation( const RealType& lambda );

   const RealType&
   getRelaxation() const;

   template< typename GradientGetter >
   bool
   solve( VectorView& w, GradientGetter&& getGradient );

protected:

   RealType relaxation = 1.0, epsilon = 1.0e-8;

   VectorType gradient, a;

};

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL

#include <TNL/Solvers/Optimization/AdaGrad.hpp>
+93 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/Optimization/AdaGrad.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {


template< typename Vector, typename SolverMonitor >
void
AdaGrad< Vector, SolverMonitor >::
configSetup( Config::ConfigDescription& config, const String& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
}

template< typename Vector, typename SolverMonitor >
bool
AdaGrad< Vector, SolverMonitor >::
setup( const Config::ParameterContainer& parameters, const String& prefix )
{
   this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
   return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}

template< typename Vector, typename SolverMonitor >
void
AdaGrad< Vector, SolverMonitor >::
setRelaxation( const RealType& lambda )
{
   this->relaxation = lambda;
}

template< typename Vector, typename SolverMonitor >
auto
AdaGrad< Vector, SolverMonitor >::
getRelaxation() const -> const RealType&
{
   return this->relaxation;
}

template< typename Vector, typename SolverMonitor >
   template< typename GradientGetter >
bool
AdaGrad< Vector, SolverMonitor >::
solve( VectorView& w, GradientGetter&& getGradient )
{
   this->gradient.setLike( w );
   this->a.setLike( w );
   auto gradient_view = gradient.getView();
   auto w_view = w.getView();
   this->gradient = 0.0;
   this->a = 0.0;

   /////
   // Set necessary parameters
   this->resetIterations();
   this->setResidue( this->getConvergenceResidue() + 1.0 );

   /////
   // Start the main loop
   while( 1 )
   {
      /////
      // Compute the gradient
      getGradient( w_view, gradient_view );
      RealType lastResidue = this->getResidue();
      // a_i += grad_i^2
      a += gradient_view * gradient_view;
      this->setResidue( addAndReduceAbs( w_view, -this->relaxation / sqrt( this->a + this->epsilon  ) * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();

      /////
      // Check the stop condition
      if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() )
         return true;
   }
   return false; // just to avoid warnings
}

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL
+55 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/IterativeSolver.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {

template< typename Vector, typename SolverMonitor =  IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > >
class GradientDescent : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor >
{
public:
   using RealType = typename Vector::RealType;
   using DeviceType = typename Vector::DeviceType;
   using IndexType = typename Vector::IndexType;
   using VectorType = Vector;
   using VectorView = typename Vector::ViewType;

   GradientDescent() = default;

   static void
   configSetup( Config::ConfigDescription& config, const String& prefix = "" );

   bool
   setup( const Config::ParameterContainer& parameters, const String& prefix = "" );

   void
   setRelaxation( const RealType& lambda );

   const RealType&
   getRelaxation() const;

   template< typename GradientGetter >
   bool
   solve( VectorView& w, GradientGetter&& getGradient );

protected:

   RealType relaxation = 1.0;

   VectorType gradient;

};

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL

#include <TNL/Solvers/Optimization/GradientDescent.hpp>
+89 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/Optimization/GradientDescent.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {


template< typename Vector, typename SolverMonitor >
void
GradientDescent< Vector, SolverMonitor >::
configSetup( Config::ConfigDescription& config, const String& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 );
}

template< typename Vector, typename SolverMonitor >
bool
GradientDescent< Vector, SolverMonitor >::
setup( const Config::ParameterContainer& parameters, const String& prefix )
{
   this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) );
   return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix );
}

template< typename Vector, typename SolverMonitor >
void
GradientDescent< Vector, SolverMonitor >::
setRelaxation( const RealType& lambda )
{
   this->relaxation = lambda;
}

template< typename Vector, typename SolverMonitor >
auto
GradientDescent< Vector, SolverMonitor >::
getRelaxation() const -> const RealType&
{
   return this->relaxation;
}

template< typename Vector, typename SolverMonitor >
   template< typename GradientGetter >
bool
GradientDescent< Vector, SolverMonitor >::
solve( VectorView& w, GradientGetter&& getGradient )
{
   this->gradient.setLike( w );
   auto gradient_view = gradient.getView();
   auto w_view = w.getView();
   this->gradient = 0.0;

   /////
   // Set necessary parameters
   this->resetIterations();
   this->setResidue( this->getConvergenceResidue() + 1.0 );

   /////
   // Start the main loop
   while( 1 )
   {
      /////
      // Compute the gradient
      getGradient( w_view, gradient_view );
      RealType lastResidue = this->getResidue();
      this->setResidue( addAndReduceAbs( w_view, -this->relaxation * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();

      /////
      // Check the stop condition
      if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() )
         return true;
   }
   return false; // just to avoid warnings
}

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL
+61 −0
Original line number Diff line number Diff line
// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
//
// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
//
// SPDX-License-Identifier: MIT

#pragma once

#include <TNL/Solvers/IterativeSolver.h>

namespace TNL {
   namespace Solvers {
      namespace Optimization {

template< typename Vector, typename SolverMonitor =  IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > >
class Momentum : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor >
{
public:
   using RealType = typename Vector::RealType;
   using DeviceType = typename Vector::DeviceType;
   using IndexType = typename Vector::IndexType;
   using VectorType = Vector;
   using VectorView = typename Vector::ViewType;

   Momentum() = default;

   static void
   configSetup( Config::ConfigDescription& config, const String& prefix = "" );

   bool
   setup( const Config::ParameterContainer& parameters, const String& prefix = "" );

   void
   setRelaxation( const RealType& lambda );

   const RealType&
   getRelaxation() const;

   void
   setMomentum( const RealType& beta );

   const RealType&
   getMomentum() const;

   template< typename GradientGetter >
   bool
   solve( VectorView& w, GradientGetter&& getGradient );

protected:

   RealType relaxation = 1.0, momentum = 0.9;

   VectorType gradient, v;

};

      } //namespace Optimization
   } //namespace Solvers
} //namespace TNL

#include <TNL/Solvers/Optimization/Momentum.hpp>
Loading