Commit 8accbc52 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized parallel CUDA scan algorithm to avoid unnecessary writing in the first phase

The original approach (prescan + uniform shift) is more efficient for
inputs that are expensive to evaluate, such as vector expressions.
parent 2f61104b
Loading
Loading
Loading
Loading
+219 −54
Original line number Diff line number Diff line
@@ -246,6 +246,116 @@ struct CudaTileScan
   }
};

/* CudaScanKernelUpsweep - compute partial reductions per each CUDA block.
 */
template< int blockSize,
          int valuesPerThread,
          typename InputView,
          typename Reduction,
          typename ValueType >
__global__ void
CudaScanKernelUpsweep( const InputView input,
                       typename InputView::IndexType begin,
                       typename InputView::IndexType end,
                       Reduction reduction,
                       ValueType zero,
                       ValueType* reductionResults )
{
   // verify the configuration
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelUpsweep" );
   static_assert( valuesPerThread % 2,
                  "valuesPerThread must be odd, otherwise there would be shared memory bank conflicts "
                  "when threads access their chunks in shared memory sequentially" );

   // allocate shared memory
   using BlockReduce = CudaBlockReduce< blockSize, Reduction, ValueType >;
   union Shared {
      ValueType data[ blockSize * valuesPerThread ];
      typename BlockReduce::Storage blockReduceStorage;
   };
   __shared__ Shared storage;

   // calculate indices
   constexpr int maxElementsInBlock = blockSize * valuesPerThread;
   const int remainingElements = end - begin - blockIdx.x * maxElementsInBlock;
   const int elementsInBlock = TNL::min( remainingElements, maxElementsInBlock );

   // update global array offset for the thread
   const int threadOffset = blockIdx.x * maxElementsInBlock + threadIdx.x;
   begin += threadOffset;

   // Load data into the shared memory.
   {
      int idx = threadIdx.x;
      while( idx < elementsInBlock )
      {
         storage.data[ idx ] = input[ begin ];
         begin += blockDim.x;
         idx += blockDim.x;
      }
      // fill the remaining (maxElementsInBlock - elementsInBlock) values with zero
      // (this helps to avoid divergent branches in the blocks below)
      while( idx < maxElementsInBlock )
      {
         storage.data[ idx ] = zero;
         idx += blockDim.x;
      }
   }
   __syncthreads();

   // Perform sequential reduction of the thread's chunk in shared memory.
   const int chunkOffset = threadIdx.x * valuesPerThread;
   ValueType value = storage.data[ chunkOffset ];
   #pragma unroll
   for( int i = 1; i < valuesPerThread; i++ )
      value = reduction( value, storage.data[ chunkOffset + i ] );
   __syncthreads();

   // Perform the parallel reduction.
   value = BlockReduce::reduce( reduction, value, threadIdx.x, storage.blockReduceStorage );

   // Store the block result in the global memory.
   if( threadIdx.x == 0 )
      reductionResults[ blockIdx.x ] = value;
}

/* CudaScanKernelDownsweep - scan each tile of the input separately in each CUDA
 * block and use the result of spine scan as the initial value
 */
template< ScanType scanType,
          int blockSize,
          int valuesPerThread,
          typename InputView,
          typename OutputView,
          typename Reduction >
__global__ void
CudaScanKernelDownsweep( const InputView input,
                         OutputView output,
                         typename InputView::IndexType begin,
                         typename InputView::IndexType end,
                         typename OutputView::IndexType outputBegin,
                         Reduction reduction,
                         typename OutputView::ValueType zero,
                         typename OutputView::ValueType shift,
                         const typename OutputView::ValueType* reductionResults )
{
   using ValueType = typename OutputView::ValueType;
   using TileScan = CudaTileScan< scanType, blockSize, valuesPerThread, Reduction, ValueType >;

   // allocate shared memory
   __shared__ typename TileScan::Storage storage;

   // load the reduction of the previous tiles
   shift = reduction( shift, reductionResults[ blockIdx.x ] );

   // scan from input into output
   TileScan::scan( input, output, begin, end, outputBegin, reduction, zero, shift, storage );
}

/* CudaScanKernelParallel - scan each tile of the input separately in each CUDA
 * block (first phase to be followed by CudaScanKernelUniformShift when there
 * are multiple CUDA blocks).
 */
template< ScanType scanType,
          int blockSize,
          int valuesPerThread,
@@ -253,7 +363,7 @@ template< ScanType scanType,
          typename OutputView,
          typename Reduction >
__global__ void
CudaScanKernelFirstPhase( const InputView input,
CudaScanKernelParallel( const InputView input,
                        OutputView output,
                        typename InputView::IndexType begin,
                        typename InputView::IndexType end,
@@ -276,26 +386,33 @@ CudaScanKernelFirstPhase( const InputView input,
      blockResults[ blockIdx.x ] = value;
}

/* CudaScanKernelUniformShift - apply a uniform shift to a pre-scanned output
 * array.
 *
 * \param blockResults  An array of per-block shifts coming from the first phase
 *                      (computed by CudaScanKernelParallel)
 * \param shift         A global shift to be applied to all elements of the
 *                      output array.
 */
template< int blockSize,
          int valuesPerThread,
          typename OutputView,
          typename Reduction >
__global__ void
CudaScanKernelSecondPhase( OutputView output,
CudaScanKernelUniformShift( OutputView output,
                            typename OutputView::IndexType outputBegin,
                            typename OutputView::IndexType outputEnd,
                            Reduction reduction,
                           int gridOffset,
                            const typename OutputView::ValueType* blockResults,
                            typename OutputView::ValueType shift )
{
   // load the block result into a __shared__ variable first
   __shared__ typename OutputView::ValueType blockResult;
   if( threadIdx.x == 0 )
      blockResult = blockResults[ gridOffset + blockIdx.x ];
      blockResult = blockResults[ blockIdx.x ];

   // update the output offset for the thread
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelSecondPhase" );
   TNL_ASSERT_EQ( blockDim.x, blockSize, "unexpected block size in CudaScanKernelUniformShift" );
   constexpr int maxElementsInBlock = blockSize * valuesPerThread;
   const int threadOffset = blockIdx.x * maxElementsInBlock + threadIdx.x;
   outputBegin += threadOffset;
@@ -318,6 +435,7 @@ CudaScanKernelSecondPhase( OutputView output,
 * \tparam valuesPerThread  Number of elements processed by each thread sequentially.
 */
template< ScanType scanType,
          ScanPhaseType phaseType,
          int blockSize = 256,
          // valuesPerThread should be odd to avoid shared memory bank conflicts
          int valuesPerThread = 7 >
@@ -371,6 +489,7 @@ struct CudaScanKernelLauncher
         end,
         outputBegin,
         reduction,
         zero,
         zero );
   }

@@ -411,7 +530,7 @@ struct CudaScanKernelLauncher

         // run the kernel with just 1 block
         if( end - begin <= blockSize )
            CudaScanKernelFirstPhase< scanType, blockSize, 1 ><<< 1, blockSize >>>
            CudaScanKernelParallel< scanType, blockSize, 1 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
@@ -422,7 +541,7 @@ struct CudaScanKernelLauncher
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else if( end - begin <= blockSize * 3 )
            CudaScanKernelFirstPhase< scanType, blockSize, 3 ><<< 1, blockSize >>>
            CudaScanKernelParallel< scanType, blockSize, 3 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
@@ -433,7 +552,7 @@ struct CudaScanKernelLauncher
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else if( end - begin <= blockSize * 5 )
            CudaScanKernelFirstPhase< scanType, blockSize, 5 ><<< 1, blockSize >>>
            CudaScanKernelParallel< scanType, blockSize, 5 ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
@@ -444,7 +563,7 @@ struct CudaScanKernelLauncher
                 // blockResults are shifted by 1, because the 0-th element should stay zero
                 &blockResults.getData()[ 1 ] );
         else
            CudaScanKernelFirstPhase< scanType, blockSize, valuesPerThread ><<< 1, blockSize >>>
            CudaScanKernelParallel< scanType, blockSize, valuesPerThread ><<< 1, blockSize >>>
               ( input.getConstView(),
                 output.getView(),
                 begin,
@@ -488,7 +607,10 @@ struct CudaScanKernelLauncher
            cudaGridSize.x = roundUpDivision( currentSize, maxElementsInBlock );

            // run the kernel
            CudaScanKernelFirstPhase< scanType, blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
            switch( phaseType )
            {
               case ScanPhaseType::WriteInFirstPhase:
                  CudaScanKernelParallel< scanType, blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
                     ( input.getConstView(),
                       output.getView(),
                       begin + gridOffset,
@@ -497,6 +619,18 @@ struct CudaScanKernelLauncher
                       reduction,
                       zero,
                       &blockResults.getData()[ gridIdx * maxGridSize() ] );
                  break;

               case ScanPhaseType::WriteInSecondPhase:
                  CudaScanKernelUpsweep< blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
                     ( input.getConstView(),
                       begin + gridOffset,
                       begin + gridOffset + currentSize,
                       reduction,
                       zero,
                       &blockResults.getData()[ gridIdx * maxGridSize() ] );
                  break;
            }
         }

         // synchronize the null-stream after all grids
@@ -505,7 +639,7 @@ struct CudaScanKernelLauncher

         // blockResults now contains scan results for each block. The first phase
         // ends by computing an exclusive scan of this array.
         CudaScanKernelLauncher< ScanType::Exclusive >::perform(
         CudaScanKernelLauncher< ScanType::Exclusive, ScanPhaseType::WriteInSecondPhase >::perform(
            blockResults,
            blockResults,
            0,
@@ -552,10 +686,23 @@ struct CudaScanKernelLauncher
                       typename InputArray::IndexType end,
                       typename OutputArray::IndexType outputBegin,
                       Reduction&& reduction,
                       typename OutputArray::ValueType zero )
                       typename OutputArray::ValueType zero,
                       typename OutputArray::ValueType shift )
   {
      using Index = typename InputArray::IndexType;

      // if the input was already scanned with just one block in the first phase,
      // it must be shifted uniformly in the second phase
      if( end - begin <= blockSize * valuesPerThread ) {
         CudaScanKernelUniformShift< blockSize, valuesPerThread ><<< 1, blockSize >>>
            ( output.getView(),
              outputBegin,
              outputBegin + end - begin,
              reduction,
              blockShifts.getData(),
              shift );
      }
      else {
         // compute the number of grids
         constexpr int maxElementsInBlock = blockSize * valuesPerThread;
         const Index numberOfBlocks = roundUpDivision( end - begin, maxElementsInBlock );
@@ -573,14 +720,32 @@ struct CudaScanKernelLauncher
            cudaGridSize.x = roundUpDivision( currentSize, maxElementsInBlock );

            // run the kernel
         CudaScanKernelSecondPhase< blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
            switch( phaseType )
            {
               case ScanPhaseType::WriteInFirstPhase:
                  CudaScanKernelUniformShift< blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
                     ( output.getView(),
                       outputBegin + gridOffset,
                       outputBegin + gridOffset + currentSize,
                       reduction,
              gridIdx * maxGridSize(),
              blockShifts.getData(),
              zero );
                       &blockShifts.getData()[ gridIdx * maxGridSize() ],
                       shift );
                  break;

               case ScanPhaseType::WriteInSecondPhase:
                  CudaScanKernelDownsweep< scanType, blockSize, valuesPerThread ><<< cudaGridSize, cudaBlockSize >>>
                     ( input.getConstView(),
                       output.getView(),
                       begin + gridOffset,
                       begin + gridOffset + currentSize,
                       outputBegin + gridOffset,
                       reduction,
                       zero,
                       shift,
                       &blockShifts.getData()[ gridIdx * maxGridSize() ] );
                  break;
            }
         }
      }

      // synchronize the null-stream after all grids
+4 −4
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ namespace TNL {
namespace Algorithms {
namespace detail {

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
struct DistributedScan
{
   template< typename InputDistributedArray,
@@ -48,7 +48,7 @@ struct DistributedScan
         // perform first phase on the local data
         const auto inputLocalView = input.getConstLocalView();
         auto outputLocalView = output.getLocalView();
         const auto block_results = Scan< DeviceType, Type >::performFirstPhase( inputLocalView, outputLocalView, begin, end, begin, reduction, zero );
         const auto block_results = Scan< DeviceType, Type, PhaseType >::performFirstPhase( inputLocalView, outputLocalView, begin, end, begin, reduction, zero );
         const ValueType local_result = block_results.getElement( block_results.getSize() - 1 );

         // exchange local results between ranks
@@ -60,11 +60,11 @@ struct DistributedScan
         MPI::Alltoall( dataForScatter, 1, rank_results.getData(), 1, group );

         // compute the scan of the per-rank results
         Scan< Devices::Host, ScanType::Exclusive >::perform( rank_results, rank_results, 0, nproc, 0, reduction, zero );
         Scan< Devices::Host, ScanType::Exclusive, ScanPhaseType::WriteInSecondPhase >::perform( rank_results, rank_results, 0, nproc, 0, reduction, zero );

         // perform the second phase, using the per-block and per-rank results
         const int rank = MPI::GetRank( group );
         Scan< DeviceType, Type >::performSecondPhase( inputLocalView, outputLocalView, block_results, begin, end, begin, reduction, rank_results[ rank ] );
         Scan< DeviceType, Type, PhaseType >::performSecondPhase( inputLocalView, outputLocalView, block_results, begin, end, begin, reduction, zero, rank_results[ rank ] );
      }
   }
};
+13 −10
Original line number Diff line number Diff line
@@ -21,11 +21,11 @@ namespace TNL {
namespace Algorithms {
namespace detail {

template< typename Device, ScanType Type >
template< typename Device, ScanType Type, ScanPhaseType PhaseType = ScanPhaseType::WriteInSecondPhase >
struct Scan;

template< ScanType Type >
struct Scan< Devices::Sequential, Type >
template< ScanType Type, ScanPhaseType PhaseType >
struct Scan< Devices::Sequential, Type, PhaseType >
{
   template< typename InputArray,
             typename OutputArray,
@@ -63,11 +63,12 @@ struct Scan< Devices::Sequential, Type >
                       typename InputArray::IndexType end,
                       typename OutputArray::IndexType outputBegin,
                       Reduction&& reduction,
                       typename OutputArray::ValueType zero );
                       typename OutputArray::ValueType zero,
                       typename OutputArray::ValueType shift );
};

template< ScanType Type >
struct Scan< Devices::Host, Type >
template< ScanType Type, ScanPhaseType PhaseType >
struct Scan< Devices::Host, Type, PhaseType >
{
   template< typename InputArray,
             typename OutputArray,
@@ -105,11 +106,12 @@ struct Scan< Devices::Host, Type >
                       typename InputArray::IndexType end,
                       typename OutputArray::IndexType outputBegin,
                       Reduction&& reduction,
                       typename OutputArray::ValueType zero );
                       typename OutputArray::ValueType zero,
                       typename OutputArray::ValueType shift );
};

template< ScanType Type >
struct Scan< Devices::Cuda, Type >
template< ScanType Type, ScanPhaseType PhaseType >
struct Scan< Devices::Cuda, Type, PhaseType >
{
   template< typename InputArray,
             typename OutputArray,
@@ -147,7 +149,8 @@ struct Scan< Devices::Cuda, Type >
                       typename InputArray::IndexType end,
                       typename OutputArray::IndexType outputBegin,
                       Reduction&& reduction,
                       typename OutputArray::ValueType zero );
                       typename OutputArray::ValueType zero,
                       typename OutputArray::ValueType shift );
};

} // namespace detail
+32 −28
Original line number Diff line number Diff line
@@ -27,12 +27,12 @@ namespace TNL {
namespace Algorithms {
namespace detail {

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
void
Scan< Devices::Sequential, Type >::
Scan< Devices::Sequential, Type, PhaseType >::
perform( const InputArray& input,
         OutputArray& output,
         typename InputArray::IndexType begin,
@@ -59,12 +59,12 @@ perform( const InputArray& input,
   }
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
auto
Scan< Devices::Sequential, Type >::
Scan< Devices::Sequential, Type, PhaseType >::
performFirstPhase( const InputArray& input,
                   OutputArray& output,
                   typename InputArray::IndexType begin,
@@ -80,13 +80,13 @@ performFirstPhase( const InputArray& input,
   return block_results;
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename BlockShifts,
             typename Reduction >
void
Scan< Devices::Sequential, Type >::
Scan< Devices::Sequential, Type, PhaseType >::
performSecondPhase( const InputArray& input,
                    OutputArray& output,
                    const BlockShifts& blockShifts,
@@ -94,18 +94,19 @@ performSecondPhase( const InputArray& input,
                    typename InputArray::IndexType end,
                    typename OutputArray::IndexType outputBegin,
                    Reduction&& reduction,
                    typename OutputArray::ValueType zero )
                    typename OutputArray::ValueType zero,
                    typename OutputArray::ValueType shift )
{
   // artificial second phase - only one block, use the shift as the initial value
   perform( input, output, begin, end, outputBegin, reduction, reduction( zero, blockShifts[ 0 ] ) );
   perform( input, output, begin, end, outputBegin, reduction, reduction( zero, reduction( shift, blockShifts[ 0 ] ) ) );
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
void
Scan< Devices::Host, Type >::
Scan< Devices::Host, Type, PhaseType >::
perform( const InputArray& input,
         OutputArray& output,
         typename InputArray::IndexType begin,
@@ -158,12 +159,12 @@ perform( const InputArray& input,
      Scan< Devices::Sequential, Type >::perform( input, output, begin, end, outputBegin, reduction, zero );
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
auto
Scan< Devices::Host, Type >::
Scan< Devices::Host, Type, PhaseType >::
performFirstPhase( const InputArray& input,
                   OutputArray& output,
                   typename InputArray::IndexType begin,
@@ -212,13 +213,13 @@ performFirstPhase( const InputArray& input,
      return Scan< Devices::Sequential, Type >::performFirstPhase( input, output, begin, end, outputBegin, reduction, zero );
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename BlockShifts,
             typename Reduction >
void
Scan< Devices::Host, Type >::
Scan< Devices::Host, Type, PhaseType >::
performSecondPhase( const InputArray& input,
                    OutputArray& output,
                    const BlockShifts& blockShifts,
@@ -226,7 +227,8 @@ performSecondPhase( const InputArray& input,
                    typename InputArray::IndexType end,
                    typename OutputArray::IndexType outputBegin,
                    Reduction&& reduction,
                    typename OutputArray::ValueType zero )
                    typename OutputArray::ValueType zero,
                    typename OutputArray::ValueType shift )
{
#ifdef HAVE_OPENMP
   using IndexType = typename InputArray::IndexType;
@@ -250,20 +252,20 @@ performSecondPhase( const InputArray& input,
         const IndexType block_output_begin = outputBegin + block_offset;

         // phase 2: per-block scan using the block results as initial values
         Scan< Devices::Sequential, Type >::perform( input, output, block_begin, block_end, block_output_begin, reduction, reduction( zero, blockShifts[ block_idx ] ) );
         Scan< Devices::Sequential, Type >::perform( input, output, block_begin, block_end, block_output_begin, reduction, reduction( zero, reduction( shift, blockShifts[ block_idx ] ) ) );
      }
   }
   else
#endif
      Scan< Devices::Sequential, Type >::performSecondPhase( input, output, blockShifts, begin, end, outputBegin, reduction, zero );
      Scan< Devices::Sequential, Type >::performSecondPhase( input, output, blockShifts, begin, end, outputBegin, reduction, zero, shift );
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
void
Scan< Devices::Cuda, Type >::
Scan< Devices::Cuda, Type, PhaseType >::
perform( const InputArray& input,
         OutputArray& output,
         typename InputArray::IndexType begin,
@@ -276,7 +278,7 @@ perform( const InputArray& input,
   if( end <= begin )
      return;

   detail::CudaScanKernelLauncher< Type >::perform(
   detail::CudaScanKernelLauncher< Type, PhaseType >::perform(
      input,
      output,
      begin,
@@ -289,12 +291,12 @@ perform( const InputArray& input,
#endif
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename Reduction >
auto
Scan< Devices::Cuda, Type >::
Scan< Devices::Cuda, Type, PhaseType >::
performFirstPhase( const InputArray& input,
                   OutputArray& output,
                   typename InputArray::IndexType begin,
@@ -310,7 +312,7 @@ performFirstPhase( const InputArray& input,
      return block_results;
   }

   return detail::CudaScanKernelLauncher< Type >::performFirstPhase(
   return detail::CudaScanKernelLauncher< Type, PhaseType >::performFirstPhase(
      input,
      output,
      begin,
@@ -323,13 +325,13 @@ performFirstPhase( const InputArray& input,
#endif
}

template< ScanType Type >
template< ScanType Type, ScanPhaseType PhaseType >
   template< typename InputArray,
             typename OutputArray,
             typename BlockShifts,
             typename Reduction >
void
Scan< Devices::Cuda, Type >::
Scan< Devices::Cuda, Type, PhaseType >::
performSecondPhase( const InputArray& input,
                    OutputArray& output,
                    const BlockShifts& blockShifts,
@@ -337,13 +339,14 @@ performSecondPhase( const InputArray& input,
                    typename InputArray::IndexType end,
                    typename OutputArray::IndexType outputBegin,
                    Reduction&& reduction,
                    typename OutputArray::ValueType zero )
                    typename OutputArray::ValueType zero,
                    typename OutputArray::ValueType shift )
{
#ifdef HAVE_CUDA
   if( end <= begin )
      return;

   detail::CudaScanKernelLauncher< Type >::performSecondPhase(
   detail::CudaScanKernelLauncher< Type, PhaseType >::performSecondPhase(
      input,
      output,
      blockShifts,
@@ -351,7 +354,8 @@ performSecondPhase( const InputArray& input,
      end,
      outputBegin,
      std::forward< Reduction >( reduction ),
      zero );
      zero,
      shift );
#else
   throw Exceptions::CudaSupportMissing();
#endif
+5 −0
Original line number Diff line number Diff line
@@ -21,6 +21,11 @@ enum class ScanType {
   Inclusive
};

enum class ScanPhaseType {
   WriteInFirstPhase,
   WriteInSecondPhase
};

} // namespace detail
} // namespace Algorithms
} // namespace TNL
Loading