diff --git a/src/TNL/Meshes/MeshDetails/Traverser_impl.h b/src/TNL/Meshes/MeshDetails/Traverser_impl.h index 64b91d29524b6b8fb435ec5f247dea7b5d773055..666b19915e50325f39b9881ad1c4bf346f4b90b7 100644 --- a/src/TNL/Meshes/MeshDetails/Traverser_impl.h +++ b/src/TNL/Meshes/MeshDetails/Traverser_impl.h @@ -80,5 +80,163 @@ processAllEntities( const MeshPointer& meshPointer, } } + +#ifdef HAVE_CUDA +template< int EntitiesDimension, + typename EntitiesProcessor, + typename Mesh, + typename UserData > +__global__ void +MeshTraverserBoundaryEntitiesKernel( Mesh* mesh, + UserData* userData, + typename Mesh::GlobalIndexType entitiesCount ) +{ + for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x; + i < entitiesCount; + i += blockDim.x * gridDim.x ) + { + const auto entityIndex = mesh.template getBoundaryEntityIndex< EntitiesDimension >( i ); + auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex ); + // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex + EntitiesProcessor::processEntity( mesh, userData, entity ); + } +} + +template< int EntitiesDimension, + typename EntitiesProcessor, + typename Mesh, + typename UserData > +__global__ void +MeshTraverserInteriorEntitiesKernel( Mesh* mesh, + UserData* userData, + typename Mesh::GlobalIndexType entitiesCount ) +{ + for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x; + i < entitiesCount; + i += blockDim.x * gridDim.x ) + { + const auto entityIndex = mesh.template getInteriorEntityIndex< EntitiesDimension >( i ); + auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex ); + // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex + EntitiesProcessor::processEntity( mesh, userData, entity ); + } +} + +template< int EntitiesDimension, + typename EntitiesProcessor, + typename Mesh, + typename UserData > +__global__ void +MeshTraverserAllEntitiesKernel( Mesh* mesh, + UserData* userData, + typename Mesh::GlobalIndexType entitiesCount ) +{ + for( typename Mesh::GlobalIndexType entityIndex = blockIdx.x * blockDim.x + threadIdx.x; + entityIndex < entitiesCount; + entityIndex += blockDim.x * gridDim.x ) + { + auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex ); + // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex + EntitiesProcessor::processEntity( mesh, userData, entity ); + } +} + +#if (__CUDA_ARCH__ >= 300 ) + static constexpr int Traverser_minBlocksPerMultiprocessor = 8; +#else + static constexpr int Traverser_minBlocksPerMultiprocessor = 4; +#endif +#endif + +template< typename MeshConfig, + typename MeshEntity, + int EntitiesDimension > + template< typename UserData, + typename EntitiesProcessor > +void +Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >:: +processBoundaryEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const +{ +#ifdef HAVE_CUDA + auto entitiesCount = meshPointer->template getBoundaryEntitiesCount< EntitiesDimension >(); + + dim3 blockSize( 256 ); + dim3 gridSize; + const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor + * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() ); + gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) ); + + Devices::Cuda::synchronizeDevice(); + MeshTraverserBoundaryEntitiesKernel< EntitiesDimension, EntitiesProcessor > + <<< gridSize, blockSize >>> + ( &meshPointer.template modifyData< Devices::Cuda >(), + &userDataPointer.template modifyData< Devices::Cuda >(), + entitiesCount ); +#else + CudaSupportMissingMessage; +#endif +} + +template< typename MeshConfig, + typename MeshEntity, + int EntitiesDimension > + template< typename UserData, + typename EntitiesProcessor > +void +Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >:: +processInteriorEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const +{ +#ifdef HAVE_CUDA + auto entitiesCount = meshPointer->template getInteriorEntitiesCount< EntitiesDimension >(); + + dim3 blockSize( 256 ); + dim3 gridSize; + const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor + * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() ); + gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) ); + + Devices::Cuda::synchronizeDevice(); + MeshTraverserInteriorEntitiesKernel< EntitiesDimension, EntitiesProcessor > + <<< gridSize, blockSize >>> + ( &meshPointer.template modifyData< Devices::Cuda >(), + &userDataPointer.template modifyData< Devices::Cuda >(), + entitiesCount ); +#else + CudaSupportMissingMessage; +#endif +} + +template< typename MeshConfig, + typename MeshEntity, + int EntitiesDimension > + template< typename UserData, + typename EntitiesProcessor > +void +Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >:: +processAllEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const +{ +#ifdef HAVE_CUDA + auto entitiesCount = meshPointer->template getEntitiesCount< EntitiesDimension >(); + + dim3 blockSize( 256 ); + dim3 gridSize; + const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor + * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() ); + gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) ); + + Devices::Cuda::synchronizeDevice(); + MeshTraverserAllEntitiesKernel< EntitiesDimension, EntitiesProcessor > + <<< gridSize, blockSize >>> + ( &meshPointer.template modifyData< Devices::Cuda >(), + &userDataPointer.template modifyData< Devices::Cuda >(), + entitiesCount ); +#else + CudaSupportMissingMessage; +#endif +} + } // namespace Meshes } // namespace TNL diff --git a/src/TNL/Meshes/Traverser.h b/src/TNL/Meshes/Traverser.h index db1fc3b57925e67b15f364ac7bbafd426e7c3f7a..ce0e0bf99d818b696c94ed8791dad5717dcb3cde 100644 --- a/src/TNL/Meshes/Traverser.h +++ b/src/TNL/Meshes/Traverser.h @@ -11,6 +11,7 @@ #pragma once #include <TNL/SharedPointer.h> +#include <TNL/Meshes/Mesh.h> namespace TNL { namespace Meshes { @@ -41,6 +42,32 @@ class Traverser SharedPointer< UserData, DeviceType >& userDataPointer ) const; }; +template< typename MeshConfig, + typename MeshEntity, + int EntitiesDimension > +class Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension > +{ + public: + using MeshType = Mesh< MeshConfig, Devices::Cuda >; + using MeshPointer = SharedPointer< MeshType >; + using DeviceType = typename MeshType::DeviceType; + + template< typename UserData, + typename EntitiesProcessor > + void processBoundaryEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const; + + template< typename UserData, + typename EntitiesProcessor > + void processInteriorEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const; + + template< typename UserData, + typename EntitiesProcessor > + void processAllEntities( const MeshPointer& meshPointer, + SharedPointer< UserData, DeviceType >& userDataPointer ) const; +}; + } // namespace Meshes } // namespace TNL