From d4fa1079adca7fe491e071a0618edce1587c54e6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Tue, 25 Jun 2019 22:57:50 +0200
Subject: [PATCH] Refactoring the Euler solver.

---
 .../Algorithms/CommonVectorOperations.hpp     |  2 +-
 src/TNL/Solvers/IterativeSolver.h             |  2 +-
 src/TNL/Solvers/ODE/Euler.h                   | 21 +++++++--------
 src/TNL/Solvers/ODE/Euler.hpp                 | 27 +++++++++++++------
 4 files changed, 30 insertions(+), 22 deletions(-)

diff --git a/src/TNL/Containers/Algorithms/CommonVectorOperations.hpp b/src/TNL/Containers/Algorithms/CommonVectorOperations.hpp
index a9e9161cf1..b8646af9e6 100644
--- a/src/TNL/Containers/Algorithms/CommonVectorOperations.hpp
+++ b/src/TNL/Containers/Algorithms/CommonVectorOperations.hpp
@@ -48,7 +48,7 @@ getVectorMin( const Vector& v )
    using IndexType = typename Vector::IndexType;
 
    const auto* data = v.getData();
-   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return data[ i ]; };
+   auto fetch = [=] __cuda_callable__ ( IndexType i ) -> RealType { return data[ i ]; };
    auto reduction = [=] __cuda_callable__ ( ResultType& a, const ResultType& b ) { a =  TNL::min( a, b ); };
    auto volatileReduction = [=] __cuda_callable__ ( volatile ResultType& a, volatile ResultType& b ) { a =  TNL::min( a, b ); };
    return Reduction< DeviceType >::reduce( v.getSize(), reduction, volatileReduction, fetch, std::numeric_limits< ResultType >::max() );
diff --git a/src/TNL/Solvers/IterativeSolver.h b/src/TNL/Solvers/IterativeSolver.h
index 201b20ae63..be0ee9f1d2 100644
--- a/src/TNL/Solvers/IterativeSolver.h
+++ b/src/TNL/Solvers/IterativeSolver.h
@@ -15,7 +15,7 @@
 #include <TNL/Solvers/IterativeSolverMonitor.h>
 
 namespace TNL {
-namespace Solvers {   
+namespace Solvers {
 
 template< typename Real, typename Index >
 class IterativeSolver
diff --git a/src/TNL/Solvers/ODE/Euler.h b/src/TNL/Solvers/ODE/Euler.h
index 5971fb8c04..c782b08cd9 100644
--- a/src/TNL/Solvers/ODE/Euler.h
+++ b/src/TNL/Solvers/ODE/Euler.h
@@ -25,13 +25,13 @@ class Euler : public ExplicitSolver< Problem >
 {
    public:
 
-   typedef Problem  ProblemType;
-   typedef typename Problem :: DofVectorType DofVectorType;
-   typedef typename Problem :: RealType RealType;
-   typedef typename Problem :: DeviceType DeviceType;
-   typedef typename Problem :: IndexType IndexType;
-   typedef Pointers::SharedPointer<  DofVectorType, DeviceType > DofVectorPointer;
-
+   using ProblemType = Problem;
+   using DofVectorType = typename ProblemType::DofVectorType;
+   using RealType = typename ProblemType::RealType;
+   using DeviceType = typename ProblemType::DeviceType;
+   using IndexType  = typename ProblemType::IndexType;
+   using DofVectorView = typename ViewTypeGetter< DofVectorType >::Type;
+   using DofVectorPointer = Pointers::SharedPointer<  DofVectorType, DeviceType >;
 
    Euler();
 
@@ -54,16 +54,13 @@ class Euler : public ExplicitSolver< Problem >
                              RealType tau,
                              RealType& currentResidue );
 
-   
-   DofVectorPointer k1;
+   DofVectorPointer _k1;
 
    RealType cflCondition;
- 
-   //Timer timer, updateTimer;
 };
 
 } // namespace ODE
 } // namespace Solvers
 } // namespace TNL
 
-#include <TNL/Solvers/ODE/Euler_impl.h>
+#include <TNL/Solvers/ODE/Euler.hpp>
diff --git a/src/TNL/Solvers/ODE/Euler.hpp b/src/TNL/Solvers/ODE/Euler.hpp
index 0b9eed1f8e..5e2c4979e7 100644
--- a/src/TNL/Solvers/ODE/Euler.hpp
+++ b/src/TNL/Solvers/ODE/Euler.hpp
@@ -72,14 +72,16 @@ const typename Problem :: RealType& Euler< Problem > :: getCFLCondition() const
 }
 
 template< typename Problem >
-bool Euler< Problem > :: solve( DofVectorPointer& u )
+bool Euler< Problem > :: solve( DofVectorPointer& _u )
 {
    /****
     * First setup the supporting meshes k1...k5 and k_tmp.
     */
    //timer.start();
-   k1->setLike( *u );
-   k1->setValue( 0.0 );
+   _k1->setLike( *_u );
+   DofVectorView k1 = _k1->getView();
+   DofVectorView u = _u->getView();
+   k1 = 0.0;
 
 
    /****
@@ -103,14 +105,14 @@ bool Euler< Problem > :: solve( DofVectorPointer& u )
        * Compute the RHS
        */
       //timer.stop();
-      this->problem->getExplicitUpdate( time, currentTau, u, k1 );
+      this->problem->getExplicitUpdate( time, currentTau, _u, _k1 );
       //timer.start();
 
       RealType lastResidue = this->getResidue();
       RealType maxResidue( 0.0 );
       if( this -> cflCondition != 0.0 )
       {
-         maxResidue = k1->absMax();
+         maxResidue = max( abs( k1 ) ); //k1->absMax();
          if( currentTau * maxResidue > this->cflCondition )
          {
             currentTau *= 0.9;
@@ -118,7 +120,15 @@ bool Euler< Problem > :: solve( DofVectorPointer& u )
          }
       }
       RealType newResidue( 0.0 );
-      computeNewTimeLevel( u, currentTau, newResidue );
+      //computeNewTimeLevel( u, currentTau, newResidue );
+     /* auto fetch = [ currentTau, k1, u ] __cuda_callable__ ( IndexType i ) -> RealType {
+         const RealType add = currentTau * k1[ i ];
+         u[ i ] += add;
+         return TNL::abs( add ); };
+      auto reduction = [=] __cuda_callable__ ( RealType& a , const RealType& b ) { a += b; };
+      auto volatileReduction = [=] __cuda_callable__ ( volatile RealType& a , const volatile RealType& b ) { a += b; };
+      return Containers::Algorithms::Reduction< DeviceType >::reduce( u.getSize(), reduction, volatileReduction, fetch, 0.0 );*/
+      
       this->setResidue( newResidue );
 
       /****
@@ -127,7 +137,7 @@ bool Euler< Problem > :: solve( DofVectorPointer& u )
        */
       if( currentTau + time == this -> stopTime ) this->setResidue( lastResidue );
       time += currentTau;
-      this->problem->applyBoundaryConditions( time, u );
+      this->problem->applyBoundaryConditions( time, _u );
 
       if( ! this->nextIteration() )
          return this->checkConvergence();
@@ -159,7 +169,7 @@ bool Euler< Problem > :: solve( DofVectorPointer& u )
    }
 };
 
-template< typename Problem >
+/*template< typename Problem >
 void Euler< Problem > :: computeNewTimeLevel( DofVectorPointer& u,
                                               RealType tau,
                                               RealType& currentResidue )
@@ -262,6 +272,7 @@ __global__ void updateUEuler( const Index size,
                         n );
 }
 #endif
+*/
 
 } // namespace ODE
 } // namespace Solvers
-- 
GitLab