Commit ee5a7cb4 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Optimized CUDA grid traverser.

parent 4d7c5d66
Loading
Loading
Loading
Loading
+186 −37
Original line number Diff line number Diff line
@@ -312,16 +312,23 @@ GridTraverser2D(
   coordinates.x() = begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;
   
   /*if( ( !processOnlyBoundaryEntities && coordinates <= end ) ||
       (  processOnlyBoundaryEntities &&
          ( coordinates.x() == begin.x() || coordinates.y() == begin.y() ||
            coordinates.x() == end.x() || coordinates.y() == end.y() ) ) )
   /*if( processOnlyBoundaryEntities && 
      ( GridEntity::getDimensions() == 2 || GridEntity::getDimensions() == 0 ) )
   {
      if( coordinates.x() == begin.x() || coordinates.x() == end.x() ||
          coordinates.y() == begin.y() || coordinates.y() == end.y() )
      {
         GridEntity entity( *grid, coordinates, gridEntityParameters... );
         entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );      
         EntitiesProcessor::processEntity
         ( *grid,
           *userData,
           entity );
      }
      return;
   }*/
   
   
   if( coordinates <= end )
   {
      GridEntity entity( *grid, coordinates, gridEntityParameters... );
@@ -334,8 +341,76 @@ GridTraverser2D(
           entity );
      }
   }
}

template< typename Real,
          typename Index,
          typename GridEntity,
          typename UserData,
          typename EntitiesProcessor,
          bool processOnlyBoundaryEntities,
          typename... GridEntityParameters >
__global__ void 
GridTraverser2DBoundaryAlongX(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const Index beginX,
   const Index endX,
   const Index fixedY,
   const Index gridIdx,
   const GridEntityParameters... gridEntityParameters )
{
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = beginX + ( gridIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = fixedY;  
   
   if( coordinates.x() <= endX )
   {
      GridEntity entity( *grid, coordinates, gridEntityParameters... );
      entity.refresh();
      EntitiesProcessor::processEntity
      ( *grid,
        *userData,
        entity );
   }   
}

template< typename Real,
          typename Index,
          typename GridEntity,
          typename UserData,
          typename EntitiesProcessor,
          bool processOnlyBoundaryEntities,
          typename... GridEntityParameters >
__global__ void 
GridTraverser2DBoundaryAlongY(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const Index beginY,
   const Index endY,
   const Index fixedX,
   const Index gridIdx,
   const GridEntityParameters... gridEntityParameters )
{
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = fixedX;
   coordinates.y() = beginY + ( gridIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   
   if( coordinates.y() <= endY )
   {
      GridEntity entity( *grid, coordinates, gridEntityParameters... );
      entity.refresh();
      EntitiesProcessor::processEntity
      ( *grid,
        *userData,
        entity );
   }   
}

#endif

template< typename Real,
@@ -359,6 +434,79 @@ processEntities(
   const GridEntityParameters&... gridEntityParameters )
{
#ifdef HAVE_CUDA
   if( processOnlyBoundaryEntities && 
      ( GridEntity::getDimensions() == 2 || GridEntity::getDimensions() == 0 ) )
   {
      dim3 cudaBlockSize( 256 );
      const IndexType entitiesAlongX = end.x() - begin.x() + 1;
      const IndexType entitiesAlongY = end.y() - begin.y() - 1;
      dim3 cudaBlocksAlongX, cudaBlocksAlongY;
      cudaBlocksAlongX.x = Devices::Cuda::getNumberOfBlocks( entitiesAlongX, cudaBlockSize.x );
      cudaBlocksAlongY.x = Devices::Cuda::getNumberOfBlocks( entitiesAlongY, cudaBlockSize.x );
      const IndexType cudaGridsAlongX = Devices::Cuda::getNumberOfGrids( cudaBlocksAlongX.x );
      const IndexType cudaGridsAlongY = Devices::Cuda::getNumberOfGrids( cudaBlocksAlongY.x );
      
      auto& pool = CudaStreamPool::getInstance();
      Devices::Cuda::synchronizeDevice();
      
      const cudaStream_t& s1 = pool.getStream( stream );
      const cudaStream_t& s2 = pool.getStream( stream + 1 );
      for( IndexType gridIdx = 0; gridIdx < cudaGridsAlongX; gridIdx++ )
      {
         GridTraverser2DBoundaryAlongX< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaBlocksAlongX, cudaBlockSize, 0, s1 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin.x(),
                 end.x(),
                 begin.y(),
                 gridIdx,
                 gridEntityParameters... );
         GridTraverser2DBoundaryAlongX< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaBlocksAlongX, cudaBlockSize, 0, s2 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin.x(),
                 end.x(),
                 end.y(),
                 gridIdx,
                 gridEntityParameters... );
      }
      const cudaStream_t& s3 = pool.getStream( stream + 2 );
      const cudaStream_t& s4 = pool.getStream( stream + 3 );
      for( IndexType gridIdx = 0; gridIdx < cudaGridsAlongX; gridIdx++ )
      {
         GridTraverser2DBoundaryAlongY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaBlocksAlongY, cudaBlockSize, 0, s3 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin.y() + 1,
                 end.y() - 1,
                 begin.x(),
                 gridIdx,
                 gridEntityParameters... );
         GridTraverser2DBoundaryAlongY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaBlocksAlongY, cudaBlockSize, 0, s4 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin.y() + 1,
                 end.y() - 1,
                 end.x(),
                 gridIdx,
                 gridEntityParameters... );
      }
      // only launches into the stream 0 are synchronized
      if( stream == 0 )
      {
         //cudaStreamSynchronize( s1 );
         //cudaStreamSynchronize( s2 );
         //cudaStreamSynchronize( s3 );
         //cudaStreamSynchronize( s4 );
         checkCudaDevice;
      }         
   }
   else
   {
      dim3 cudaBlockSize( 16, 16 );
      dim3 cudaBlocks;
      cudaBlocks.x = Devices::Cuda::getNumberOfBlocks( end.x() - begin.x() + 1, cudaBlockSize.x );
@@ -373,7 +521,7 @@ processEntities(
      for( IndexType gridYIdx = 0; gridYIdx < cudaYGrids; gridYIdx ++ )
         for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
            GridTraverser2D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
            <<< cudaBlocks, cudaBlockSize, 0, s >>>
               <<< cudaBlocks, cudaBlockSize, 0 >>> //, s >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 begin,
@@ -385,9 +533,10 @@ processEntities(
      // only launches into the stream 0 are synchronized
      if( stream == 0 )
      {
      cudaStreamSynchronize( s );
         //cudaStreamSynchronize( s );
         checkCudaDevice;
      }
   }
#endif
}