Commit d05990b9 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Added multi vector CSR kernel to CSR Light.

parent 777c0c95
Loading
Loading
Loading
Loading
+101 −9
Original line number Diff line number Diff line
@@ -219,6 +219,97 @@ void SpMVCSRVector( OffsetsView offsets,
   if( laneID == 0 )
      keep( warpID, result );
}

template< int BlockSize,
          int ThreadsPerSegment,
          typename Offsets,
          typename Index,
          typename Fetch,
          typename Reduction,
          typename ResultKeeper,
          typename Real >
__global__
void reduceSegmentsCSRLightMultivectorKernel(
    int gridIdx,
    const Offsets offsets,
    Index first,
    Index last,
    Fetch fetch,
    const Reduction reduce,
    ResultKeeper keep,
    const Real zero )
{
    const Index segmentIdx =  TNL::Cuda::getGlobalThreadIdx( gridIdx ) / ThreadsPerSegment + first;
    if( segmentIdx >= last )
        return;

    __shared__ Real shared[ BlockSize / 32 ];
    if( threadIdx.x < BlockSize / TNL::Cuda::getWarpSize() )
        shared[ threadIdx.x ] = zero;

    const int laneIdx = threadIdx.x & ( ThreadsPerSegment - 1 ); // & is cheaper than %
    const int inWarpLaneIdx = threadIdx.x & ( TNL::Cuda::getWarpSize() - 1 ); // & is cheaper than %
    const Index beginIdx = offsets[ segmentIdx ];
    const Index endIdx   = offsets[ segmentIdx + 1 ] ;

    Real result = zero;
    bool compute( true );
    Index localIdx = laneIdx;
    for( Index globalIdx = beginIdx + laneIdx; globalIdx < endIdx && compute; globalIdx += ThreadsPerSegment )
    {
       result = reduce( result, details::FetchLambdaAdapter< Index, Fetch >::call( fetch, segmentIdx, localIdx, globalIdx, compute ) );
       localIdx += ThreadsPerSegment;
    }
    result += __shfl_down_sync(0xFFFFFFFF, result, 16);
    result += __shfl_down_sync(0xFFFFFFFF, result, 8);
    result += __shfl_down_sync(0xFFFFFFFF, result, 4);
    result += __shfl_down_sync(0xFFFFFFFF, result, 2);
    result += __shfl_down_sync(0xFFFFFFFF, result, 1);

    const Index warpIdx = threadIdx.x / TNL::Cuda::getWarpSize();
    if( inWarpLaneIdx == 0 )
        shared[ warpIdx ] = result;

    __syncthreads();
    // Reduction in shared
    if( warpIdx == 0 && inWarpLaneIdx < 16 )
    {
        //constexpr int totalWarps = BlockSize / WarpSize;
        constexpr int warpsPerSegment = ThreadsPerSegment / TNL::Cuda::getWarpSize();
        if( warpsPerSegment >= 32 )
        {
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx + 16 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 16 )
        {
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  8 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 8 )
        {
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  4 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 4 )
        {
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  2 ] );
            __syncwarp();
        }
        if( warpsPerSegment >= 2 )
        {
            shared[ inWarpLaneIdx ] =  reduce( shared[ inWarpLaneIdx ], shared[ inWarpLaneIdx +  1 ] );
            __syncwarp();
        }
        constexpr int segmentsCount = BlockSize / ThreadsPerSegment;
        if( inWarpLaneIdx < segmentsCount && segmentIdx + inWarpLaneIdx < last )
        {
            //printf( "Long: segmentIdx %d -> %d \n", segmentIdx, aux );
            keep( segmentIdx + inWarpLaneIdx, shared[ inWarpLaneIdx * ThreadsPerSegment / 32 ] );
        }
    }
}

#endif
template< typename Index,
          typename Device,
@@ -302,20 +393,21 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee
         else if (threadsPerSegment == 16)
            SpMVCSRLightWithoutAtomic16<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else // if (threadsPerSegment == 32)
         else if (threadsPerSegment == 32)
         { // CSR SpMV Light with threadsPerSegment = 32 is CSR Vector
            SpMVCSRVector<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         }
         /*else
         else if (threadsPerSegment == 64 )
         { // Execute CSR MultiVector
            SpMVCSRMultiVector<Real, Index, warpSize><<<blocks, threads>>>(
                     inVector, outVector, matrix.getoffsets().getData(),
                     matrix.getColumnIndexes().getData(), matrix.getValues().getData(),
                     rows, threadsPerSegment / 32, grid
            );
         }*/

            reduceSegmentsCSRLightMultivectorKernel< 128, 64 ><<<blocks, threads>>>(
                     grid, offsets, first, last, fetch, reduce, keep, zero );
         }
         else //if (threadsPerSegment == 64 )
         { // Execute CSR MultiVector
            reduceSegmentsCSRLightMultivectorKernel< 128, 128 ><<<blocks, threads>>>(
                     grid, offsets, first, last, fetch, reduce, keep, zero );
         }
      }
#endif