Loading src/TNL/Functions/MeshFunctionEvaluator.h +47 −46 Original line number Diff line number Diff line Loading @@ -21,7 +21,29 @@ namespace Functions { template< typename OutMeshFunction, typename InFunction, typename Real > class MeshFunctionEvaluatorTraverserUserData; class MeshFunctionEvaluatorTraverserUserData { public: typedef InFunction InFunctionType; MeshFunctionEvaluatorTraverserUserData( const InFunction* function, const Real& time, OutMeshFunction* meshFunction, const Real& outFunctionMultiplicator, const Real& inFunctionMultiplicator ) : meshFunction( meshFunction ), function( function ), time( time ), outFunctionMultiplicator( outFunctionMultiplicator ), inFunctionMultiplicator( inFunctionMultiplicator ) {} OutMeshFunction* meshFunction; const InFunction* function; const Real time, outFunctionMultiplicator, inFunctionMultiplicator; }; /*** * General mesh function evaluator. As an input function any type implementing Loading @@ -40,33 +62,35 @@ class MeshFunctionEvaluator "Input and output functions must have the same domain dimensions." ); public: typedef typename OutMeshFunction::MeshType MeshType; typedef typename MeshType::RealType MeshRealType; typedef typename MeshType::DeviceType MeshDeviceType; typedef typename MeshType::IndexType MeshIndexType; typedef typename OutMeshFunction::RealType RealType; typedef typename OutMeshFunction::MeshType MeshType; typedef typename MeshType::DeviceType DeviceType; typedef Functions::MeshFunctionEvaluatorTraverserUserData< OutMeshFunction, InFunction, RealType > TraverserUserData; static void evaluate( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluate( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateAllEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateAllEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateInteriorEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateBoundaryEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); Loading @@ -75,8 +99,9 @@ class MeshFunctionEvaluator enum EntitiesType { all, boundary, interior }; static void evaluateEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator, Loading @@ -85,30 +110,6 @@ class MeshFunctionEvaluator }; template< typename OutMeshFunction, typename InFunction, typename Real > class MeshFunctionEvaluatorTraverserUserData { public: typedef InFunction InFunctionType; MeshFunctionEvaluatorTraverserUserData( const InFunction* function, const Real* time, OutMeshFunction* meshFunction, const Real* outFunctionMultiplicator, const Real* inFunctionMultiplicator ) : meshFunction( meshFunction ), function( function ), time( time ), outFunctionMultiplicator( outFunctionMultiplicator ), inFunctionMultiplicator( inFunctionMultiplicator ){} OutMeshFunction* meshFunction; const InFunction* function; const Real *time, *outFunctionMultiplicator, *inFunctionMultiplicator; }; template< typename MeshType, typename UserData > Loading @@ -124,10 +125,10 @@ class MeshFunctionEvaluatorAssignmentEntitiesProcessor { typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter; ( *userData.meshFunction )( entity ) = *userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, *userData.time ); userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, userData.time ); /*cerr << "Idx = " << entity.getIndex() << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time ) << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time ) << " stored value = " << ( *userData.meshFunction )( entity ) << " multiplicators = " << std::endl;*/ } Loading @@ -147,11 +148,11 @@ class MeshFunctionEvaluatorAdditionEntitiesProcessor { typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter; ( *userData.meshFunction )( entity ) = *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) + *userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, *userData.time ); userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) + userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, userData.time ); /*cerr << "Idx = " << entity.getIndex() << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time ) << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time ) << " stored value = " << ( *userData.meshFunction )( entity ) << " multiplicators = " << std::endl;*/ } Loading src/TNL/Functions/MeshFunctionEvaluator_impl.h +73 −111 Original line number Diff line number Diff line Loading @@ -18,14 +18,18 @@ namespace Functions { template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluate( OutMeshFunction& meshFunction, const InFunction& function, evaluate( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); switch( InFunction::getDomainType() ) { case NonspaceDomain: Loading @@ -45,40 +49,52 @@ evaluate( OutMeshFunction& meshFunction, template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateAllEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateAllEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, all ); } template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateInteriorEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, interior ); } template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateBoundaryEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, boundary ); } Loading @@ -86,73 +102,28 @@ evaluateBoundaryEntities( OutMeshFunction& meshFunction, template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator, EntitiesType entitiesType ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); typedef typename MeshType::template MeshEntity< OutMeshFunction::getEntitiesDimensions() > MeshEntityType; typedef Functions::MeshFunctionEvaluatorAssignmentEntitiesProcessor< MeshType, TraverserUserData > AssignmentEntitiesProcessor; typedef Functions::MeshFunctionEvaluatorAdditionEntitiesProcessor< MeshType, TraverserUserData > AdditionEntitiesProcessor; if( std::is_same< MeshDeviceType, Devices::Host >::value ) { TraverserUserData userData( &function, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator ); Meshes::Traverser< MeshType, MeshEntityType > meshTraverser; switch( entitiesType ) { case all: if( outFunctionMultiplicator ) meshTraverser.template processAllEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processAllEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; case interior: if( outFunctionMultiplicator ) meshTraverser.template processInteriorEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processInteriorEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; case boundary: if( outFunctionMultiplicator ) meshTraverser.template processBoundaryEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processBoundaryEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; } } if( std::is_same< MeshDeviceType, Devices::Cuda >::value ) { OutMeshFunction* kernelMeshFunction = Devices::Cuda::passToDevice( meshFunction ); InFunction* kernelFunction = Devices::Cuda::passToDevice( function ); RealType* kernelTime = Devices::Cuda::passToDevice( time ); RealType* kernelOutFunctionMultiplicator = Devices::Cuda::passToDevice( outFunctionMultiplicator ); RealType* kernelInFunctionMultiplicator = Devices::Cuda::passToDevice( inFunctionMultiplicator ); TraverserUserData userData( kernelFunction, kernelTime, kernelMeshFunction, kernelOutFunctionMultiplicator, kernelInFunctionMultiplicator ); checkCudaDevice; TraverserUserData userData( &function.template getData< DeviceType >(), time, &meshFunction.template modifyData< DeviceType >(), outFunctionMultiplicator, inFunctionMultiplicator ); Meshes::Traverser< MeshType, MeshEntityType > meshTraverser; switch( entitiesType ) { Loading @@ -160,48 +131,39 @@ evaluateEntities( OutMeshFunction& meshFunction, if( outFunctionMultiplicator ) meshTraverser.template processAllEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processAllEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; case interior: if( outFunctionMultiplicator ) meshTraverser.template processInteriorEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processInteriorEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; case boundary: if( outFunctionMultiplicator ) meshTraverser.template processBoundaryEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processBoundaryEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; } checkCudaDevice; Devices::Cuda::freeFromDevice( kernelMeshFunction ); Devices::Cuda::freeFromDevice( kernelFunction ); Devices::Cuda::freeFromDevice( kernelTime ); Devices::Cuda::freeFromDevice( kernelOutFunctionMultiplicator ); Devices::Cuda::freeFromDevice( kernelInFunctionMultiplicator ); checkCudaDevice; } } } // namespace Functions Loading src/TNL/Functions/MeshFunction_impl.h +10 −3 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ /* See Copyright Notice in tnl/Copyright */ #include <TNL/Assert.h> #include <TNL/DevicePointer.h> #include <TNL/Functions/MeshFunction.h> #include <TNL/Functions/MeshFunctionEvaluator.h> #include <TNL/Functions/MeshFunctionNormGetter.h> Loading Loading @@ -342,7 +343,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator = ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr ); return *this; } Loading @@ -354,7 +357,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator += ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, 1.0 ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, 1.0 ); return *this; } Loading @@ -366,7 +371,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator -= ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, -1.0 ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, -1.0 ); return *this; } Loading tools/src/tnl-init.h +6 −4 Original line number Diff line number Diff line Loading @@ -35,13 +35,15 @@ bool renderFunction( const Config::ParameterContainer& parameters ) return false; typedef Functions::TestFunction< MeshType::meshDimensions, RealType > FunctionType; FunctionType function; typedef SharedPointer< FunctionType, typename MeshType::DeviceType > FunctionPointer; FunctionPointer function; std::cout << "Setting up the function ... " << std::endl; if( ! function.setup( parameters, "" ) ) if( ! function->setup( parameters, "" ) ) return false; std::cout << "done." << std::endl; typedef Functions::MeshFunction< MeshType, MeshType::meshDimensions > MeshFunctionType; MeshFunctionType meshFunction( meshPointer ); typedef SharedPointer< MeshFunctionType, typename MeshType::DeviceType > MeshFunctionPointer; MeshFunctionPointer meshFunction( meshPointer ); //if( ! discreteFunction.setSize( mesh.template getEntitiesCount< typename MeshType::Cell >() ) ) // return false; Loading Loading @@ -87,7 +89,7 @@ bool renderFunction( const Config::ParameterContainer& parameters ) } else std::cout << "+ -> Writing the function to " << outputFile << " ... " << std::endl; if( ! meshFunction.save( outputFile) ) if( ! meshFunction->save( outputFile) ) return false; time += tau; step ++; Loading Loading
src/TNL/Functions/MeshFunctionEvaluator.h +47 −46 Original line number Diff line number Diff line Loading @@ -21,7 +21,29 @@ namespace Functions { template< typename OutMeshFunction, typename InFunction, typename Real > class MeshFunctionEvaluatorTraverserUserData; class MeshFunctionEvaluatorTraverserUserData { public: typedef InFunction InFunctionType; MeshFunctionEvaluatorTraverserUserData( const InFunction* function, const Real& time, OutMeshFunction* meshFunction, const Real& outFunctionMultiplicator, const Real& inFunctionMultiplicator ) : meshFunction( meshFunction ), function( function ), time( time ), outFunctionMultiplicator( outFunctionMultiplicator ), inFunctionMultiplicator( inFunctionMultiplicator ) {} OutMeshFunction* meshFunction; const InFunction* function; const Real time, outFunctionMultiplicator, inFunctionMultiplicator; }; /*** * General mesh function evaluator. As an input function any type implementing Loading @@ -40,33 +62,35 @@ class MeshFunctionEvaluator "Input and output functions must have the same domain dimensions." ); public: typedef typename OutMeshFunction::MeshType MeshType; typedef typename MeshType::RealType MeshRealType; typedef typename MeshType::DeviceType MeshDeviceType; typedef typename MeshType::IndexType MeshIndexType; typedef typename OutMeshFunction::RealType RealType; typedef typename OutMeshFunction::MeshType MeshType; typedef typename MeshType::DeviceType DeviceType; typedef Functions::MeshFunctionEvaluatorTraverserUserData< OutMeshFunction, InFunction, RealType > TraverserUserData; static void evaluate( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluate( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateAllEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateAllEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateInteriorEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); static void evaluateBoundaryEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time = 0.0, const RealType& outFunctionMultiplicator = 0.0, const RealType& inFunctionMultiplicator = 1.0 ); Loading @@ -75,8 +99,9 @@ class MeshFunctionEvaluator enum EntitiesType { all, boundary, interior }; static void evaluateEntities( OutMeshFunction& meshFunction, const InFunction& function, template< typename OutMeshFunctionPointer, typename InFunctionPointer > static void evaluateEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator, Loading @@ -85,30 +110,6 @@ class MeshFunctionEvaluator }; template< typename OutMeshFunction, typename InFunction, typename Real > class MeshFunctionEvaluatorTraverserUserData { public: typedef InFunction InFunctionType; MeshFunctionEvaluatorTraverserUserData( const InFunction* function, const Real* time, OutMeshFunction* meshFunction, const Real* outFunctionMultiplicator, const Real* inFunctionMultiplicator ) : meshFunction( meshFunction ), function( function ), time( time ), outFunctionMultiplicator( outFunctionMultiplicator ), inFunctionMultiplicator( inFunctionMultiplicator ){} OutMeshFunction* meshFunction; const InFunction* function; const Real *time, *outFunctionMultiplicator, *inFunctionMultiplicator; }; template< typename MeshType, typename UserData > Loading @@ -124,10 +125,10 @@ class MeshFunctionEvaluatorAssignmentEntitiesProcessor { typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter; ( *userData.meshFunction )( entity ) = *userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, *userData.time ); userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, userData.time ); /*cerr << "Idx = " << entity.getIndex() << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time ) << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time ) << " stored value = " << ( *userData.meshFunction )( entity ) << " multiplicators = " << std::endl;*/ } Loading @@ -147,11 +148,11 @@ class MeshFunctionEvaluatorAdditionEntitiesProcessor { typedef FunctionAdapter< MeshType, typename UserData::InFunctionType > FunctionAdapter; ( *userData.meshFunction )( entity ) = *userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) + *userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, *userData.time ); userData.outFunctionMultiplicator * ( *userData.meshFunction )( entity ) + userData.inFunctionMultiplicator * FunctionAdapter::getValue( *userData.function, entity, userData.time ); /*cerr << "Idx = " << entity.getIndex() << " Value = " << FunctionAdapter::getValue( *userData.function, entity, *userData.time ) << " Value = " << FunctionAdapter::getValue( *userData.function, entity, userData.time ) << " stored value = " << ( *userData.meshFunction )( entity ) << " multiplicators = " << std::endl;*/ } Loading
src/TNL/Functions/MeshFunctionEvaluator_impl.h +73 −111 Original line number Diff line number Diff line Loading @@ -18,14 +18,18 @@ namespace Functions { template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluate( OutMeshFunction& meshFunction, const InFunction& function, evaluate( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); switch( InFunction::getDomainType() ) { case NonspaceDomain: Loading @@ -45,40 +49,52 @@ evaluate( OutMeshFunction& meshFunction, template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateAllEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateAllEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, all ); } template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateInteriorEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateInteriorEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, interior ); } template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateBoundaryEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateBoundaryEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); return evaluateEntities( meshFunction, function, time, outFunctionMultiplicator, inFunctionMultiplicator, boundary ); } Loading @@ -86,73 +102,28 @@ evaluateBoundaryEntities( OutMeshFunction& meshFunction, template< typename OutMeshFunction, typename InFunction > template< typename OutMeshFunctionPointer, typename InFunctionPointer > void MeshFunctionEvaluator< OutMeshFunction, InFunction >:: evaluateEntities( OutMeshFunction& meshFunction, const InFunction& function, evaluateEntities( OutMeshFunctionPointer& meshFunction, const InFunctionPointer& function, const RealType& time, const RealType& outFunctionMultiplicator, const RealType& inFunctionMultiplicator, EntitiesType entitiesType ) { static_assert( std::is_same< typename std::decay< typename OutMeshFunctionPointer::ObjectType >::type, OutMeshFunction >::value, "expected a smart pointer" ); static_assert( std::is_same< typename std::decay< typename InFunctionPointer::ObjectType >::type, InFunction >::value, "expected a smart pointer" ); typedef typename MeshType::template MeshEntity< OutMeshFunction::getEntitiesDimensions() > MeshEntityType; typedef Functions::MeshFunctionEvaluatorAssignmentEntitiesProcessor< MeshType, TraverserUserData > AssignmentEntitiesProcessor; typedef Functions::MeshFunctionEvaluatorAdditionEntitiesProcessor< MeshType, TraverserUserData > AdditionEntitiesProcessor; if( std::is_same< MeshDeviceType, Devices::Host >::value ) { TraverserUserData userData( &function, &time, &meshFunction, &outFunctionMultiplicator, &inFunctionMultiplicator ); Meshes::Traverser< MeshType, MeshEntityType > meshTraverser; switch( entitiesType ) { case all: if( outFunctionMultiplicator ) meshTraverser.template processAllEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processAllEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; case interior: if( outFunctionMultiplicator ) meshTraverser.template processInteriorEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processInteriorEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; case boundary: if( outFunctionMultiplicator ) meshTraverser.template processBoundaryEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); else meshTraverser.template processBoundaryEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), userData ); break; } } if( std::is_same< MeshDeviceType, Devices::Cuda >::value ) { OutMeshFunction* kernelMeshFunction = Devices::Cuda::passToDevice( meshFunction ); InFunction* kernelFunction = Devices::Cuda::passToDevice( function ); RealType* kernelTime = Devices::Cuda::passToDevice( time ); RealType* kernelOutFunctionMultiplicator = Devices::Cuda::passToDevice( outFunctionMultiplicator ); RealType* kernelInFunctionMultiplicator = Devices::Cuda::passToDevice( inFunctionMultiplicator ); TraverserUserData userData( kernelFunction, kernelTime, kernelMeshFunction, kernelOutFunctionMultiplicator, kernelInFunctionMultiplicator ); checkCudaDevice; TraverserUserData userData( &function.template getData< DeviceType >(), time, &meshFunction.template modifyData< DeviceType >(), outFunctionMultiplicator, inFunctionMultiplicator ); Meshes::Traverser< MeshType, MeshEntityType > meshTraverser; switch( entitiesType ) { Loading @@ -160,48 +131,39 @@ evaluateEntities( OutMeshFunction& meshFunction, if( outFunctionMultiplicator ) meshTraverser.template processAllEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processAllEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; case interior: if( outFunctionMultiplicator ) meshTraverser.template processInteriorEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processInteriorEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; case boundary: if( outFunctionMultiplicator ) meshTraverser.template processBoundaryEntities< TraverserUserData, AdditionEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); else meshTraverser.template processBoundaryEntities< TraverserUserData, AssignmentEntitiesProcessor > ( meshFunction.getMeshPointer(), ( meshFunction->getMeshPointer(), userData ); break; } checkCudaDevice; Devices::Cuda::freeFromDevice( kernelMeshFunction ); Devices::Cuda::freeFromDevice( kernelFunction ); Devices::Cuda::freeFromDevice( kernelTime ); Devices::Cuda::freeFromDevice( kernelOutFunctionMultiplicator ); Devices::Cuda::freeFromDevice( kernelInFunctionMultiplicator ); checkCudaDevice; } } } // namespace Functions Loading
src/TNL/Functions/MeshFunction_impl.h +10 −3 Original line number Diff line number Diff line Loading @@ -9,6 +9,7 @@ /* See Copyright Notice in tnl/Copyright */ #include <TNL/Assert.h> #include <TNL/DevicePointer.h> #include <TNL/Functions/MeshFunction.h> #include <TNL/Functions/MeshFunctionEvaluator.h> #include <TNL/Functions/MeshFunctionNormGetter.h> Loading Loading @@ -342,7 +343,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator = ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr ); return *this; } Loading @@ -354,7 +357,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator += ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, 1.0 ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, 1.0 ); return *this; } Loading @@ -366,7 +371,9 @@ MeshFunction< Mesh, MeshEntityDimensions, Real >& MeshFunction< Mesh, MeshEntityDimensions, Real >:: operator -= ( const Function& f ) { MeshFunctionEvaluator< ThisType, Function >::evaluate( *this, f, 1.0, -1.0 ); DevicePointer< ThisType > thisDevicePtr( *this ); DevicePointer< typename std::add_const< Function >::type > fDevicePtr( f ); MeshFunctionEvaluator< ThisType, Function >::evaluate( thisDevicePtr, fDevicePtr, 1.0, -1.0 ); return *this; } Loading
tools/src/tnl-init.h +6 −4 Original line number Diff line number Diff line Loading @@ -35,13 +35,15 @@ bool renderFunction( const Config::ParameterContainer& parameters ) return false; typedef Functions::TestFunction< MeshType::meshDimensions, RealType > FunctionType; FunctionType function; typedef SharedPointer< FunctionType, typename MeshType::DeviceType > FunctionPointer; FunctionPointer function; std::cout << "Setting up the function ... " << std::endl; if( ! function.setup( parameters, "" ) ) if( ! function->setup( parameters, "" ) ) return false; std::cout << "done." << std::endl; typedef Functions::MeshFunction< MeshType, MeshType::meshDimensions > MeshFunctionType; MeshFunctionType meshFunction( meshPointer ); typedef SharedPointer< MeshFunctionType, typename MeshType::DeviceType > MeshFunctionPointer; MeshFunctionPointer meshFunction( meshPointer ); //if( ! discreteFunction.setSize( mesh.template getEntitiesCount< typename MeshType::Cell >() ) ) // return false; Loading Loading @@ -87,7 +89,7 @@ bool renderFunction( const Config::ParameterContainer& parameters ) } else std::cout << "+ -> Writing the function to " << outputFile << " ... " << std::endl; if( ! meshFunction.save( outputFile) ) if( ! meshFunction->save( outputFile) ) return false; time += tau; step ++; Loading