Commit a8ecc63c authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added smart pointers to MeshFunctionEvaluator

parent e97f245e
Loading
Loading
Loading
Loading
+47 −46
Original line number Diff line number Diff line
@@ -21,7 +21,29 @@ namespace Functions {
template< typename OutMeshFunction,
          typename InFunction,
          typename Real >
class MeshFunctionEvaluatorTraverserUserData;
class MeshFunctionEvaluatorTraverserUserData
{
   public:
      typedef InFunction InFunctionType;

      MeshFunctionEvaluatorTraverserUserData( const InFunction* function,
                                              const Real& time,
                                              OutMeshFunction* meshFunction,
                                              const Real& outFunctionMultiplicator,
                                              const Real& inFunctionMultiplicator )
      : meshFunction( meshFunction ),
        function( function ),
        time( time ),
        outFunctionMultiplicator( outFunctionMultiplicator ),
        inFunctionMultiplicator( inFunctionMultiplicator )
      {}

      OutMeshFunction* meshFunction;
      const InFunction* function;
      const Real time, outFunctionMultiplicator, inFunctionMultiplicator;

};


/***
 * General mesh function evaluator. As an input function any type implementing
@@ -40,33 +62,35 @@ class MeshFunctionEvaluator
                  "Input and output functions must have the same domain dimensions." );

   public:
      typedef typename OutMeshFunction::MeshType MeshType;
      typedef typename MeshType::RealType MeshRealType;
      typedef typename MeshType::DeviceType MeshDeviceType;
      typedef typename MeshType::IndexType MeshIndexType;
      typedef typename OutMeshFunction::RealType RealType;
      typedef typename OutMeshFunction::MeshType MeshType;
      typedef typename MeshType::DeviceType DeviceType;
      typedef Functions::MeshFunctionEvaluatorTraverserUserData< OutMeshFunction, InFunction, RealType > TraverserUserData;

      static void evaluate( OutMeshFunction& meshFunction,
                            const InFunction& function,
      template< typename OutMeshFunctionPointer, typename InFunctionPointer >
      static void evaluate( OutMeshFunctionPointer& meshFunction,
                            const InFunctionPointer& function,
                            const RealType& time = 0.0,
                            const RealType& outFunctionMultiplicator = 0.0,
                            const RealType& inFunctionMultiplicator = 1.0 );

      static void evaluateAllEntities( OutMeshFunction& meshFunction,
                                       const InFunction& function,
      template< typename OutMeshFunctionPointer, typename InFunctionPointer >
      static void evaluateAllEntities( OutMeshFunctionPointer& meshFunction,
                                       const InFunctionPointer& function,
                                       const RealType& time = 0.0,
                                       const RealType& outFunctionMultiplicator = 0.0,
                                       const RealType& inFunctionMultiplicator = 1.0 );
 
      static void evaluateInteriorEntities( OutMeshFunction& meshFunction,
                                            const InFunction& function,
      template< typename OutMeshFunctionPointer, typename InFunctionPointer >
      static void evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction,
                                            const InFunctionPointer& function,
                                            const RealType& time = 0.0,
                                            const RealType& outFunctionMultiplicator = 0.0,
                                            const RealType& inFunctionMultiplicator = 1.0 );

      static void evaluateBoundaryEntities( OutMeshFunction& meshFunction,
                                            const InFunction& function,
      template< typename OutMeshFunctionPointer, typename InFunctionPointer >
      static void evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction,
                                            const InFunctionPointer& function,
                                            const RealType& time = 0.0,
                                            const RealType& outFunctionMultiplicator = 0.0,
                                            const RealType& inFunctionMultiplicator = 1.0 );
@@ -75,8 +99,9 @@ class MeshFunctionEvaluator

      enum EntitiesType { all, boundary, interior };
 
      static void evaluateEntities( OutMeshFunction& meshFunction,
                                    const InFunction& function,
      template< typename OutMeshFunctionPointer, typename InFunctionPointer >
      static void evaluateEntities( OutMeshFunctionPointer& meshFunction,
                                    const InFunctionPointer& function,
                                    const RealType& time,
                                    const RealType& outFunctionMultiplicator,
                                    const RealType& inFunctionMultiplicator,
@@ -85,30 +110,6 @@ class MeshFunctionEvaluator
 
};

template< typename OutMeshFunction,
          typename InFunction,
          typename Real >
class MeshFunctionEvaluatorTraverserUserData
{
   public:

      typedef InFunction InFunctionType;

      MeshFunctionEvaluatorTraverserUserData( const InFunction* function,
                                              const Real* time,
                                              OutMeshFunction* meshFunction,
                                              const Real* outFunctionMultiplicator,
                                              const Real* inFunctionMultiplicator )
      : meshFunction( meshFunction ), function( function ), time( time ),
        outFunctionMultiplicator( outFunctionMultiplicator ),
        inFunctionMultiplicator( inFunctionMultiplicator ){}

      OutMeshFunction* meshFunction;
      const InFunction* function;
      const Real *time, *outFunctionMultiplicator, *inFunctionMultiplicator;

};


template< typename MeshType,
          typename UserData >
@@ -124,10 +125,10 @@ class MeshFunctionEvaluatorAssignmentEntitiesProcessor
      {
         typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
         ( *userData.meshFunction )( entity ) =
            *userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, *userData.time );
            userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, userData.time );
         /*cerr << "Idx = " << entity.getIndex()
            << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time )
            << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
            << " stored value = " << ( *userData.meshFunction )( entity )
            << " multiplicators = " << std::endl;*/
      }
@@ -147,11 +148,11 @@ class MeshFunctionEvaluatorAdditionEntitiesProcessor
      {
         typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter;
         ( *userData.meshFunction )( entity ) =
            *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
            *userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, *userData.time );
            userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) +
            userData.inFunctionMultiplicator *
            FunctionAdapter::getValue( *userData.function, entity, userData.time );
         /*cerr << "Idx = " << entity.getIndex()
            << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time )
            << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time )
            << " stored value = " << ( *userData.meshFunction )( entity )
            << " multiplicators = " << std::endl;*/
      }
+73 −111
Original line number Diff line number Diff line
@@ -18,14 +18,18 @@ namespace Functions {

template< typename OutMeshFunction,
          typename InFunction >
   template< typename OutMeshFunctionPointer, typename InFunctionPointer >
void
MeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluate( OutMeshFunction& meshFunction,
          const InFunction& function,
evaluate( OutMeshFunctionPointer& meshFunction,
          const InFunctionPointer& function,
          const RealType& time,
          const RealType& outFunctionMultiplicator,
          const RealType& inFunctionMultiplicator )
{
   static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" );
   static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" );

   switch( InFunction::getDomainType() )
   {
      case NonspaceDomain:
@@ -45,40 +49,52 @@ evaluate( OutMeshFunction& meshFunction,

template< typename OutMeshFunction,
          typename InFunction >
   template< typename OutMeshFunctionPointer, typename InFunctionPointer >
void
MeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateAllEntities( OutMeshFunction& meshFunction,
                     const InFunction& function,
evaluateAllEntities( OutMeshFunctionPointer& meshFunction,
                     const InFunctionPointer& function,
                     const RealType& time,
                     const RealType& outFunctionMultiplicator,
                     const RealType& inFunctionMultiplicator )
{
   static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" );
   static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" );

   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, all );
}

template< typename OutMeshFunction,
          typename InFunction >
   template< typename OutMeshFunctionPointer, typename InFunctionPointer >
void
MeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateInteriorEntities( OutMeshFunction& meshFunction,
                          const InFunction& function,
evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction,
                          const InFunctionPointer& function,
                          const RealType& time,
                          const RealType& outFunctionMultiplicator,
                          const RealType& inFunctionMultiplicator )
{
   static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" );
   static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" );

   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, interior );
}

template< typename OutMeshFunction,
          typename InFunction >
   template< typename OutMeshFunctionPointer, typename InFunctionPointer >
void
MeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateBoundaryEntities( OutMeshFunction& meshFunction,
                          const InFunction& function,
evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction,
                          const InFunctionPointer& function,
                          const RealType& time,
                          const RealType& outFunctionMultiplicator,
                          const RealType& inFunctionMultiplicator )
{
   static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" );
   static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" );

   return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, boundary );
}

@@ -86,73 +102,28 @@ evaluateBoundaryEntities( OutMeshFunction& meshFunction,

template< typename OutMeshFunction,
          typename InFunction >
   template< typename OutMeshFunctionPointer, typename InFunctionPointer >
void
MeshFunctionEvaluator< OutMeshFunction, InFunction >::
evaluateEntities( OutMeshFunction& meshFunction,
                  const InFunction& function,
evaluateEntities( OutMeshFunctionPointer& meshFunction,
                  const InFunctionPointer& function,
                  const RealType& time,
                  const RealType& outFunctionMultiplicator,
                  const RealType& inFunctionMultiplicator,
                  EntitiesType entitiesType )
{
   static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" );
   static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" );

   typedef typename MeshType::template MeshEntity< OutMeshFunction::getEntitiesDimensions() > MeshEntityType;
   typedef Functions::MeshFunctionEvaluatorAssignmentEntitiesProcessor< MeshType, TraverserUserData > AssignmentEntitiesProcessor;
   typedef Functions::MeshFunctionEvaluatorAdditionEntitiesProcessor< MeshType, TraverserUserData > AdditionEntitiesProcessor;
 
   if( std::is_same< MeshDeviceType, Devices::Host >::value )
   {
      TraverserUserData userData( &function, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator );
      Meshes::Traverser< MeshType, MeshEntityType > meshTraverser;
      switch( entitiesType )
      {
         case all:
            if( outFunctionMultiplicator )
               meshTraverser.template processAllEntities< TraverserUserData,
                                                          AdditionEntitiesProcessor >
                                                        ( meshFunction.getMeshPointer(),
                                                          userData );
            else
               meshTraverser.template processAllEntities< TraverserUserData,
                                                         AssignmentEntitiesProcessor >
                                                       ( meshFunction.getMeshPointer(),
                                                         userData );
            break;
         case interior:
            if( outFunctionMultiplicator )
               meshTraverser.template processInteriorEntities< TraverserUserData,
                                                               AdditionEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                               userData );
            else
               meshTraverser.template processInteriorEntities< TraverserUserData,
                                                               AssignmentEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                               userData );            
            break;
         case boundary:
            if( outFunctionMultiplicator )
               meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                               AdditionEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                               userData );
            else
               meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                               AssignmentEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                               userData );
            break;
      }
   }
   if( std::is_same< MeshDeviceType, Devices::Cuda >::value )
   {
      OutMeshFunction* kernelMeshFunction = Devices::Cuda::passToDevice( meshFunction );
      InFunction* kernelFunction = Devices::Cuda::passToDevice( function );
      RealType* kernelTime = Devices::Cuda::passToDevice( time );
      RealType* kernelOutFunctionMultiplicator = Devices::Cuda::passToDevice( outFunctionMultiplicator );
      RealType* kernelInFunctionMultiplicator = Devices::Cuda::passToDevice( inFunctionMultiplicator );
 
      TraverserUserData userData( kernelFunction, kernelTime, kernelMeshFunction, kernelOutFunctionMultiplicator, kernelInFunctionMultiplicator );
      checkCudaDevice;
   TraverserUserData userData( &function.template getData< DeviceType >(),
                               time,
                               &meshFunction.template modifyData< DeviceType >(),
                               outFunctionMultiplicator,
                               inFunctionMultiplicator );
   Meshes::Traverser< MeshType, MeshEntityType > meshTraverser;
   switch( entitiesType )
   {
@@ -160,48 +131,39 @@ evaluateEntities( OutMeshFunction& meshFunction,
         if( outFunctionMultiplicator )
            meshTraverser.template processAllEntities< TraverserUserData,
                                                       AdditionEntitiesProcessor >
                                                        ( meshFunction.getMeshPointer(),
                                                     ( meshFunction->getMeshPointer(),
                                                       userData );
         else
            meshTraverser.template processAllEntities< TraverserUserData,
                                                      AssignmentEntitiesProcessor >
                                                       ( meshFunction.getMeshPointer(),
                                                    ( meshFunction->getMeshPointer(),
                                                      userData );
         break;
      case interior:
         if( outFunctionMultiplicator )
            meshTraverser.template processInteriorEntities< TraverserUserData,
                                                            AdditionEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                          ( meshFunction->getMeshPointer(),
                                                            userData );
         else
            meshTraverser.template processInteriorEntities< TraverserUserData,
                                                            AssignmentEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                          ( meshFunction->getMeshPointer(),
                                                            userData );
         break;
      case boundary:
         if( outFunctionMultiplicator )
            meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                            AdditionEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                          ( meshFunction->getMeshPointer(),
                                                            userData );
         else
            meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                            AssignmentEntitiesProcessor >
                                                             ( meshFunction.getMeshPointer(),
                                                          ( meshFunction->getMeshPointer(),
                                                            userData );
         break;
   }

      checkCudaDevice;
      Devices::Cuda::freeFromDevice( kernelMeshFunction );
      Devices::Cuda::freeFromDevice( kernelFunction );
      Devices::Cuda::freeFromDevice( kernelTime );
      Devices::Cuda::freeFromDevice( kernelOutFunctionMultiplicator );
      Devices::Cuda::freeFromDevice( kernelInFunctionMultiplicator );
      checkCudaDevice;
   }
}

} // namespace Functions
+10 −3
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
/* See Copyright Notice in tnl/Copyright */

#include <TNL/Assert.h>
#include <TNL/DevicePointer.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Functions/MeshFunctionEvaluator.h>
#include <TNL/Functions/MeshFunctionNormGetter.h>
@@ -342,7 +343,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >&
MeshFunction< Mesh, MeshEntityDimensions, Real >::
operator = ( const Function& f )
{
   MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f );
   DevicePointer< ThisType > thisDevicePtr( *this );
   DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f );
   MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr );
   return *this;
}

@@ -354,7 +357,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >&
MeshFunction< Mesh, MeshEntityDimensions, Real >::
operator += ( const Function& f )
{
   MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, 1.0 );
   DevicePointer< ThisType > thisDevicePtr( *this );
   DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f );
   MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, 1.0 );
   return *this;
}

@@ -366,7 +371,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >&
MeshFunction< Mesh, MeshEntityDimensions, Real >::
operator -= ( const Function& f )
{
   MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, -1.0 );
   DevicePointer< ThisType > thisDevicePtr( *this );
   DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f );
   MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, -1.0 );
   return *this;
}

+6 −4
Original line number Diff line number Diff line
@@ -35,13 +35,15 @@ bool renderFunction( const Config::ParameterContainer& parameters )
      return false;

   typedef Functions::TestFunction< MeshType::meshDimensions, RealType > FunctionType;
   FunctionType function;
   typedef SharedPointer< FunctionType, typename MeshType::DeviceType > FunctionPointer;
   FunctionPointer function;
   std::cout << "Setting up the function ... " << std::endl;
   if( ! function.setup( parameters, "" ) )
   if( ! function->setup( parameters, "" ) )
      return false;
   std::cout << "done." << std::endl;
   typedef Functions::MeshFunction< MeshType, MeshType::meshDimensions > MeshFunctionType;
   MeshFunctionType meshFunction( meshPointer );
   typedef SharedPointer< MeshFunctionType, typename MeshType::DeviceType > MeshFunctionPointer;
   MeshFunctionPointer meshFunction( meshPointer );
   //if( ! discreteFunction.setSize( mesh.template getEntitiesCount< typename MeshType::Cell >() ) )
   //   return false;
 
@@ -87,7 +89,7 @@ bool renderFunction( const Config::ParameterContainer& parameters )
      }
      else
        std::cout << "+ -> Writing the function to " << outputFile << " ... " << std::endl;
      if( ! meshFunction.save( outputFile) )
      if( ! meshFunction->save( outputFile) )
         return false;
      time += tau;
      step ++;