Commit 572f13b7 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Implementing operator functions.

parent ba69493b
Loading
Loading
Loading
Loading
+4 −2
Original line number Diff line number Diff line
@@ -78,11 +78,13 @@ class tnlMeshFunction :
      
      template< typename EntityType >
      __cuda_callable__
      RealType& operator()( const EntityType& meshEntity );
      RealType& operator()( const EntityType& meshEntity,
                            const RealType& time = 0.0 );
      
      template< typename EntityType >
      __cuda_callable__
      const RealType& operator()( const EntityType& meshEntity ) const;
      const RealType& operator()( const EntityType& meshEntity,
                                  const RealType& time = 0.0 ) const;
      
      __cuda_callable__
      RealType& operator[]( const IndexType& meshEntityIndex );
+34 −62
Original line number Diff line number Diff line
@@ -23,6 +23,10 @@
#include <functions/tnlOperatorFunction.h>
#include <functions/tnlBoundaryOperatorFunction.h>

template< typename OutMeshFunction,
          typename InFunction,
          typename Real >
class tnlMeshFunctionEvaluatorTraverserUserData;

/***
 * General mesh function evaluator. As an input function any type implementing
@@ -42,6 +46,8 @@ class tnlMeshFunctionEvaluator : public tnlFunction< OutMeshFunction::getMeshEnt
      typedef typename MeshType::DeviceType MeshDeviceType;
      typedef typename MeshType::IndexType MeshIndexType;
      typedef typename OutMeshFunction::RealType RealType;
      typedef tnlMeshFunctionEvaluatorTraverserUserData< OutMeshFunction, InFunction, RealType > TraverserUserData;

      
      const static int meshEntityDimensions = OutMeshFunction::entityDimensions;
      
@@ -78,26 +84,6 @@ class tnlMeshFunctionEvaluator : public tnlFunction< OutMeshFunction::getMeshEnt
                                    EntitiesType entitiesType );

      
      class TraverserUserData
      {
         public:
            
            typedef InFunction InFunctionType;
            
            TraverserUserData( const InFunction* function,
                               const RealType* time,
                               OutMeshFunction* meshFunction,
                               const RealType* outFunctionMultiplicator,
                               const RealType* inFunctionMultiplicator )
            : meshFunction( meshFunction ), function( function ), time( time ), 
              outFunctionMultiplicator( outFunctionMultiplicator ),
              inFunctionMultiplicator( inFunctionMultiplicator ){}

            OutMeshFunction* meshFunction;            
            const InFunction* function;
            const RealType *time, *outFunctionMultiplicator, *inFunctionMultiplicator;
            
      };
}; 

/****
@@ -120,6 +106,7 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator,
      typedef typename MeshType::IndexType MeshIndexType;
      typedef typename OutMeshFunction::RealType RealType;
      typedef tnlOperatorFunction< Operator, Function > OperatorFunctionType;
      typedef tnlMeshFunctionEvaluatorTraverserUserData< OutMeshFunction, OperatorFunctionType, RealType > TraverserUserData;
      
      static_assert( std::is_same< MeshType, typename OperatorFunctionType::MeshType >::value, 
         "Input function and the mesh of the mesh function have both different number of dimensions." );
@@ -133,28 +120,6 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlOperatorFunction< Operator,
                            const RealType& time = 0.0,
                            const RealType& outFunctionMultiplicator = 0.0,
                            const RealType& inFunctionMultiplicator = 1.0 );
            
      class TraverserUserData
      {
         public:
            
            typedef OperatorFunctionType InFunctionType;
         
            TraverserUserData( const OperatorFunctionType* operatorFunction,              
                               const RealType* time,
                               OutMeshFunction* meshFunction,
                               const RealType* outFunctionMultiplicator,
                               const RealType* inFunctionMultiplicator )
            : meshFunction( meshFunction ), function( operatorFunction ),time( time ), 
              outFunctionMultiplicator( outFunctionMultiplicator ),
              inFunctionMultiplicator( inFunctionMultiplicator ){}

            OutMeshFunction* meshFunction;            
            const OperatorFunctionType* function;
            const RealType *time, *outFunctionMultiplicator, *inFunctionMultiplicator;
            
      };

};

/****
@@ -177,6 +142,8 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlBoundaryOperatorFunction< Bo
      typedef typename MeshType::IndexType MeshIndexType;
      typedef typename OutMeshFunction::RealType RealType;
      typedef tnlBoundaryOperatorFunction< BoundaryOperator, Function > BoundaryOperatorFunctionType;
      typedef tnlMeshFunctionEvaluatorTraverserUserData< OutMeshFunction, BoundaryOperatorFunctionType, RealType > TraverserUserData;

      
      static_assert( std::is_same < MeshType, typename BoundaryOperatorFunctionType::MeshType >::value, 
         "Input boundary operator mesh type and the output mesh function mesh are different types." );
@@ -195,27 +162,32 @@ class tnlMeshFunctionEvaluator< OutMeshFunction, tnlBoundaryOperatorFunction< Bo
                                            const RealType& time = 0.0,
                                            const RealType& outFunctionMultiplicator = 0.0,
                                            const RealType& inFunctionMultiplicator = 1.0 );
};

      class TraverserUserData
template< typename OutMeshFunction,
          typename InFunction,
          typename Real >
class tnlMeshFunctionEvaluatorTraverserUserData
{
   public:
            typedef BoundaryOperatorFunctionType InFunctionType;

            TraverserUserData( const BoundaryOperatorFunctionType* operatorFunction,              
                               const RealType* time,
      typedef InFunction InFunctionType;

      tnlMeshFunctionEvaluatorTraverserUserData( const InFunction* function,
                                                 const Real* time,
                                                 OutMeshFunction* meshFunction,
                               const RealType* outFunctionMultiplicator,
                               const RealType* inFunctionMultiplicator )
            : meshFunction( meshFunction ), function( operatorFunction ), time( time ), 
                                                 const Real* outFunctionMultiplicator,
                                                 const Real* inFunctionMultiplicator )
      : meshFunction( meshFunction ), function( function ), time( time ), 
        outFunctionMultiplicator( outFunctionMultiplicator ),
        inFunctionMultiplicator( inFunctionMultiplicator ){}

      OutMeshFunction* meshFunction;            
            const BoundaryOperatorFunctionType* function;
            const RealType *time, *outFunctionMultiplicator, *inFunctionMultiplicator;
      const InFunction* function;
      const Real *time, *outFunctionMultiplicator, *inFunctionMultiplicator;

};
};


template< typename MeshType,
          typename UserData > 
+4 −2
Original line number Diff line number Diff line
@@ -169,7 +169,8 @@ template< typename Mesh,
__cuda_callable__
typename tnlMeshFunction< Mesh, MeshEntityDimensions, Real >::RealType& 
tnlMeshFunction< Mesh, MeshEntityDimensions, Real >::
operator()( const EntityType& meshEntity )
operator()( const EntityType& meshEntity,
            const RealType& time )
{
   static_assert( EntityType::entityDimensions == MeshEntityDimensions, "Calling with wrong EntityType -- entity dimensions do not match." );
   return this->data[ meshEntity.getIndex() ];
@@ -182,7 +183,8 @@ template< typename Mesh,
__cuda_callable__
const typename tnlMeshFunction< Mesh, MeshEntityDimensions, Real >::RealType& 
tnlMeshFunction< Mesh, MeshEntityDimensions, Real >::
operator()( const EntityType& meshEntity ) const
operator()( const EntityType& meshEntity,
            const RealType& time ) const
{
   static_assert( EntityType::entityDimensions == MeshEntityDimensions, "Calling with wrong EntityType -- entity dimensions do not match." );
   return this->data[ meshEntity.getIndex() ];
+8 −3
Original line number Diff line number Diff line
@@ -27,13 +27,18 @@
 * evaluates this function only on the INTERIOR mesh entities.
 */
template< typename Operator,
          typename Function >
class tnlOperatorFunction : public tnlFunction< Operator::getMeshEntityDimensions(), MeshFunction >
          typename MeshFunction >
class tnlOperatorFunction : public tnlFunction< Operator::getMeshEntityDimensions(), ::MeshFunction >
{   
   public:
      
      static_assert( MeshFunction::getFunctionType() == ::MeshFunction,
         "Only mesh functions may be used in the operator function." );
      static_assert( std::is_same< typename Operator::MeshType, typename MeshFunction::MeshType >::value,
          "Both, operator and mesh function must be defined on the same mesh." );
      
      typedef Operator OperatorType;
      typedef Function FunctionType;
      typedef MeshFunction FunctionType;
      typedef typename OperatorType::MeshType MeshType;
      typedef typename OperatorType::RealType RealType;
      typedef typename OperatorType::DeviceType DeviceType;
+2 −3
Original line number Diff line number Diff line
@@ -215,11 +215,10 @@ getExplicitRHS( const RealType& time,
      this->rightHandSide,
      this->u,
      fu );
   tnlBoundaryConditionsSetter< Mesh, MeshFunctionType, BoundaryCondition > boundaryConditionsSetter;
   tnlBoundaryConditionsSetter< MeshFunctionType, BoundaryCondition > boundaryConditionsSetter;
   boundaryConditionsSetter.template apply< typename Mesh::Cell >(
      time + tau,
      mesh,
      this->boundaryCondition,
      time + tau,
      this->u );
   /*cout << "u = " << u << endl;
   cout << "fu = " << fu << endl;
Loading