Skip to content
Snippets Groups Projects
Commit 0d6312b2 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Mesh traverser for CUDA

parent 53ef455d
No related branches found
No related tags found
No related merge requests found
......@@ -80,5 +80,163 @@ processAllEntities( const MeshPointer& meshPointer,
}
}
#ifdef HAVE_CUDA
template< int EntitiesDimension,
typename EntitiesProcessor,
typename Mesh,
typename UserData >
__global__ void
MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
UserData* userData,
typename Mesh::GlobalIndexType entitiesCount )
{
for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x;
i < entitiesCount;
i += blockDim.x * gridDim.x )
{
const auto entityIndex = mesh.template getBoundaryEntityIndex< EntitiesDimension >( i );
auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
// TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
EntitiesProcessor::processEntity( mesh, userData, entity );
}
}
template< int EntitiesDimension,
typename EntitiesProcessor,
typename Mesh,
typename UserData >
__global__ void
MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
UserData* userData,
typename Mesh::GlobalIndexType entitiesCount )
{
for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x;
i < entitiesCount;
i += blockDim.x * gridDim.x )
{
const auto entityIndex = mesh.template getInteriorEntityIndex< EntitiesDimension >( i );
auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
// TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
EntitiesProcessor::processEntity( mesh, userData, entity );
}
}
template< int EntitiesDimension,
typename EntitiesProcessor,
typename Mesh,
typename UserData >
__global__ void
MeshTraverserAllEntitiesKernel( Mesh* mesh,
UserData* userData,
typename Mesh::GlobalIndexType entitiesCount )
{
for( typename Mesh::GlobalIndexType entityIndex = blockIdx.x * blockDim.x + threadIdx.x;
entityIndex < entitiesCount;
entityIndex += blockDim.x * gridDim.x )
{
auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
// TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
EntitiesProcessor::processEntity( mesh, userData, entity );
}
}
#if (__CUDA_ARCH__ >= 300 )
static constexpr int Traverser_minBlocksPerMultiprocessor = 8;
#else
static constexpr int Traverser_minBlocksPerMultiprocessor = 4;
#endif
#endif
template< typename MeshConfig,
typename MeshEntity,
int EntitiesDimension >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
processBoundaryEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const
{
#ifdef HAVE_CUDA
auto entitiesCount = meshPointer->template getBoundaryEntitiesCount< EntitiesDimension >();
dim3 blockSize( 256 );
dim3 gridSize;
const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
* Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
Devices::Cuda::synchronizeDevice();
MeshTraverserBoundaryEntitiesKernel< EntitiesDimension, EntitiesProcessor >
<<< gridSize, blockSize >>>
( &meshPointer.template modifyData< Devices::Cuda >(),
&userDataPointer.template modifyData< Devices::Cuda >(),
entitiesCount );
#else
CudaSupportMissingMessage;
#endif
}
template< typename MeshConfig,
typename MeshEntity,
int EntitiesDimension >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
processInteriorEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const
{
#ifdef HAVE_CUDA
auto entitiesCount = meshPointer->template getInteriorEntitiesCount< EntitiesDimension >();
dim3 blockSize( 256 );
dim3 gridSize;
const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
* Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
Devices::Cuda::synchronizeDevice();
MeshTraverserInteriorEntitiesKernel< EntitiesDimension, EntitiesProcessor >
<<< gridSize, blockSize >>>
( &meshPointer.template modifyData< Devices::Cuda >(),
&userDataPointer.template modifyData< Devices::Cuda >(),
entitiesCount );
#else
CudaSupportMissingMessage;
#endif
}
template< typename MeshConfig,
typename MeshEntity,
int EntitiesDimension >
template< typename UserData,
typename EntitiesProcessor >
void
Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
processAllEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const
{
#ifdef HAVE_CUDA
auto entitiesCount = meshPointer->template getEntitiesCount< EntitiesDimension >();
dim3 blockSize( 256 );
dim3 gridSize;
const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
* Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
Devices::Cuda::synchronizeDevice();
MeshTraverserAllEntitiesKernel< EntitiesDimension, EntitiesProcessor >
<<< gridSize, blockSize >>>
( &meshPointer.template modifyData< Devices::Cuda >(),
&userDataPointer.template modifyData< Devices::Cuda >(),
entitiesCount );
#else
CudaSupportMissingMessage;
#endif
}
} // namespace Meshes
} // namespace TNL
......@@ -11,6 +11,7 @@
#pragma once
#include <TNL/SharedPointer.h>
#include <TNL/Meshes/Mesh.h>
namespace TNL {
namespace Meshes {
......@@ -41,6 +42,32 @@ class Traverser
SharedPointer< UserData, DeviceType >& userDataPointer ) const;
};
template< typename MeshConfig,
typename MeshEntity,
int EntitiesDimension >
class Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >
{
public:
using MeshType = Mesh< MeshConfig, Devices::Cuda >;
using MeshPointer = SharedPointer< MeshType >;
using DeviceType = typename MeshType::DeviceType;
template< typename UserData,
typename EntitiesProcessor >
void processBoundaryEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const;
template< typename UserData,
typename EntitiesProcessor >
void processInteriorEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const;
template< typename UserData,
typename EntitiesProcessor >
void processAllEntities( const MeshPointer& meshPointer,
SharedPointer< UserData, DeviceType >& userDataPointer ) const;
};
} // namespace Meshes
} // namespace TNL
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment