Loading src/TNL/Solvers/Optimization/GradientDescent.h +5 −9 Original line number Diff line number Diff line /*************************************************************************** Timer.h - description ------------------- begin : Mar 14, 2016 copyright : (C) 2016 by Tomas Oberhuber email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once Loading src/TNL/Solvers/Optimization/GradientDescent.hpp +13 −17 Original line number Diff line number Diff line /*************************************************************************** Timer.h - description ------------------- begin : Mar 14, 2016 copyright : (C) 2016 by Tomas Oberhuber email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once Loading @@ -23,7 +19,7 @@ GradientDescent< Vector, SolverMonitor >:: configSetup( Config::ConfigDescription& config, const String& prefix ) { IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix ); config.addEntry< double >( prefix + "gd-relaxation", "Relaxation parameter for the gradient descent.", 1.0 ); config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 ); } template< typename Vector, typename SolverMonitor > Loading @@ -31,7 +27,7 @@ bool GradientDescent< Vector, SolverMonitor >:: setup( const Config::ParameterContainer& parameters, const String& prefix ) { this->setRelaxation( parameters.getParameter< double >( prefix + "gd-relaxation" ) ); this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) ); return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix ); } Loading @@ -57,10 +53,10 @@ bool GradientDescent< Vector, SolverMonitor >:: solve( VectorView& w, GradientGetter&& getGradient ) { this->aux.setLike( w ); auto aux_view = aux.getView(); this->gradient.setLike( w ); auto gradient_view = gradient.getView(); auto w_view = w.getView(); aux = 0.0; this->gradient = 0.0; ///// // Set necessary parameters Loading @@ -73,9 +69,9 @@ solve( VectorView& w, GradientGetter&& getGradient ) { ///// // Compute the gradient getGradient( w_view, aux_view ); getGradient( w_view, gradient_view ); RealType lastResidue = this->getResidue(); this->setResidue( addAndReduceAbs( w_view, this->relaxation * aux_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) ); this->setResidue( addAndReduceAbs( w_view, -this->relaxation * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) ); if( ! this->nextIteration() ) return this->checkConvergence(); Loading Loading
src/TNL/Solvers/Optimization/GradientDescent.h +5 −9 Original line number Diff line number Diff line /*************************************************************************** Timer.h - description ------------------- begin : Mar 14, 2016 copyright : (C) 2016 by Tomas Oberhuber email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once Loading
src/TNL/Solvers/Optimization/GradientDescent.hpp +13 −17 Original line number Diff line number Diff line /*************************************************************************** Timer.h - description ------------------- begin : Mar 14, 2016 copyright : (C) 2016 by Tomas Oberhuber email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Copyright (c) 2004-2022 Tomáš Oberhuber et al. // // This file is part of TNL - Template Numerical Library (https://tnl-project.org/) // // SPDX-License-Identifier: MIT #pragma once Loading @@ -23,7 +19,7 @@ GradientDescent< Vector, SolverMonitor >:: configSetup( Config::ConfigDescription& config, const String& prefix ) { IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix ); config.addEntry< double >( prefix + "gd-relaxation", "Relaxation parameter for the gradient descent.", 1.0 ); config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the gradient descent.", 1.0 ); } template< typename Vector, typename SolverMonitor > Loading @@ -31,7 +27,7 @@ bool GradientDescent< Vector, SolverMonitor >:: setup( const Config::ParameterContainer& parameters, const String& prefix ) { this->setRelaxation( parameters.getParameter< double >( prefix + "gd-relaxation" ) ); this->setRelaxation( parameters.getParameter< double >( prefix + "relaxation" ) ); return IterativeSolver< RealType, IndexType, SolverMonitor >::setup( parameters, prefix ); } Loading @@ -57,10 +53,10 @@ bool GradientDescent< Vector, SolverMonitor >:: solve( VectorView& w, GradientGetter&& getGradient ) { this->aux.setLike( w ); auto aux_view = aux.getView(); this->gradient.setLike( w ); auto gradient_view = gradient.getView(); auto w_view = w.getView(); aux = 0.0; this->gradient = 0.0; ///// // Set necessary parameters Loading @@ -73,9 +69,9 @@ solve( VectorView& w, GradientGetter&& getGradient ) { ///// // Compute the gradient getGradient( w_view, aux_view ); getGradient( w_view, gradient_view ); RealType lastResidue = this->getResidue(); this->setResidue( addAndReduceAbs( w_view, this->relaxation * aux_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) ); this->setResidue( addAndReduceAbs( w_view, -this->relaxation * gradient_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) ); if( ! this->nextIteration() ) return this->checkConvergence(); Loading