Commit c0716242 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixes in the memomentum optimization method.

parent 98429a42
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ public:

protected:

   RealType relaxation = 1.0, momentum = 1.0;
   RealType relaxation = 1.0, momentum = 0.9;

   VectorType gradient, v;

+9 −2
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ configSetup( Config::ConfigDescription& config, const String& prefix )
{
   IterativeSolver< RealType, IndexType, SolverMonitor >::configSetup( config, prefix );
   config.addEntry< double >( prefix + "relaxation", "Relaxation parameter for the momentum method.", 1.0 );
   config.addEntry< double >( prefix + "momentum", "Momentum parameter for the momentum method.", 1.0 );
   config.addEntry< double >( prefix + "momentum", "Momentum parameter for the momentum method.", 0.9 );
}

template< typename Vector, typename SolverMonitor >
@@ -94,7 +94,14 @@ solve( VectorView& w, GradientGetter&& getGradient )
      v_view  = this->momentum * v_view - this->relaxation * gradient_view;

      RealType lastResidue = this->getResidue();
      this->setResidue( addAndReduceAbs( w_view, v_view, TNL::Plus(), ( RealType ) 0.0 ) / ( this->relaxation * ( RealType ) w.getSize() ) );
      this->setResidue(
         Algorithms::reduce< DeviceType >(
            ( IndexType ) 0, w_view.getSize(),
            [=] __cuda_callable__ ( IndexType i ) mutable {
               w_view[ i ] += v_view[ i ];
               return abs( v_view[ i ] );
            },
            TNL::Plus() ) / ( this->relaxation * ( RealType ) w.getSize() ) );

      if( ! this->nextIteration() )
         return this->checkConvergence();