/***************************************************************************
                          ExplicitUpdater.h  -  description
                             -------------------
    begin                : Jul 29, 2014
    copyright            : (C) 2014 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <TNL/Functions/FunctionAdapter.h>
#include <TNL/Timer.h>
#include <TNL/Pointers/SharedPointer.h>
#include <type_traits>
#include <TNL/Meshes/GridDetails/Traverser_Grid1D.h>
#include <TNL/Meshes/GridDetails/Traverser_Grid2D.h>
#include <TNL/Meshes/GridDetails/Traverser_Grid3D.h>

namespace TNL {
namespace Solvers {
namespace PDE {   

template< typename Real,
          typename MeshFunction,
          typename DifferentialOperator,
          typename BoundaryConditions,
          typename RightHandSide >
class ExplicitUpdaterTraverserUserData
{
   public:
      
      Real time;

      const DifferentialOperator* differentialOperator;

      const BoundaryConditions* boundaryConditions;

      const RightHandSide* rightHandSide;

      MeshFunction *u, *fu;
      
      ExplicitUpdaterTraverserUserData()
      : time( 0.0 ),
        differentialOperator( NULL ),
        boundaryConditions( NULL ),
        rightHandSide( NULL ),
        u( NULL ),
        fu( NULL )
      {}
      
      
      /*void setUserData( const Real& time,
                        const DifferentialOperator* differentialOperator,
                        const BoundaryConditions* boundaryConditions,
                        const RightHandSide* rightHandSide,
                        MeshFunction* u,
                        MeshFunction* fu )
      {
         this->time = time;
         this->differentialOperator = differentialOperator;
         this->boundaryConditions = boundaryConditions;
         this->rightHandSide = rightHandSide;
         this->u = u;
         this->fu = fu;
      }*/
};


template< typename Mesh,
          typename MeshFunction,
          typename DifferentialOperator,
          typename BoundaryConditions,
          typename RightHandSide >
class ExplicitUpdater
{
   public:
      typedef Mesh MeshType;
      typedef Pointers::SharedPointer<  MeshType > MeshPointer;
      typedef typename MeshFunction::RealType RealType;
      typedef typename MeshFunction::DeviceType DeviceType;
      typedef typename MeshFunction::IndexType IndexType;
      typedef ExplicitUpdaterTraverserUserData< RealType,
                                                MeshFunction,
                                                DifferentialOperator,
                                                BoundaryConditions,
                                                RightHandSide > TraverserUserData;
      typedef Pointers::SharedPointer<  DifferentialOperator, DeviceType > DifferentialOperatorPointer;
      typedef Pointers::SharedPointer<  BoundaryConditions, DeviceType > BoundaryConditionsPointer;
      typedef Pointers::SharedPointer<  RightHandSide, DeviceType > RightHandSidePointer;
      typedef Pointers::SharedPointer<  MeshFunction, DeviceType > MeshFunctionPointer;
      typedef Pointers::SharedPointer<  TraverserUserData, DeviceType > TraverserUserDataPointer;
      
      void setDifferentialOperator( const DifferentialOperatorPointer& differentialOperatorPointer )
      {
         this->userData.differentialOperator = &differentialOperatorPointer.template getData< DeviceType >();
      }
      
      void setBoundaryConditions( const BoundaryConditionsPointer& boundaryConditionsPointer )
      {
         this->userData.boundaryConditions = &boundaryConditionsPointer.template getData< DeviceType >();
      }
      
      void setRightHandSide( const RightHandSidePointer& rightHandSidePointer )
      {
         this->userData.rightHandSide = &rightHandSidePointer.template getData< DeviceType >();
      }
            
      template< typename EntityType,
                typename CommunicatorType >
      void update( const RealType& time,
                   const RealType& tau,
                   const MeshPointer& meshPointer,
                   MeshFunctionPointer& uPointer,
                   MeshFunctionPointer& fuPointer )
      {
         static_assert( std::is_same< MeshFunction,
                                      Containers::Vector< typename MeshFunction::RealType,
                                                 typename MeshFunction::DeviceType,
                                                 typename MeshFunction::IndexType > >::value != true,
            "Error: I am getting Vector instead of MeshFunction or similar object. You might forget to bind DofVector into MeshFunction in you method getExplicitUpdate."  );
         TNL_ASSERT_GT( uPointer->getData().getSize(), 0, "The first MeshFunction in the parameters was not bound." );
         TNL_ASSERT_GT( fuPointer->getData().getSize(), 0, "The second MeshFunction in the parameters was not bound." );

         TNL_ASSERT_EQ( uPointer->getData().getSize(), meshPointer->template getEntitiesCount< EntityType >(),
                        "The first MeshFunction in the parameters was not bound properly." );
         TNL_ASSERT_EQ( fuPointer->getData().getSize(), meshPointer->template getEntitiesCount< EntityType >(),
                        "The second MeshFunction in the parameters was not bound properly." );
            
         TNL_ASSERT_TRUE( this->userData.differentialOperator,
                          "The differential operator is not correctly set-up. Use method setDifferentialOperator() to do it." );
         TNL_ASSERT_TRUE( this->userData.rightHandSide,
                          "The right-hand side is not correctly set-up. Use method setRightHandSide() to do it." );

         this->userData.time = time;
         this->userData.u = &uPointer.template modifyData< DeviceType >();
         this->userData.fu = &fuPointer.template modifyData< DeviceType >();
         Meshes::Traverser< MeshType, EntityType > meshTraverser;
         meshTraverser.template processInteriorEntities< TraverserInteriorEntitiesProcessor >
                                                       ( meshPointer,
                                                         userData );
      }
      
      template< typename EntityType >
      void applyBoundaryConditions( const MeshPointer& meshPointer,
                                    const RealType& time,
                                    MeshFunctionPointer& uPointer )
      {
         TNL_ASSERT_TRUE( this->userData.boundaryConditions,
                          "The boundary conditions are not correctly set-up. Use method setBoundaryCondtions() to do it." );         
         TNL_ASSERT_TRUE( &uPointer.template modifyData< DeviceType >(),
                          "The function u is not correctly set-up. It was not bound probably with DOFs." );

         this->userData.time = time;
         this->userData.u = &uPointer.template modifyData< DeviceType >();         
         Meshes::Traverser< MeshType, EntityType > meshTraverser;
         meshTraverser.template processBoundaryEntities< TraverserBoundaryEntitiesProcessor >
                                           ( meshPointer,
                                             userData );

         // TODO: I think that this is not necessary
         /*if(CommunicatorType::isDistributed())
            fuPointer->template synchronize<CommunicatorType>();*/

      }
         
      class TraverserBoundaryEntitiesProcessor
      {
         public:

            template< typename GridEntity >
            __cuda_callable__
            static inline void processEntity( const MeshType& mesh,
                                              TraverserUserData& userData,
                                              const GridEntity& entity )
            {
               ( *userData.u )( entity ) = ( *userData.boundaryConditions )
                  ( *userData.u, entity, userData.time );
            }

      };

      class TraverserInteriorEntitiesProcessor
      {
         public:

            typedef typename MeshType::PointType PointType;

            template< typename EntityType >
            __cuda_callable__
            static inline void processEntity( const MeshType& mesh,
                                              TraverserUserData& userData,
                                              const EntityType& entity )
            {
               typedef Functions::FunctionAdapter< MeshType, RightHandSide > FunctionAdapter;
               ( *userData.fu )( entity ) = 
                  ( *userData.differentialOperator )( *userData.u, entity, userData.time )
                  + FunctionAdapter::getValue( *userData.rightHandSide, entity, userData.time );
               
            }
      }; 

   protected:

      TraverserUserData userData;

};

} // namespace PDE
} // namespace Solvers
} // namespace TNL