diff --git a/src/TNL/Containers/Algorithms/CudaReductionKernel.h b/src/TNL/Containers/Algorithms/CudaReductionKernel.h index a5823849d443bd70831ddff6252631fb333d4aff..08a767c6010aee87b18b7f221f8ac6927e2d53c4 100644 --- a/src/TNL/Containers/Algorithms/CudaReductionKernel.h +++ b/src/TNL/Containers/Algorithms/CudaReductionKernel.h @@ -39,14 +39,19 @@ static constexpr int Reduction_registersPerThread = 32; // empirically determi static constexpr int Reduction_minBlocksPerMultiprocessor = 4; #endif -template< int blockSize, typename Operation, typename Index > +template< int blockSize, + typename Real, + typename FirstPhase, + typename SecondPhase, + typename Index, + typename ResultType = decltype( std::declval< FirstPhase >( 0,0 ) ) > __global__ void __launch_bounds__( Reduction_maxThreadsPerBlock, Reduction_minBlocksPerMultiprocessor ) -CudaReductionKernel( Operation operation, +CudaReductionKernel( const Real& initialValue, + FirstReduction& firstReduction, + SecondReduction& secondReduction, const Index size, - const typename Operation::DataType1* input1, - const typename Operation::DataType2* input2, - typename Operation::ResultType* output ) + ResultType* output ) { typedef Index IndexType; typedef typename Operation::ResultType ResultType; @@ -62,23 +67,23 @@ CudaReductionKernel( Operation operation, IndexType gid = blockIdx.x * blockDim. x + threadIdx.x; const IndexType gridSize = blockDim.x * gridDim.x; - sdata[ tid ] = operation.initialValue(); + sdata[ tid ] = initialValue; /*** * Read data into the shared memory. We start with the * sequential reduction. */ while( gid + 4 * gridSize < size ) { - operation.firstReduction( sdata[ tid ], gid, input1, input2 ); - operation.firstReduction( sdata[ tid ], gid + gridSize, input1, input2 ); - operation.firstReduction( sdata[ tid ], gid + 2 * gridSize, input1, input2 ); - operation.firstReduction( sdata[ tid ], gid + 3 * gridSize, input1, input2 ); + sdata[ tid ] = firstReduction( sdata[ tid ], gid ); + sdata[ tid ] = firstReduction( sdata[ tid ], gid + gridSize ); + sdata[ tid ] = firstReduction( sdata[ tid ], gid + 2 * gridSize ); + sdata[ tid ] = firstReduction( sdata[ tid ], gid + 3 * gridSize ); gid += 4 * gridSize; } while( gid + 2 * gridSize < size ) { - operation.firstReduction( sdata[ tid ], gid, input1, input2 ); - operation.firstReduction( sdata[ tid ], gid + gridSize, input1, input2 ); + firstReduction( sdata[ tid ], gid, input1, input2 ); + firstReduction( sdata[ tid ], gid + gridSize, input1, input2 ); gid += 2 * gridSize; } while( gid < size )