Commit 250d21c2 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized grid traversers for faces and edges using CUDA streams

parent 6fd14031
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -18,6 +18,7 @@ SET( CURRENT_DIR ${CMAKE_SOURCE_DIR}/src/TNL )
set( headers 
     Assert.h
     Constants.h
     CudaStreamPool.h
     Curve.h
     DevicePointer.h
     File.h
+62 −0
Original line number Diff line number Diff line
#pragma once

#include <stdlib.h>
#include <unordered_map>

#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>

namespace TNL {

#ifdef HAVE_CUDA
class CudaStreamPool
{
   public:
      // stop the compiler generating methods of copy the object
      CudaStreamPool( CudaStreamPool const& copy ) = delete;
      CudaStreamPool& operator=( CudaStreamPool const& copy ) = delete;

      inline static CudaStreamPool& getInstance()
      {
         static CudaStreamPool instance;
         return instance;
      }

      const cudaStream_t& getStream( int s )
      {
         auto result = pool.insert( {s, cudaStream_t()} );
         cudaStream_t& stream = (*result.first).second;
         bool& inserted = result.second;
         if( inserted ) {
            cudaStreamCreate( &stream );
         }
         return stream;
      }

   private:
      // private constructor of the singleton
      inline CudaStreamPool()
      {
         atexit( CudaStreamPool::free_atexit );
      }

      inline static void free_atexit( void )
      {
         CudaStreamPool::getInstance().free();
      }

   protected:
      using MapType = std::unordered_map< int, cudaStream_t >;

      inline void free( void )
      {
         for( auto& p : pool )
            cudaStreamDestroy( p.second );
      }

      MapType pool;
};
#endif

} // namespace TNL
+10 −3
Original line number Diff line number Diff line
@@ -10,8 +10,9 @@

#pragma once

#include <TNL/Meshes/Grid.h>
#include <TNL/SharedPointer.h>

#include <TNL/CudaStreamPool.h>

namespace TNL {
namespace Meshes {
@@ -50,7 +51,8 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Host, Index > >
         const GridPointer& gridPointer,
         const CoordinatesType begin,
         const CoordinatesType end,
         SharedPointer< UserData, DeviceType >& userData );
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0 );
};

/****
@@ -79,7 +81,8 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >
         const GridPointer& gridPointer,
         const CoordinatesType& begin,
         const CoordinatesType& end,
         SharedPointer< UserData, DeviceType >& userData );
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0 );
};

/****
@@ -112,6 +115,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Host, Index > >
         const CoordinatesType begin,
         const CoordinatesType end,
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0,
         // gridEntityParameters are passed to GridEntity's constructor
         // (i.e. orientation and basis for faces)
         const GridEntityParameters&... gridEntityParameters );
@@ -147,6 +151,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Cuda, Index > >
         const CoordinatesType& begin,
         const CoordinatesType& end,
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0,
         // gridEntityParameters are passed to GridEntity's constructor
         // (i.e. orientation and basis for faces)
         const GridEntityParameters&... gridEntityParameters );
@@ -183,6 +188,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Host, Index > >
         const CoordinatesType begin,
         const CoordinatesType end,
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0,
         // gridEntityParameters are passed to GridEntity's constructor
         // (i.e. orientation and basis for faces and edges)
         const GridEntityParameters&... gridEntityParameters );
@@ -219,6 +225,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Cuda, Index > >
         const CoordinatesType& begin,
         const CoordinatesType& end,
         SharedPointer< UserData, DeviceType >& userData,
         const int& stream = 0,
         // gridEntityParameters are passed to GridEntity's constructor
         // (i.e. orientation and basis for faces and edges)
         const GridEntityParameters&... gridEntityParameters );
+40 −12
Original line number Diff line number Diff line
@@ -29,7 +29,8 @@ processEntities(
   const GridPointer& gridPointer,
   const CoordinatesType begin,
   const CoordinatesType end,
   SharedPointer< UserData, DeviceType >& userDataPointer )
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream )
{
   GridEntity entity( *gridPointer );
   if( processOnlyBoundaryEntities )
@@ -152,16 +153,20 @@ processEntities(
   const GridPointer& gridPointer,
   const CoordinatesType& begin,
   const CoordinatesType& end,
   SharedPointer< UserData, DeviceType >& userDataPointer )
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream )
{
#ifdef HAVE_CUDA
   auto& pool = CudaStreamPool::getInstance();
   const cudaStream_t& s = pool.getStream( stream );

   Devices::Cuda::synchronizeDevice();
   if( processOnlyBoundaryEntities )
   {
      dim3 cudaBlockSize( 2 );
      dim3 cudaBlocks( 1 );
      GridBoundaryTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            <<< cudaBlocks, cudaBlockSize, 0, s >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              &userDataPointer.template modifyData< Devices::Cuda >(),
              begin,
@@ -176,15 +181,20 @@ processEntities(

      for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
         GridTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            <<< cudaBlocks, cudaBlockSize, 0, s >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              &userDataPointer.template modifyData< Devices::Cuda >(),
              begin,
              end,
              gridXIdx );
   }
   cudaThreadSynchronize();

   // only launches into the stream 0 are synchronized
   if( stream == 0 )
   {
      cudaStreamSynchronize( s );
      checkCudaDevice;
   }
#endif
}

@@ -209,6 +219,7 @@ processEntities(
   const CoordinatesType begin,
   const CoordinatesType end,
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream,
   const GridEntityParameters&... gridEntityParameters )
{
   if( processOnlyBoundaryEntities )
@@ -333,6 +344,7 @@ processEntities(
   const CoordinatesType& begin,
   const CoordinatesType& end,
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream,
   const GridEntityParameters&... gridEntityParameters )
{
#ifdef HAVE_CUDA   
@@ -343,11 +355,14 @@ processEntities(
   const IndexType cudaXGrids = Devices::Cuda::getNumberOfGrids( cudaBlocks.x );
   const IndexType cudaYGrids = Devices::Cuda::getNumberOfGrids( cudaBlocks.y );

   auto& pool = CudaStreamPool::getInstance();
   const cudaStream_t& s = pool.getStream( stream );

   Devices::Cuda::synchronizeDevice();
   for( IndexType gridYIdx = 0; gridYIdx < cudaYGrids; gridYIdx ++ )
      for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
         GridTraverser2D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
            <<< cudaBlocks, cudaBlockSize >>>
            <<< cudaBlocks, cudaBlockSize, 0, s >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              &userDataPointer.template modifyData< Devices::Cuda >(),
              begin,
@@ -356,8 +371,12 @@ processEntities(
              gridYIdx,
              gridEntityParameters... );
 
   cudaThreadSynchronize();
   // only launches into the stream 0 are synchronized
   if( stream == 0 )
   {
      cudaStreamSynchronize( s );
      checkCudaDevice;
   }
#endif
}

@@ -382,6 +401,7 @@ processEntities(
   const CoordinatesType begin,
   const CoordinatesType end,
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream,
   const GridEntityParameters&... gridEntityParameters )
{
   if( processOnlyBoundaryEntities )
@@ -535,6 +555,7 @@ processEntities(
   const CoordinatesType& begin,
   const CoordinatesType& end,
   SharedPointer< UserData, DeviceType >& userDataPointer,
   const int& stream,
   const GridEntityParameters&... gridEntityParameters )
{
#ifdef HAVE_CUDA   
@@ -547,12 +568,15 @@ processEntities(
   const IndexType cudaYGrids = Devices::Cuda::getNumberOfGrids( cudaBlocks.y );
   const IndexType cudaZGrids = Devices::Cuda::getNumberOfGrids( cudaBlocks.z );

   auto& pool = CudaStreamPool::getInstance();
   const cudaStream_t& s = pool.getStream( stream );

   Devices::Cuda::synchronizeDevice();
   for( IndexType gridZIdx = 0; gridZIdx < cudaZGrids; gridZIdx ++ )
      for( IndexType gridYIdx = 0; gridYIdx < cudaYGrids; gridYIdx ++ )
         for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
            GridTraverser3D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaBlocks, cudaBlockSize >>>
               <<< cudaBlocks, cudaBlockSize, 0, s >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin,
@@ -562,8 +586,12 @@ processEntities(
                 gridZIdx,
                 gridEntityParameters... );

   cudaThreadSynchronize();
   // only launches into the stream 0 are synchronized
   if( stream == 0 )
   {
      cudaStreamSynchronize( s );
      checkCudaDevice;
   }
#endif
}

+1 −0
Original line number Diff line number Diff line
@@ -11,6 +11,7 @@
#pragma once

#include <TNL/Meshes/Traverser.h>
#include <TNL/SharedPointer.h>

namespace TNL {
namespace Meshes {
Loading