Commit 1bbcf539 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Optimized CUDA traverser for boundary entities in 2D -> speed-up almost 20%.

parent 33b01f62
Loading
Loading
Loading
Loading
+53 −99
Original line number Diff line number Diff line
@@ -403,66 +403,62 @@ template< typename Real,
          bool processOnlyBoundaryEntities,
          typename... GridEntityParameters >
__global__ void 
GridTraverser2DBoundaryAlongX(
GridTraverser2DBoundary(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData userData,
   const Index beginX,
   const Index endX,
   const Index fixedY,
   const Index beginY,
   const Index endY,
   const dim3 gridIdx,
   const GridEntityParameters... gridEntityParameters )
{
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;
   using GridType = Meshes::Grid< 2, Real, Devices::Cuda, Index >;
   using CoordinatesType = typename GridType::CoordinatesType;
   
   coordinates.x() = beginX + Devices::Cuda::getGlobalThreadIdx_x( gridIdx );
   coordinates.y() = fixedY;  
   Index entitiesAlongX = endX - beginX + 1;
   Index entitiesAlongY = endY - beginY;
   
   if( coordinates.x() <= endX )
   Index threadId = Devices::Cuda::getGlobalThreadIdx_x( gridIdx );
   if( threadId < entitiesAlongX )
   {
      GridEntity entity( *grid, coordinates, gridEntityParameters... );
      GridEntity entity( *grid, 
         CoordinatesType( beginX + threadId, beginY ),
         gridEntityParameters... );
      //printf( "X1: Thread %d -> %d %d x %d %d \n ", threadId, 
      //   entity.getCoordinates().x(), entity.getCoordinates().y(),
      //   grid->getDimensions().x(), grid->getDimensions().y() );
      entity.refresh();
      EntitiesProcessor::processEntity
      ( *grid,
        userData,
        entity );
      EntitiesProcessor::processEntity( *grid, userData, entity );
   }
   else if( ( threadId -= entitiesAlongX ) < entitiesAlongX && threadId >= 0 )
   {
      GridEntity entity( *grid, 
         CoordinatesType( beginX + threadId, endY ),
         gridEntityParameters... );
      entity.refresh();
      //printf( "X2: Thread %d -> %d %d \n ", threadId, entity.getCoordinates().x(), entity.getCoordinates().y() );
      EntitiesProcessor::processEntity( *grid, userData, entity );
   }

template< typename Real,
          typename Index,
          typename GridEntity,
          typename UserData,
          typename EntitiesProcessor,
          bool processOnlyBoundaryEntities,
          typename... GridEntityParameters >
__global__ void 
GridTraverser2DBoundaryAlongY(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData userData,
   const Index beginY,
   const Index endY,
   const Index fixedX,
   const dim3 gridIdx,
   const GridEntityParameters... gridEntityParameters )
   else if( ( ( threadId -= entitiesAlongX ) < entitiesAlongY - 1 ) && threadId >= 0 )
   {
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = fixedX;
   coordinates.y() = beginY + Devices::Cuda::getGlobalThreadIdx_x( gridIdx );
   
   if( coordinates.y() <= endY )
      GridEntity entity( *grid,
         CoordinatesType( beginX, beginY + threadId + 1 ),
      gridEntityParameters... );
      entity.refresh();
      //printf( "Y1: Thread %d -> %d %d \n ", threadId, entity.getCoordinates().x(), entity.getCoordinates().y() );
      EntitiesProcessor::processEntity( *grid, userData, entity );      
   }
   else if( ( ( threadId -= entitiesAlongY - 1 ) < entitiesAlongY - 1  ) && threadId >= 0 )
   {
      GridEntity entity( *grid, coordinates, gridEntityParameters... );
      GridEntity entity( *grid,
         CoordinatesType( endX, beginY + threadId + 1 ),
      gridEntityParameters... );
      entity.refresh();
      EntitiesProcessor::processEntity
      ( *grid,
        userData,
        entity );
      //printf( "Y2: Thread %d -> %d %d \n ", threadId, entity.getCoordinates().x(), entity.getCoordinates().y() );
      EntitiesProcessor::processEntity( *grid, userData, entity );
   }
}

#endif

template< typename Real,
@@ -490,68 +486,26 @@ processEntities(
       ( GridEntity::getEntityDimension() == 2 || GridEntity::getEntityDimension() == 0 ) )
   {
      dim3 cudaBlockSize( 256 );      
      dim3 cudaBlocksCountAlongX, cudaGridsCountAlongX,
           cudaBlocksCountAlongY, cudaGridsCountAlongY;
      Devices::Cuda::setupThreads( cudaBlockSize, cudaBlocksCountAlongX, cudaGridsCountAlongX, end.x() - begin.x() + 1 );
      Devices::Cuda::setupThreads( cudaBlockSize, cudaBlocksCountAlongY, cudaGridsCountAlongY, end.y() - begin.y() - 1 );
            
      auto& pool = CudaStreamPool::getInstance();
      Devices::Cuda::synchronizeDevice();
      
      const cudaStream_t& s1 = pool.getStream( stream );
      const cudaStream_t& s2 = pool.getStream( stream + 1 );
      dim3 cudaBlocksCount, cudaGridsCount;
      IndexType cudaThreadsCount = 2 * ( end.x() - begin.x() + end.y() - begin.y() + 1 );
      Devices::Cuda::setupThreads( cudaBlockSize, cudaBlocksCount, cudaGridsCount, cudaThreadsCount );
      dim3 gridIdx, cudaGridSize;
      for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongX.x; gridIdx.x++ )
      Devices::Cuda::synchronizeDevice();
      for( gridIdx.x = 0; gridIdx.x < cudaGridsCount.x; gridIdx.x++ )
      {
         Devices::Cuda::setupGrid( cudaBlocksCountAlongX, cudaGridsCountAlongX, gridIdx, cudaGridSize );
         Devices::Cuda::setupGrid( cudaBlocksCount, cudaGridsCount, gridIdx, cudaGridSize );
         //Devices::Cuda::printThreadsSetup( cudaBlockSize, cudaBlocksCountAlongX, cudaGridSize, cudaGridsCountAlongX );
         GridTraverser2DBoundaryAlongX< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaGridSize, cudaBlockSize, 0, s1 >>>
         GridTraverser2DBoundary< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaGridSize, cudaBlockSize >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 userData,
                 begin.x(),
                 end.x(),
                 begin.y(),
                 gridIdx,
                 gridEntityParameters... );
         GridTraverser2DBoundaryAlongX< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaGridSize, cudaBlockSize, 0, s2 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 userData,
                 begin.x(),
                 end.x(),
                 end.y(),
                 gridIdx,
                 gridEntityParameters... );
      }      
      const cudaStream_t& s3 = pool.getStream( stream + 2 );
      const cudaStream_t& s4 = pool.getStream( stream + 3 );
      for( gridIdx.x = 0; gridIdx.x < cudaGridsCountAlongY.x; gridIdx.x++ )
      {
         Devices::Cuda::setupGrid( cudaBlocksCountAlongY, cudaGridsCountAlongY, gridIdx, cudaGridSize );
         GridTraverser2DBoundaryAlongY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaGridSize, cudaBlockSize, 0, s3 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 userData,
                 begin.y() + 1,
                 end.y() - 1,
                 begin.x(),
                 gridIdx,
                 gridEntityParameters... );
         GridTraverser2DBoundaryAlongY< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities, GridEntityParameters... >
               <<< cudaGridSize, cudaBlockSize, 0, s4 >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 userData,
                 begin.y() + 1,
                 end.y() - 1,
                 end.x(),
                 gridIdx,
                 gridEntityParameters... );
      }
      cudaStreamSynchronize( s1 );
      cudaStreamSynchronize( s2 );
      cudaStreamSynchronize( s3 );
      cudaStreamSynchronize( s4 );
      TNL_CHECK_CUDA_DEVICE;
   }
   else