Commit 429bd511 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactored CUDA parallel scan kernel

Using an odd number of valuesPerThread avoids shared memory bank
conflicts even without a special interleaving. We also save some shared
memory this way.

Small inputs can be scanned with just one CUDA block, which avoids the
scan of block results and second-phase kernel. Hence, large arrays can
be scanned with just 3 kernel launches instead of 4.
parent c37987b1
Loading
Loading
Loading
Loading
+128 −62
Original line number Diff line number Diff line
@@ -45,6 +45,9 @@ CudaScanKernelFirstPhase( const InputView input,
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelFirstPhase" );
   static_assert( blockSize / Cuda::getWarpSize() <= Cuda::getWarpSize(),
                  "blockSize is too large, it would not be possible to scan warpResults using one warp" );
   static_assert( valuesPerThread % 2,
                  "valuesPerThread must be odd, otherwise there would be shared memory bank conflicts "
                  "when threads access their chunks in sharedData sequentially" );

   // calculate indices
   constexpr int maxElementsInBlock = blockSize * valuesPerThread;
@@ -57,8 +60,7 @@ CudaScanKernelFirstPhase( const InputView input,
   outputBegin += threadOffset;

   // allocate shared memory
   constexpr int shmemElements = maxElementsInBlock + maxElementsInBlock / Cuda::getNumberOfSharedMemoryBanks();
   __shared__ ValueType sharedData[ shmemElements ];  // accessed via Cuda::getInterleaving()
   __shared__ ValueType sharedData[ maxElementsInBlock ];
   __shared__ ValueType chunkResults[ blockSize + blockSize / Cuda::getNumberOfSharedMemoryBanks() ];  // accessed via Cuda::getInterleaving()
   __shared__ ValueType warpResults[ Cuda::getWarpSize() ];

@@ -67,7 +69,7 @@ CudaScanKernelFirstPhase( const InputView input,
      int idx = threadIdx.x;
      while( idx < elementsInBlock )
      {
         sharedData[ Cuda::getInterleaving( idx ) ] = input[ begin ];
         sharedData[ idx ] = input[ begin ];
         begin += blockDim.x;
         idx += blockDim.x;
      }
@@ -75,7 +77,7 @@ CudaScanKernelFirstPhase( const InputView input,
      // (this helps to avoid divergent branches in the blocks below)
      while( idx < maxElementsInBlock )
      {
         sharedData[ Cuda::getInterleaving( idx ) ] = zero;
         sharedData[ idx ] = zero;
         idx += blockDim.x;
      }
   }
@@ -85,10 +87,10 @@ CudaScanKernelFirstPhase( const InputView input,
   const int chunkOffset = threadIdx.x * valuesPerThread;
   const int chunkResultIdx = Cuda::getInterleaving( threadIdx.x );
   {
      ValueType chunkResult = sharedData[ Cuda::getInterleaving( chunkOffset ) ];
      ValueType chunkResult = sharedData[ chunkOffset ];
      #pragma unroll
      for( int i = 1; i < valuesPerThread; i++ )
         chunkResult = reduction( chunkResult, sharedData[ Cuda::getInterleaving( chunkOffset + i ) ] );
         chunkResult = reduction( chunkResult, sharedData[ chunkOffset + i ] );

      // store the result of the sequential reduction of the chunk in chunkResults
      chunkResults[ chunkResultIdx ] = chunkResult;
@@ -135,13 +137,12 @@ CudaScanKernelFirstPhase( const InputView input,
      #pragma unroll
      for( int i = 0; i < valuesPerThread; i++ )
      {
         const int sharedIdx = Cuda::getInterleaving( chunkOffset + i );
         const ValueType inputValue = sharedData[ sharedIdx ];
         const ValueType inputValue = sharedData[ chunkOffset + i ];
         if( scanType == ScanType::Exclusive )
            sharedData[ sharedIdx ] = value;
            sharedData[ chunkOffset + i ] = value;
         value = reduction( value, inputValue );
         if( scanType == ScanType::Inclusive )
            sharedData[ sharedIdx ] = value;
            sharedData[ chunkOffset + i ] = value;
      }

      // The last thread of the block stores the block result in the global memory.
@@ -155,7 +156,7 @@ CudaScanKernelFirstPhase( const InputView input,
      int idx = threadIdx.x;
      while( idx < elementsInBlock )
      {
         output[ outputBegin ] = sharedData[ Cuda::getInterleaving( idx ) ];
         output[ outputBegin ] = sharedData[ idx ];
         outputBegin += blockDim.x;
         idx += blockDim.x;
      }
@@ -181,7 +182,7 @@ CudaScanKernelSecondPhase( OutputView output,
      blockResult = blockResults[ gridOffset + blockIdx.x ];

   // update the output offset for the thread
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelFirstPhase" );
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelSecondPhase" );
   constexpr int maxElementsInBlock = blockSize * valuesPerThread;
   const int threadOffset = blockIdx.x * maxElementsInBlock + threadIdx.x;
   outputBegin += threadOffset;
@@ -205,7 +206,8 @@ CudaScanKernelSecondPhase( OutputView output,
 */
template< ScanType scanType,
          int blockSize = 256,
          int valuesPerThread = 8 >
          // valuesPerThread should be odd to avoid shared memory bank conflicts
          int valuesPerThread = 7 >
struct CudaScanKernelLauncher
{
   /****
@@ -243,6 +245,11 @@ struct CudaScanKernelLauncher
         outputBegin,
         reduction,
         zero );

      // if the first-phase kernel was launched with just one block, skip the second phase
      if( blockShifts.getSize() <= 2 )
         return;

      performSecondPhase(
         input,
         output,
@@ -283,6 +290,70 @@ struct CudaScanKernelLauncher
   {
      using Index = typename InputArray::IndexType;

      if( end - begin <= blockSize * valuesPerThread ) {
         // allocate array for the block results
         Containers::Array< typename OutputArray::ValueType, Devices::Cuda > blockResults;
         blockResults.setSize( 2 );
         blockResults.setElement( 0, zero );

         // run the kernel with just 1 block
         if( end - begin <= blockSize )
            CudaScanKernelFirstPhase< scanType, blockSize, 1 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
                 end,
                 outputBegin,
                 reduction,
                 zero,
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else if( end - begin <= blockSize * 3 )
            CudaScanKernelFirstPhase< scanType, blockSize, 3 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
                 end,
                 outputBegin,
                 reduction,
                 zero,
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else if( end - begin <= blockSize * 5 )
            CudaScanKernelFirstPhase< scanType, blockSize, 5 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
                 end,
                 outputBegin,
                 reduction,
                 zero,
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else
            CudaScanKernelFirstPhase< scanType, blockSize, valuesPerThread ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
                 end,
                 outputBegin,
                 reduction,
                 zero,
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );

         // synchronize the null-stream
         cudaStreamSynchronize(0);
         TNL_CHECK_CUDA_DEVICE;

         // Store the number of CUDA grids for the purpose of unit testing, i.e.
         // to check if we test the algorithm with more than one CUDA grid.
         gridsCount() = 1;

         // blockResults now contains shift values for each block - to be used in the second phase
         return blockResults;
      }
      else {
         // compute the number of grids
         constexpr int maxElementsInBlock = blockSize * valuesPerThread;
         const Index numberOfBlocks = roundUpDivision( end - begin, maxElementsInBlock );
@@ -291,7 +362,6 @@ struct CudaScanKernelLauncher
         // allocate array for the block results
         Containers::Array< typename OutputArray::ValueType, Devices::Cuda > blockResults;
         blockResults.setSize( numberOfBlocks + 1 );
      blockResults.setElement( 0, zero );

         // loop over all grids
         for( Index gridIdx = 0; gridIdx < numberOfGrids; gridIdx++ ) {
@@ -313,8 +383,7 @@ struct CudaScanKernelLauncher
                 outputBegin + gridOffset,
                 reduction,
                 zero,
              // blockResults are shifted by 1, because the 0-th element should stay zero
              &blockResults.getData()[ gridIdx * maxGridSize() + 1 ] );
                 &blockResults.getData()[ gridIdx * maxGridSize() ] );
         }

         // synchronize the null-stream after all grids
@@ -323,10 +392,7 @@ struct CudaScanKernelLauncher

         // blockResults now contains scan results for each block. The first phase
         // ends by computing an exclusive scan of this array.
      if( numberOfBlocks > 1 ) {
         // we perform an inclusive scan, but the 0-th is zero and block results
         // were shifted by 1, so effectively we get an exclusive scan
         CudaScanKernelLauncher< ScanType::Inclusive >::perform(
         CudaScanKernelLauncher< ScanType::Exclusive >::perform(
            blockResults,
            blockResults,
            0,
@@ -334,7 +400,6 @@ struct CudaScanKernelLauncher
            0,
            reduction,
            zero );
      }

         // Store the number of CUDA grids for the purpose of unit testing, i.e.
         // to check if we test the algorithm with more than one CUDA grid.
@@ -343,6 +408,7 @@ struct CudaScanKernelLauncher
         // blockResults now contains shift values for each block - to be used in the second phase
         return blockResults;
      }
   }

   /****
    * \brief Performs the second phase of prefix sum.