Skip to content
Snippets Groups Projects
Commit e8f91de7 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Debuging heat eqquation benchmark.

parent d1678f7c
No related branches found
No related tags found
No related merge requests found
......@@ -87,6 +87,8 @@ class HeatEquationBenchmarkProblem:
BoundaryConditionPointer boundaryConditionPointer;
RightHandSidePointer rightHandSidePointer;
MeshFunctionPointer fu, u;
tnlString cudaKernelType;
MeshType* cudaMesh;
......
......@@ -233,9 +233,9 @@ template< typename GridType,
typename BoundaryConditions,
typename MeshFunction >
__global__ void
boundaryConditionsTemplatedCompact( const GridType grid,
const BoundaryConditions boundaryConditions,
MeshFunction u,
boundaryConditionsTemplatedCompact( const GridType* grid,
const BoundaryConditions* boundaryConditions,
MeshFunction* u,
const typename GridType::RealType time,
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end,
......@@ -271,62 +271,17 @@ boundaryConditionsTemplatedCompact( const GridType grid,
}
}
/*template< typename Grid,
int EntityDimensions = 2,
typename Config = tnlGridEntityNoStencilStorage >
struct TestEntity
{
typedef Grid GridType;
typedef GridType MeshType;
typedef typename GridType::RealType RealType;
typedef typename GridType::IndexType IndexType;
typedef typename GridType::CoordinatesType CoordinatesType;
typedef Config ConfigType;
static const int meshDimensions = GridType::meshDimensions;
static const int entityDimensions = EntityDimensions;
constexpr static int getDimensions() { return EntityDimensions; };
constexpr static int getMeshDimensions() { return meshDimensions; };
typedef TestEntity< GridType, EntityDimensions, Config > ThisType;
typedef tnlNeighbourGridEntitiesStorage< ThisType > NeighbourGridEntitiesStorageType;
__cuda_callable__ TestEntity( const GridType& grid,
const CoordinatesType& coordinates,
const CoordinatesType& entityOrientation,
const CoordinatesType& entityBasis )
: grid( grid ), coordinates( coordinates ),
entityOrientation( 0 ),
entityBasis( 1 ),
neighbourEntitiesStorage( *this )
{
}
const GridType& grid;
CoordinatesType coordinates;
CoordinatesType entityOrientation;
CoordinatesType entityBasis;
NeighbourGridEntitiesStorageType neighbourEntitiesStorage;
};*/
template< typename GridType,
typename GridEntity,
typename DifferentialOperator,
typename RightHandSide,
typename MeshFunction >
__global__ void
heatEquationTemplatedCompact( const GridType grid,
const DifferentialOperator differentialOperator,
const RightHandSide rightHandSide,
MeshFunction u,
MeshFunction fu,
heatEquationTemplatedCompact( const GridType* grid,
const DifferentialOperator* differentialOperator,
const RightHandSide* rightHandSide,
MeshFunction* u,
MeshFunction* fu,
const typename GridType::RealType time,
const typename GridEntity::CoordinatesType begin,
const typename GridEntity::CoordinatesType end,
......@@ -474,9 +429,9 @@ getExplicitRHS( const RealType& time,
{
typedef typename MeshType::Cell CellType;
typedef typename CellType::CoordinatesType CoordinatesType;
MeshFunctionType u( mesh, uDofs );
MeshFunctionType fu( mesh, fuDofs );
fu.getData().setValue( 1.0 );
u->bind( mesh, uDofs );
fu->bind( mesh, fuDofs );
fu->getData().setValue( 1.0 );
const CoordinatesType begin( 0,0 );
const CoordinatesType& end = mesh->getDimensions();
CellType cell( mesh.template getData< DeviceType >() );
......@@ -493,9 +448,9 @@ getExplicitRHS( const RealType& time,
for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
boundaryConditionsTemplatedCompact< MeshType, CellType, BoundaryCondition, MeshFunctionType >
<<< cudaBlocks, cudaBlockSize >>>
( mesh.template getData< DeviceType >(),
boundaryConditionPointer.template getData< DeviceType >(),
u,
( &mesh.template getData< tnlCuda >(),
&boundaryConditionPointer.template getData< tnlCuda >(),
&u.template getData< tnlCuda >(),
time,
begin,
end,
......@@ -510,11 +465,11 @@ getExplicitRHS( const RealType& time,
for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
heatEquationTemplatedCompact< MeshType, CellType, DifferentialOperator, RightHandSide, MeshFunctionType >
<<< cudaBlocks, cudaBlockSize >>>
( mesh.template getData< DeviceType >(),
differentialOperatorPointer.template getData< DeviceType >(),
rightHandSidePointer.template getData< DeviceType >(),
u,
fu,
( &mesh.template getData< DeviceType >(),
&differentialOperatorPointer.template getData< DeviceType >(),
&rightHandSidePointer.template getData< DeviceType >(),
&u.template getData< DeviceType >(),
&fu.template getData< DeviceType >(),
time,
begin,
end,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment