From 0d6312b2396c3b9b1d9d37684d066d8243785171 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Sat, 18 Feb 2017 00:10:04 +0100
Subject: [PATCH] Mesh traverser for CUDA

---
 src/TNL/Meshes/MeshDetails/Traverser_impl.h | 158 ++++++++++++++++++++
 src/TNL/Meshes/Traverser.h                  |  27 ++++
 2 files changed, 185 insertions(+)

diff --git a/src/TNL/Meshes/MeshDetails/Traverser_impl.h b/src/TNL/Meshes/MeshDetails/Traverser_impl.h
index 64b91d2952..666b19915e 100644
--- a/src/TNL/Meshes/MeshDetails/Traverser_impl.h
+++ b/src/TNL/Meshes/MeshDetails/Traverser_impl.h
@@ -80,5 +80,163 @@ processAllEntities( const MeshPointer& meshPointer,
    }
 }
 
+
+#ifdef HAVE_CUDA
+template< int EntitiesDimension,
+          typename EntitiesProcessor,
+          typename Mesh,
+          typename UserData >
+__global__ void
+MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
+                                     UserData* userData,
+                                     typename Mesh::GlobalIndexType entitiesCount )
+{
+   for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x;
+        i < entitiesCount;
+        i += blockDim.x * gridDim.x )
+   {
+      const auto entityIndex = mesh.template getBoundaryEntityIndex< EntitiesDimension >( i );
+      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
+      EntitiesProcessor::processEntity( mesh, userData, entity );
+   }
+}
+
+template< int EntitiesDimension,
+          typename EntitiesProcessor,
+          typename Mesh,
+          typename UserData >
+__global__ void
+MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
+                                     UserData* userData,
+                                     typename Mesh::GlobalIndexType entitiesCount )
+{
+   for( typename Mesh::GlobalIndexType i = blockIdx.x * blockDim.x + threadIdx.x;
+        i < entitiesCount;
+        i += blockDim.x * gridDim.x )
+   {
+      const auto entityIndex = mesh.template getInteriorEntityIndex< EntitiesDimension >( i );
+      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
+      EntitiesProcessor::processEntity( mesh, userData, entity );
+   }
+}
+
+template< int EntitiesDimension,
+          typename EntitiesProcessor,
+          typename Mesh,
+          typename UserData >
+__global__ void
+MeshTraverserAllEntitiesKernel( Mesh* mesh,
+                                UserData* userData,
+                                typename Mesh::GlobalIndexType entitiesCount )
+{
+   for( typename Mesh::GlobalIndexType entityIndex = blockIdx.x * blockDim.x + threadIdx.x;
+        entityIndex < entitiesCount;
+        entityIndex += blockDim.x * gridDim.x )
+   {
+      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
+      EntitiesProcessor::processEntity( mesh, userData, entity );
+   }
+}
+
+#if (__CUDA_ARCH__ >= 300 )
+   static constexpr int Traverser_minBlocksPerMultiprocessor = 8;
+#else
+   static constexpr int Traverser_minBlocksPerMultiprocessor = 4;
+#endif
+#endif
+
+template< typename MeshConfig,
+          typename MeshEntity,
+          int EntitiesDimension >
+   template< typename UserData,
+             typename EntitiesProcessor >
+void
+Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
+processBoundaryEntities( const MeshPointer& meshPointer,
+                         SharedPointer< UserData, DeviceType >& userDataPointer ) const
+{
+#ifdef HAVE_CUDA
+   auto entitiesCount = meshPointer->template getBoundaryEntitiesCount< EntitiesDimension >();
+
+   dim3 blockSize( 256 );
+   dim3 gridSize;
+   const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
+                             * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
+   gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
+
+   Devices::Cuda::synchronizeDevice();
+   MeshTraverserBoundaryEntitiesKernel< EntitiesDimension, EntitiesProcessor >
+      <<< gridSize, blockSize >>>
+      ( &meshPointer.template modifyData< Devices::Cuda >(),
+        &userDataPointer.template modifyData< Devices::Cuda >(),
+        entitiesCount );
+#else
+   CudaSupportMissingMessage;
+#endif
+}
+
+template< typename MeshConfig,
+          typename MeshEntity,
+          int EntitiesDimension >
+   template< typename UserData,
+             typename EntitiesProcessor >
+void
+Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
+processInteriorEntities( const MeshPointer& meshPointer,
+                         SharedPointer< UserData, DeviceType >& userDataPointer ) const
+{
+#ifdef HAVE_CUDA
+   auto entitiesCount = meshPointer->template getInteriorEntitiesCount< EntitiesDimension >();
+
+   dim3 blockSize( 256 );
+   dim3 gridSize;
+   const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
+                             * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
+   gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
+
+   Devices::Cuda::synchronizeDevice();
+   MeshTraverserInteriorEntitiesKernel< EntitiesDimension, EntitiesProcessor >
+      <<< gridSize, blockSize >>>
+      ( &meshPointer.template modifyData< Devices::Cuda >(),
+        &userDataPointer.template modifyData< Devices::Cuda >(),
+        entitiesCount );
+#else
+   CudaSupportMissingMessage;
+#endif
+}
+
+template< typename MeshConfig,
+          typename MeshEntity,
+          int EntitiesDimension >
+   template< typename UserData,
+             typename EntitiesProcessor >
+void
+Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >::
+processAllEntities( const MeshPointer& meshPointer,
+                    SharedPointer< UserData, DeviceType >& userDataPointer ) const
+{
+#ifdef HAVE_CUDA
+   auto entitiesCount = meshPointer->template getEntitiesCount< EntitiesDimension >();
+
+   dim3 blockSize( 256 );
+   dim3 gridSize;
+   const int desGridSize = 4 * Traverser_minBlocksPerMultiprocessor
+                             * Devices::CudaDeviceInfo::getCudaMultiprocessors( Devices::CudaDeviceInfo::getActiveDevice() );
+   gridSize.x = min( desGridSize, Devices::Cuda::getNumberOfBlocks( entitiesCount, blockSize.x ) );
+
+   Devices::Cuda::synchronizeDevice();
+   MeshTraverserAllEntitiesKernel< EntitiesDimension, EntitiesProcessor >
+      <<< gridSize, blockSize >>>
+      ( &meshPointer.template modifyData< Devices::Cuda >(),
+        &userDataPointer.template modifyData< Devices::Cuda >(),
+        entitiesCount );
+#else
+   CudaSupportMissingMessage;
+#endif
+}
+
 } // namespace Meshes
 } // namespace TNL
diff --git a/src/TNL/Meshes/Traverser.h b/src/TNL/Meshes/Traverser.h
index db1fc3b579..ce0e0bf99d 100644
--- a/src/TNL/Meshes/Traverser.h
+++ b/src/TNL/Meshes/Traverser.h
@@ -11,6 +11,7 @@
 #pragma once
 
 #include <TNL/SharedPointer.h>
+#include <TNL/Meshes/Mesh.h>
 
 namespace TNL {
 namespace Meshes {
@@ -41,6 +42,32 @@ class Traverser
                                SharedPointer< UserData, DeviceType >& userDataPointer ) const;
 };
 
+template< typename MeshConfig,
+          typename MeshEntity,
+          int EntitiesDimension >
+class Traverser< Mesh< MeshConfig, Devices::Cuda >, MeshEntity, EntitiesDimension >
+{
+   public:
+      using MeshType = Mesh< MeshConfig, Devices::Cuda >;
+      using MeshPointer = SharedPointer< MeshType >;
+      using DeviceType = typename MeshType::DeviceType;
+
+      template< typename UserData,
+                typename EntitiesProcessor >
+      void processBoundaryEntities( const MeshPointer& meshPointer,
+                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;
+
+      template< typename UserData,
+                typename EntitiesProcessor >
+      void processInteriorEntities( const MeshPointer& meshPointer,
+                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;
+
+      template< typename UserData,
+                typename EntitiesProcessor >
+      void processAllEntities( const MeshPointer& meshPointer,
+                               SharedPointer< UserData, DeviceType >& userDataPointer ) const;
+};
+
 } // namespace Meshes
 } // namespace TNL
 
-- 
GitLab