Commit c60aa9e4 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Rewritting parallel reduciton with lambda functions.

parent 1997234c
Loading
Loading
Loading
Loading
+17 −12
Original line number Diff line number Diff line
@@ -39,14 +39,19 @@ static constexpr int Reduction_registersPerThread = 32; // empirically determi
   static constexpr int Reduction_minBlocksPerMultiprocessor = 4;
#endif

template< int blockSize, typename Operation, typename Index >
template< int blockSize,
   typename Real,
   typename FirstPhase,
   typename SecondPhase,
   typename Index,
   typename ResultType = decltype( std::declval< FirstPhase >( 0,0 ) ) >
__global__ void
__launch_bounds__( Reduction_maxThreadsPerBlock, Reduction_minBlocksPerMultiprocessor )
CudaReductionKernel( Operation operation,
CudaReductionKernel( const Real& initialValue,
                     FirstReduction& firstReduction,
                     SecondReduction& secondReduction,
                     const Index size,
                     const typename Operation::DataType1* input1,
                     const typename Operation::DataType2* input2,
                     typename Operation::ResultType* output )
                     ResultType* output )
{
   typedef Index IndexType;
   typedef typename Operation::ResultType ResultType;
@@ -62,23 +67,23 @@ CudaReductionKernel( Operation operation,
         IndexType gid = blockIdx.x * blockDim. x + threadIdx.x;
   const IndexType gridSize = blockDim.x * gridDim.x;

   sdata[ tid ] = operation.initialValue();
   sdata[ tid ] = initialValue;
   /***
    * Read data into the shared memory. We start with the
    * sequential reduction.
    */
   while( gid + 4 * gridSize < size )
   {
      operation.firstReduction( sdata[ tid ], gid,                input1, input2 );
      operation.firstReduction( sdata[ tid ], gid + gridSize,     input1, input2 );
      operation.firstReduction( sdata[ tid ], gid + 2 * gridSize, input1, input2 );
      operation.firstReduction( sdata[ tid ], gid + 3 * gridSize, input1, input2 );
      sdata[ tid ] = firstReduction( sdata[ tid ], gid );
      sdata[ tid ] = firstReduction( sdata[ tid ], gid + gridSize );
      sdata[ tid ] = firstReduction( sdata[ tid ], gid + 2 * gridSize );
      sdata[ tid ] = firstReduction( sdata[ tid ], gid + 3 * gridSize );
      gid += 4 * gridSize;
   }
   while( gid + 2 * gridSize < size )
   {
      operation.firstReduction( sdata[ tid ], gid,                input1, input2 );
      operation.firstReduction( sdata[ tid ], gid + gridSize,     input1, input2 );
      firstReduction( sdata[ tid ], gid,                input1, input2 );
      firstReduction( sdata[ tid ], gid + gridSize,     input1, input2 );
      gid += 2 * gridSize;
   }
   while( gid < size )