Loading src/TNL/Solvers/Optimization/NesterovMomentum.h 0 → 100644 +61 −0 Original line number Diff line number Diff line // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once #include <TNL/Solvers/IterativeSolver.h> namespace TNL { namespace Solvers { namespace Optimization { template< typename Vector, typename SolverMonitor = IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > > class NesterovMomentum : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor > { public: using RealType = typename Vector::RealType; using DeviceType = typename Vector::DeviceType; using IndexType = typename Vector::IndexType; using VectorType = Vector; using VectorView = typename Vector::ViewType; NesterovMomentum() = default; static void configSetup( Config::ConfigDescription& config, const String& prefix = "" ); bool setup( const Config::ParameterContainer& parameters, const String& prefix = "" ); void setRelaxation( const RealType& lambda ); const RealType& getRelaxation() const; void setMomentum( const RealType& beta ); const RealType& getMomentum() const; template< typename GradientGetter > bool solve( VectorView& w, GradientGetter&& getGradient ); protected: RealType relaxation = 1.0, momentum = 0.9; VectorType gradient, v, aux; }; } //namespace Optimization } //namespace Solvers } //namespace TNL #include <TNL/Solvers/Optimization/Momentum.hpp> src/TNL/Solvers/Optimization/NesterovMomentum.hpp 0 → 100644 +122 −0 Original line number Diff line number Diff line // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once #include <TNL/Solvers/Optimization/NesterovMomentum.h> namespace TNL { namespace Solvers { namespace Optimization { template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: configSetup( Config::ConfigDescription& config, const String& prefix ) { IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix ); config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the momentum method.", 1.0 ); config.addEntry< double >( prefix + "momentum", "Momentum parameter for the momentum method.", 0.9 ); } template< typename Vector, typename SolverMonitor > bool NesterovMomentum< Vector, SolverMonitor >:: setup( const Config::ParameterContainer& parameters, const String& prefix ) { this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) ); this->setMomentum( parameters.getParameter< double >( prefix + "momentum" ) ); return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix ); } template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: setRelaxation( const RealType& lambda ) { this->relaxation = lambda; } template< typename Vector, typename SolverMonitor > auto NesterovMomentum< Vector, SolverMonitor >:: getRelaxation() const -> const RealType& { return this->relaxation; } template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: setMomentum( const RealType& beta ) { this->momentum = beta; } template< typename Vector, typename SolverMonitor > auto NesterovMomentum< Vector, SolverMonitor >:: getMomentum() const -> const RealType& { return this->momentum; } template< typename Vector, typename SolverMonitor > template< typename GradientGetter > bool NesterovMomentum< Vector, SolverMonitor >:: solve( VectorView& w, GradientGetter&& getGradient ) { this->gradient.setLike( w ); this->v.setLike( w ); this->aux.setLike( w ); auto gradient_view = gradient.getView(); auto w_view = w.getView(); auto v_view = v.getView(); auto aux_view = aux.getView(); this->gradient = 0.0; this->v = 0.0; ///// // Set necessary parameters this->resetIterations(); this->setResidue( this->getConvergenceResidue() + 1.0 ); ///// // Start the main loop while( 1 ) { ///// // Compute the gradient aux_view = w_view + v_view; getGradient( aux_view, gradient_view ); v_view = this->momentum * v_view - this->relaxation * gradient_view; RealType lastResidue = this->getResidue(); this->setResidue( Algorithms::reduce< DeviceType >( ( IndexType ) 0, w_view.getSize(), [=] __cuda_callable__ ( IndexType i ) mutable { w_view[ i ] += v_view[ i ]; return abs( v_view[ i ] ); }, TNL::Plus() ) / ( this->relaxation * ( RealType ) w.getSize() ) ); if( ! this->nextIteration() ) return this->checkConvergence(); ///// // Check the stop condition if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() ) return true; } return false; // just to avoid warnings } } //namespace Optimization } //namespace Solvers } //namespace TNL Loading
src/TNL/Solvers/Optimization/NesterovMomentum.h 0 → 100644 +61 −0 Original line number Diff line number Diff line // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once #include <TNL/Solvers/IterativeSolver.h> namespace TNL { namespace Solvers { namespace Optimization { template< typename Vector, typename SolverMonitor = IterativeSolverMonitor< typename Vector::RealType, typename Vector::IndexType > > class NesterovMomentum : public IterativeSolver< typename Vector::RealType, typename Vector::IndexType, SolverMonitor > { public: using RealType = typename Vector::RealType; using DeviceType = typename Vector::DeviceType; using IndexType = typename Vector::IndexType; using VectorType = Vector; using VectorView = typename Vector::ViewType; NesterovMomentum() = default; static void configSetup( Config::ConfigDescription& config, const String& prefix = "" ); bool setup( const Config::ParameterContainer& parameters, const String& prefix = "" ); void setRelaxation( const RealType& lambda ); const RealType& getRelaxation() const; void setMomentum( const RealType& beta ); const RealType& getMomentum() const; template< typename GradientGetter > bool solve( VectorView& w, GradientGetter&& getGradient ); protected: RealType relaxation = 1.0, momentum = 0.9; VectorType gradient, v, aux; }; } //namespace Optimization } //namespace Solvers } //namespace TNL #include <TNL/Solvers/Optimization/Momentum.hpp>
src/TNL/Solvers/Optimization/NesterovMomentum.hpp 0 → 100644 +122 −0 Original line number Diff line number Diff line // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once #include <TNL/Solvers/Optimization/NesterovMomentum.h> namespace TNL { namespace Solvers { namespace Optimization { template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: configSetup( Config::ConfigDescription& config, const String& prefix ) { IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix ); config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the momentum method.", 1.0 ); config.addEntry< double >( prefix + "momentum", "Momentum parameter for the momentum method.", 0.9 ); } template< typename Vector, typename SolverMonitor > bool NesterovMomentum< Vector, SolverMonitor >:: setup( const Config::ParameterContainer& parameters, const String& prefix ) { this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) ); this->setMomentum( parameters.getParameter< double >( prefix + "momentum" ) ); return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix ); } template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: setRelaxation( const RealType& lambda ) { this->relaxation = lambda; } template< typename Vector, typename SolverMonitor > auto NesterovMomentum< Vector, SolverMonitor >:: getRelaxation() const -> const RealType& { return this->relaxation; } template< typename Vector, typename SolverMonitor > void NesterovMomentum< Vector, SolverMonitor >:: setMomentum( const RealType& beta ) { this->momentum = beta; } template< typename Vector, typename SolverMonitor > auto NesterovMomentum< Vector, SolverMonitor >:: getMomentum() const -> const RealType& { return this->momentum; } template< typename Vector, typename SolverMonitor > template< typename GradientGetter > bool NesterovMomentum< Vector, SolverMonitor >:: solve( VectorView& w, GradientGetter&& getGradient ) { this->gradient.setLike( w ); this->v.setLike( w ); this->aux.setLike( w ); auto gradient_view = gradient.getView(); auto w_view = w.getView(); auto v_view = v.getView(); auto aux_view = aux.getView(); this->gradient = 0.0; this->v = 0.0; ///// // Set necessary parameters this->resetIterations(); this->setResidue( this->getConvergenceResidue() + 1.0 ); ///// // Start the main loop while( 1 ) { ///// // Compute the gradient aux_view = w_view + v_view; getGradient( aux_view, gradient_view ); v_view = this->momentum * v_view - this->relaxation * gradient_view; RealType lastResidue = this->getResidue(); this->setResidue( Algorithms::reduce< DeviceType >( ( IndexType ) 0, w_view.getSize(), [=] __cuda_callable__ ( IndexType i ) mutable { w_view[ i ] += v_view[ i ]; return abs( v_view[ i ] ); }, TNL::Plus() ) / ( this->relaxation * ( RealType ) w.getSize() ) ); if( ! this->nextIteration() ) return this->checkConvergence(); ///// // Check the stop condition if( this->getConvergenceResidue() != 0.0 && this->getResidue() < this -> getConvergenceResidue() ) return true; } return false; // just to avoid warnings } } //namespace Optimization } //namespace Solvers } //namespace TNL