diff --git a/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem.h b/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem.h index 6fa97c0dc69fea61687fb1f7c1e5763862706bef..c8badd7856092795dcca346972d68d0d3fddea7a 100644 --- a/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem.h +++ b/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem.h @@ -87,6 +87,8 @@ class HeatEquationBenchmarkProblem: BoundaryConditionPointer boundaryConditionPointer; RightHandSidePointer rightHandSidePointer; + MeshFunctionPointer fu, u; + tnlString cudaKernelType; MeshType* cudaMesh; diff --git a/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem_impl.h b/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem_impl.h index eebfbc31963ae850d903f133793437ab9c14f7a4..abd9110370ef45a66d9b651b2b3766fa82fd2fa2 100644 --- a/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem_impl.h +++ b/tests/benchmarks/heat-equation-benchmark/HeatEquationBenchmarkProblem_impl.h @@ -233,9 +233,9 @@ template< typename GridType, typename BoundaryConditions, typename MeshFunction > __global__ void -boundaryConditionsTemplatedCompact( const GridType grid, - const BoundaryConditions boundaryConditions, - MeshFunction u, +boundaryConditionsTemplatedCompact( const GridType* grid, + const BoundaryConditions* boundaryConditions, + MeshFunction* u, const typename GridType::RealType time, const typename GridEntity::CoordinatesType begin, const typename GridEntity::CoordinatesType end, @@ -271,62 +271,17 @@ boundaryConditionsTemplatedCompact( const GridType grid, } } -/*template< typename Grid, - int EntityDimensions = 2, - typename Config = tnlGridEntityNoStencilStorage > -struct TestEntity -{ - typedef Grid GridType; - typedef GridType MeshType; - typedef typename GridType::RealType RealType; - typedef typename GridType::IndexType IndexType; - typedef typename GridType::CoordinatesType CoordinatesType; - typedef Config ConfigType; - - static const int meshDimensions = GridType::meshDimensions; - - static const int entityDimensions = EntityDimensions; - - constexpr static int getDimensions() { return EntityDimensions; }; - - constexpr static int getMeshDimensions() { return meshDimensions; }; - - typedef TestEntity< GridType, EntityDimensions, Config > ThisType; - typedef tnlNeighbourGridEntitiesStorage< ThisType > NeighbourGridEntitiesStorageType; - - - __cuda_callable__ TestEntity( const GridType& grid, - const CoordinatesType& coordinates, - const CoordinatesType& entityOrientation, - const CoordinatesType& entityBasis ) - : grid( grid ), coordinates( coordinates ), - entityOrientation( 0 ), - entityBasis( 1 ), - neighbourEntitiesStorage( *this ) - { - - } - - const GridType& grid; - - CoordinatesType coordinates; - CoordinatesType entityOrientation; - CoordinatesType entityBasis; - - NeighbourGridEntitiesStorageType neighbourEntitiesStorage; -};*/ - template< typename GridType, typename GridEntity, typename DifferentialOperator, typename RightHandSide, typename MeshFunction > __global__ void -heatEquationTemplatedCompact( const GridType grid, - const DifferentialOperator differentialOperator, - const RightHandSide rightHandSide, - MeshFunction u, - MeshFunction fu, +heatEquationTemplatedCompact( const GridType* grid, + const DifferentialOperator* differentialOperator, + const RightHandSide* rightHandSide, + MeshFunction* u, + MeshFunction* fu, const typename GridType::RealType time, const typename GridEntity::CoordinatesType begin, const typename GridEntity::CoordinatesType end, @@ -474,9 +429,9 @@ getExplicitRHS( const RealType& time, { typedef typename MeshType::Cell CellType; typedef typename CellType::CoordinatesType CoordinatesType; - MeshFunctionType u( mesh, uDofs ); - MeshFunctionType fu( mesh, fuDofs ); - fu.getData().setValue( 1.0 ); + u->bind( mesh, uDofs ); + fu->bind( mesh, fuDofs ); + fu->getData().setValue( 1.0 ); const CoordinatesType begin( 0,0 ); const CoordinatesType& end = mesh->getDimensions(); CellType cell( mesh.template getData< DeviceType >() ); @@ -493,9 +448,9 @@ getExplicitRHS( const RealType& time, for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ ) boundaryConditionsTemplatedCompact< MeshType, CellType, BoundaryCondition, MeshFunctionType > <<< cudaBlocks, cudaBlockSize >>> - ( mesh.template getData< DeviceType >(), - boundaryConditionPointer.template getData< DeviceType >(), - u, + ( &mesh.template getData< tnlCuda >(), + &boundaryConditionPointer.template getData< tnlCuda >(), + &u.template getData< tnlCuda >(), time, begin, end, @@ -510,11 +465,11 @@ getExplicitRHS( const RealType& time, for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ ) heatEquationTemplatedCompact< MeshType, CellType, DifferentialOperator, RightHandSide, MeshFunctionType > <<< cudaBlocks, cudaBlockSize >>> - ( mesh.template getData< DeviceType >(), - differentialOperatorPointer.template getData< DeviceType >(), - rightHandSidePointer.template getData< DeviceType >(), - u, - fu, + ( &mesh.template getData< DeviceType >(), + &differentialOperatorPointer.template getData< DeviceType >(), + &rightHandSidePointer.template getData< DeviceType >(), + &u.template getData< DeviceType >(), + &fu.template getData< DeviceType >(), time, begin, end,