Commit b43a9719 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Implementing operator functions.

parent f343ff94
Loading
Loading
Loading
Loading
+36 −0
Original line number Diff line number Diff line
@@ -18,6 +18,42 @@
#ifndef TNLBOUNDARYOPERATORFUNCTION_H
#define	TNLBOUNDARYOPERATORFUNCTION_H

/***
 * This class evaluates given operator on given function.
 * The main role of this type is that the mesh function evaluator
 * evaluates this function only on the BOUNDARY mesh entities.
 */
template< typename BoundaryOperator,
          typename Function >
class tnlBoundaryOperatorFunction
{
   public:
      
      typedef BoundaryOperator BoundaryOperatorType;
      typedef Function FunctionType;
      typedef typename BoundaryOperator::MeshType MeshType;
      typedef typename BoundaryOperator::RealType RealType;
      
      tnlBoundaryOperatorFunction(
         const BoundaryOperatorType& boundaryOperator,
         const FunctionType& function )
      :  boundaryOperator( &boundaryOperator ), function( &function ){};
      
      template< typename MeshEntity >
      __cuda_callable__
      RealType operator()(
         const MeshEntity& meshEntity,
         const RealType& time )
      {
         return boundaryOperator->getValue( meshEntity, function->getData(), time );
      }
      
   protected:
      
      const BoundaryOperator* boundaryOperator;
      
      const FunctionType* function;         
};


#endif	/* TNLBOUNDARYOPERATORFUNCTION_H */
+0 −2
Original line number Diff line number Diff line
@@ -49,7 +49,5 @@ class tnlExactOperatorFunction
      const FunctionType& function;               
};



#endif	/* TNLEXACTOPERATORFUNCTION_H */
+1 −3
Original line number Diff line number Diff line
@@ -24,9 +24,7 @@
enum tnlFunctionType { GeneralFunction, 
                       MeshFunction,
                       AnalyticFunction,
                       AnalyticConstantFunction,
                       AnalyticOperator,
                       MeshOperator };
                       AnalyticConstantFunction };

template< int Dimensions,
          tnlFunctionType FunctionType = GeneralFunction >
+118 −15
Original line number Diff line number Diff line
@@ -21,7 +21,17 @@
#include <mesh/tnlGrid.h>
#include <functions/tnlMeshFunction.h>
#include <functions/tnlOperatorFunction.h>
#include <functions/tnlBoundaryOperatorFunction.h>


/***
 * General mesh function evaluator. As an input function any type implementing
 * getValue( meshEntity, time ) may be substituted.
 * Methods:
 *  evaluate() -  evaluate the input function on ALL mesh entities of the mesh function
 *  evaluateInteriorEntities() - evaluate the input function only on the INTERIOR mesh entities
 *  evaluateBoundaryEntities() - evaluate the input function only on the BOUNDARY mesh entities
 */
template< typename OutMeshFunction,
          typename InFunction >
class tnlMeshFunctionEvaluator
@@ -38,7 +48,7 @@ class tnlMeshFunctionEvaluator
      static_assert( MeshType::meshDimensions == InFunction::Dimensions, 
         "Input function and the mesh of the mesh function have both different number of dimensions." );
      
      static void evaluateAllEntities( OutMeshFunction& meshFunction,
      static void evaluate( OutMeshFunction& meshFunction,
                            const InFunction& function,                          
                            const RealType& time = 0.0,
                            const RealType& outFunctionMultiplicator = 0.0,
@@ -71,6 +81,9 @@ class tnlMeshFunctionEvaluator
      class TraverserUserData
      {
         public:
            
            typedef InFunction InFunctionType;
            
            TraverserUserData( const InFunction* function,
                               const RealType* time,
                               OutMeshFunction* meshFunction,
@@ -88,6 +101,12 @@ class tnlMeshFunctionEvaluator
      };
}; 

/****
 * Specialization of the mesh function evaluator for operator functions which are
 * defines only for the interior mesh entities. Therefore there is only one method
 *   evaluate()
 * which goes only over the interior mesh entities.
 */
template< typename OutMeshFunction,
          typename Operator,
          typename Function >
@@ -102,17 +121,77 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator,
      typedef typename OutMeshFunction::Real RealType;
      typedef tnlOperatorFunction< Operator, Function > OperatorFunctionType;
      
      static void evaluateAllEntities( OutMeshFunction& meshFunction,
      static_assert( std::is_same< MeshType, typename OperatorFunctionType::MeshType >::value, 
         "Input function and the mesh of the mesh function have both different number of dimensions." );

      
      /****
       * Evaluate on interior mesh entities
       */
      static void evaluate( OutMeshFunction& meshFunction,
                            const OperatorFunctionType& function,                          
                            const RealType& time = 0.0,
                            const RealType& outFunctionMultiplicator = 0.0,
                                       const RealType& inFunctionMultiplicator = 1.0 )
                            const RealType& inFunctionMultiplicator = 1.0 );
            
      class TraverserUserData
      {
         evaluateInteriorEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator );
         typedef OperatorFunctionType InFunctionType;
         
         public:
            TraverserUserData( const OperatorFunctionType* operatorFunction,              
                               const RealType* time,
                               OutMeshFunction* meshFunction,
                               const RealType* outFunctionMultiplicator,
                               const RealType* inFunctionMultiplicator )
            : meshFunction( meshFunction ), operatorFunction( operatorFunction ),time( time ), 
              outFunctionMultiplicator( outFunctionMultiplicator ),
              inFunctionMultiplicator( inFunctionMultiplicator ){}

         protected:
            OutMeshFunction* meshFunction;            
            const OperatorFunctionType* operatorFunction;
            const RealType *time, *outFunctionMultiplicator, *inFunctionMultiplicator;
            
      };

};

/****
 * Specialization of the mesh function evaluator for boundary operator functions which are
 * defines only for the boundary mesh entities. Therefore there is only one method
 *   evaluate()
 * which goes only over the boundary mesh entities.
 */
template< typename OutMeshFunction,
          typename BoundaryOperator,
          typename Function >
class tnlMeshFunctionEvaluator< OutMeshFunction, tnlBoundaryOperatorFunction< BoundaryOperator, Function > >
{
   public:
      
      typedef typename OutMeshFunction::MeshType MeshType;
      typedef typename MeshType::RealType MeshRealType;
      typedef typename MeshType::DeviceType MeshDeviceType;
      typedef typename MeshType::IndexType MeshIndexType;
      typedef typename OutMeshFunction::Real RealType;
      typedef tnlBoundaryOperatorFunction< BoundaryOperator, Function > BoundaryOperatorFunctionType;
      
      static_assert( std::is_same < MeshType, typename BoundaryOperatorFunctionType::MeshType >::value, 
         "Input function and the mesh of the mesh function have both different number of dimensions." );

      
      /***
       * Evaluate on boundary mesh entities
       */
      static void evaluate( OutMeshFunction& meshFunction,
                            const BoundaryOperatorFunctionType& function,                          
                            const RealType& time = 0.0,
                            const RealType& outFunctionMultiplicator = 0.0,
                            const RealType& inFunctionMultiplicator = 1.0 );
      
      static void evaluateInteriorEntities( OutMeshFunction& meshFunction,
                                            const OperatorFunctionType& function,                          
                                            const BoundaryOperatorFunctionType& function,                          
                                            const RealType& time = 0.0,
                                            const RealType& outFunctionMultiplicator = 0.0,
                                            const RealType& inFunctionMultiplicator = 1.0 );
@@ -120,7 +199,9 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator,
      class TraverserUserData
      {
         public:
            TraverserUserData( const OperatorFunctionType* operatorFunction,              
            typedef BoundaryOperatorFunctionType InFunctionType;
            
            TraverserUserData( const BoundaryOperatorFunctionType* operatorFunction,              
                               const RealType* time,
                               OutMeshFunction* meshFunction,
                               const RealType* outFunctionMultiplicator,
@@ -131,11 +212,33 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator,

         protected:
            OutMeshFunction* meshFunction;            
            const OperatorFunctionType* operatorFunction;
            const BoundaryOperatorFunctionType* operatorFunction;
            const RealType *time, *outFunctionMultiplicator, *inFunctionMultiplicator;
            
      };
};

template< typename MeshType,
          typename UserData > 
class tnlMeshFunctionEvaluatorEntitiesProcessor
{
   template< typename EntityType >
   __cuda_callable__
   static inline void processEntity( const MeshType& mesh,
                                     UserData& userData,
                                     const EntityType& entity )
   {
      typedef tnlFunctionAdapter< MeshType, typename UserData::InFunction > FunctionAdapter;
      ( *userData.meshFunction )( entity ) = 
         *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
         *userData.inFunctionMultiplicator *
         FunctionAdapter::getValue( *userData.function, entity, *userData.time );
   }
};



#include <functions/tnlMeshFunctionEvaluator_impl.h>

#endif	/* TNLMESHFUNCTIONEVALUATOR_H */
+89 −63
Original line number Diff line number Diff line
@@ -25,11 +25,11 @@ template< typename OutMeshFunction,
          typename InFunction >
void
tnlMeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateAllEntities( OutMeshFunction& meshFunction,
evaluate( OutMeshFunction& meshFunction,
          const InFunction& function,                          
                     const RealType& time = 0.0,
                     const RealType& outFunctionMultiplicator = 0.0,
                     const RealType& inFunctionMultiplicator = 1.0 )
          const RealType& time,
          const RealType& outFunctionMultiplicator,
          const RealType& inFunctionMultiplicator )
{
   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, all );
}
@@ -40,9 +40,9 @@ void
tnlMeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateInteriorEntities( OutMeshFunction& meshFunction,
                          const InFunction& function,                          
                          const RealType& time = 0.0,
                          const RealType& outFunctionMultiplicator = 0.0,
                          const RealType& inFunctionMultiplicator = 1.0 )
                          const RealType& time,
                          const RealType& outFunctionMultiplicator,
                          const RealType& inFunctionMultiplicator )
{
   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, interior );
}
@@ -53,9 +53,9 @@ void
tnlMeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateBoundaryEntities( OutMeshFunction& meshFunction,
                          const InFunction& function,                          
                          const RealType& time = 0.0,
                          const RealType& outFunctionMultiplicator = 0.0,
                          const RealType& inFunctionMultiplicator = 1.0 )
                          const RealType& time,
                          const RealType& outFunctionMultiplicator,
                          const RealType& inFunctionMultiplicator )
{
   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, boundary );
}
@@ -73,23 +73,8 @@ evaluateEntities( OutMeshFunction& meshFunction,
                  const RealType& inFunctionMultiplicator,
                  EntitiesType entitiesType )
{
   typedef typename MeshType::template MeshEntities< meshEntityDimensions > MeshEntityType;
   
   class AssignEntitiesProcessor
   {
      template< typename EntityType >
      __cuda_callable__
      static inline void processEntity( const MeshType& mesh,
                                        TraverserUserData& userData,
                                        const EntityType& entity )
      {
         typedef tnlFunctionAdapter< MeshType, InFunction > FunctionAdapter;
         ( *userData.meshFunction )( entity ) = 
            *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
            *userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, *userData.time );
      }
   };
   typedef typename MeshType::template MeshEntities< OutMeshFunction::entityDimensions > MeshEntityType;
   typedef tnlMeshFunctionEvaluatorEntitiesProcessor< MeshType, TraverserUserData > EntitiesProcessor;
  
   if( std::is_same< MeshDeviceType, tnlHost >::value )
   {
@@ -99,19 +84,19 @@ evaluateEntities( OutMeshFunction& meshFunction,
      {
         case all:            
            meshTraverser.template processAllEntities< TraverserUserData,
                                                       AssignEntitiesProcessor >
                                                       EntitiesProcessor >
                                                     ( meshFunction.getMesh(),
                                                       userData );
            break;
         case interior:
            meshTraverser.template processInteriroEntities< TraverserUserData,
                                                            AssignEntitiesProcessor >
                                                            EntitiesProcessor >
                                                          ( meshFunction.getMesh(),
                                                            userData );
            break;
         case boundary:
            meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                            AssignEntitiesProcessor >
                                                            EntitiesProcessor >
                                                          ( meshFunction.getMesh(),
                                                            userData );
            break;
@@ -132,19 +117,19 @@ evaluateEntities( OutMeshFunction& meshFunction,
      {
         case all:            
            meshTraverser.template processAllEntities< TraverserUserData,
                                                       AssignEntitiesProcessor >
                                                       EntitiesProcessor >
                                                     ( meshFunction.getMesh(),
                                                       userData );
            break;
         case interior:
            meshTraverser.template processInteriorEntities< TraverserUserData,
                                                            AssignEntitiesProcessor >
                                                            EntitiesProcessor >
                                                          ( meshFunction.getMesh(),
                                                            userData );
            break;
         case boundary:
            meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                            AssignEntitiesProcessor >
                                                            EntitiesProcessor >
                                                          ( meshFunction.getMesh(),
                                                            userData );
            break;
@@ -163,39 +148,79 @@ evaluateEntities( OutMeshFunction& meshFunction,


template< typename OutMeshFunction,          
          typename Function,
          typename Operator >
          typename Operator,
          typename Function >
void
tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator, Function> >::
evaluateEntities( OutMeshFunction& meshFunction,
evaluate( OutMeshFunction& meshFunction,
          const OperatorFunctionType& operatorFunction,
          const RealType& time,
          const RealType& outFunctionMultiplicator,
          const RealType& inFunctionMultiplicator )
{
   typedef typename MeshType::template MeshEntities< meshEntityDimensions > MeshEntityType;
   typedef typename MeshType::template MeshEntities< OutMeshFunction::entityDimensions > MeshEntityType;
   typedef tnlMeshFunctionEvaluatorEntitiesProcessor< MeshType, TraverserUserData > EntitiesProcessor;
   
   class AssignEntitiesProcessor
   if( std::is_same< MeshDeviceType, tnlHost >::value )
   {
      template< typename EntityType >
      __cuda_callable__
      static inline void processEntity( const MeshType& mesh,
                                        TraverserUserData& userData,
                                        const EntityType& entity )
      TraverserUserData userData( &operatorFunction, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator );
      tnlTraverser< MeshType, MeshEntityType > meshTraverser;
      meshTraverser.template processInterirorEntities< TraverserUserData, EntitiesProcessor >
         ( meshFunction.getMesh(),
           userData );
      
   }
   if( std::is_same< MeshDeviceType, tnlCuda >::value )
   {      
         typedef tnlFunctionAdapter< MeshType, InFunction > FunctionAdapter;
         ( *userData.meshFunction )( entity ) = 
            *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
            *userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, *userData.time );
      OutMeshFunction* kernelMeshFunction = tnlCuda::passToDevice( meshFunction );
      Function* kernelFunction = tnlCuda::passToDevice( *operatorFunction.function );
      Operator* kernelOperator = tnlCuda::passToDevice( *operatorFunction.operator_ );
      OperatorFunctionType auxOperatorFunction( *kernelOperator, *kernelFunction );
      OperatorFunctionType* kernelOperatorFunction = tnlCuda::passToDevice( auxOperatorFunction );
      RealType* kernelTime = tnlCuda::passToDevice( time );
      RealType* kernelOutFunctionMultiplicator = tnlCuda::passToDevice( outFunctionMultiplicator );
      RealType* kernelInFunctionMultiplicator = tnlCuda::passToDevice( inFunctionMultiplicator );
      
      TraverserUserData userData( kernelOperatorFunction, kernelTime, kernelMeshFunction, kernelOutFunctionMultiplicator, kernelInFunctionMultiplicator );
      checkCudaDevice;
      tnlTraverser< MeshType, MeshEntityType > meshTraverser;
      meshTraverser.template processInteriorEntities< TraverserUserData, EntitiesProcessor >
         ( meshFunction.getMesh(),
           userData );


      checkCudaDevice;      
      tnlCuda::freeFromDevice( kernelMeshFunction );
      tnlCuda::freeFromDevice( kernelFunction );
      tnlCuda::freeFromDevice( kernelOperator );
      tnlCuda::freeFromDevice( kernelOperatorFunction );
      tnlCuda::freeFromDevice( kernelTime );
      tnlCuda::freeFromDevice( kernelOutFunctionMultiplicator );
      tnlCuda::freeFromDevice( kernelInFunctionMultiplicator );
            
      checkCudaDevice;
   }
   };
}

template< typename OutMeshFunction,          
          typename BoundaryOperator,
          typename Function >
void
tnlMeshFunctionEvaluator< OutMeshFunction, tnlBoundaryOperatorFunction< BoundaryOperator, Function> >::
evaluate( OutMeshFunction& meshFunction,
          const BoundaryOperatorFunctionType& operatorFunction,
          const RealType& time,
          const RealType& outFunctionMultiplicator,
          const RealType& inFunctionMultiplicator )
{
   typedef typename MeshType::template MeshEntities< OutMeshFunction::entityDimensions > MeshEntityType;
   typedef tnlMeshFunctionEvaluatorEntitiesProcessor< MeshType, TraverserUserData > EntitiesProcessor;
   
   if( std::is_same< MeshDeviceType, tnlHost >::value )
   {
      TraverserUserData userData( &function, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator );
      TraverserUserData userData( &operatorFunction, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator );
      tnlTraverser< MeshType, MeshEntityType > meshTraverser;
      meshTraverser.template processInterirorEntities< TraverserUserData, AssignEntitiesProcessor >
      meshTraverser.template processBoundaryEntities< TraverserUserData, EntitiesProcessor >
         ( meshFunction.getMesh(),
           userData );
      
@@ -204,9 +229,9 @@ evaluateEntities( OutMeshFunction& meshFunction,
   {      
      OutMeshFunction* kernelMeshFunction = tnlCuda::passToDevice( meshFunction );
      Function* kernelFunction = tnlCuda::passToDevice( *operatorFunction.function );
      Operator* kernelOperator = tnlCuda::passToDevice( *operatorFunction.operator_ );
      OperatorFunctionType auxOperatorFunction( *kernelOperator, *kernelFunction );
      OperatorFunctionType* kernelOperatorFunction = tnlCuda::passToDevice( auxOperatorFunction );
      BoundaryOperator* kernelOperator = tnlCuda::passToDevice( *operatorFunction.operator_ );
      BoundaryOperatorFunctionType auxOperatorFunction( *kernelOperator, *kernelFunction );
      BoundaryOperatorFunctionType* kernelOperatorFunction = tnlCuda::passToDevice( auxOperatorFunction );
      RealType* kernelTime = tnlCuda::passToDevice( time );
      RealType* kernelOutFunctionMultiplicator = tnlCuda::passToDevice( outFunctionMultiplicator );
      RealType* kernelInFunctionMultiplicator = tnlCuda::passToDevice( inFunctionMultiplicator );
@@ -214,7 +239,7 @@ evaluateEntities( OutMeshFunction& meshFunction,
      TraverserUserData userData( kernelOperatorFunction, kernelTime, kernelMeshFunction, kernelOutFunctionMultiplicator, kernelInFunctionMultiplicator );
      checkCudaDevice;
      tnlTraverser< MeshType, MeshEntityType > meshTraverser;
      meshTraverser.template processInteriorEntities< TraverserUserData, AssignEntitiesProcessor >
      meshTraverser.template processInteriorEntities< TraverserUserData, EntitiesProcessor >
         ( meshFunction.getMesh(),
           userData );

@@ -233,5 +258,6 @@ evaluateEntities( OutMeshFunction& meshFunction,
}



#endif	/* TNLMESHFUNCTIONEVALUATOR_IMPL_H */
Loading