Commit 40044cd5 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Small maybe optimization in Lihgt CSR kernel.

parent ab6d5310
Loading
Loading
Loading
Loading
+23 −12
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ template< typename Real,
          typename Reduce,
          typename Keep >
__global__
void SpMVCSRLightWithoutAtomic2( OffsetsView offsets,
void SpMVCSRLight2( OffsetsView offsets,
                                 const Index first,
                                 const Index last,
                                 Fetch fetch,
@@ -66,7 +66,7 @@ template< typename Real,
          typename Reduce,
          typename Keep >
__global__
void SpMVCSRLightWithoutAtomic4( OffsetsView offsets,
void SpMVCSRLight4( OffsetsView offsets,
                                 const Index first,
                                 const Index last,
                                 Fetch fetch,
@@ -105,7 +105,7 @@ template< typename Real,
          typename Reduce,
          typename Keep >
__global__
void SpMVCSRLightWithoutAtomic8( OffsetsView offsets,
void SpMVCSRLight8( OffsetsView offsets,
                                 const Index first,
                                 const Index last,
                                 Fetch fetch,
@@ -145,7 +145,7 @@ template< typename Real,
          typename Reduce,
          typename Keep >
__global__
void SpMVCSRLightWithoutAtomic16( OffsetsView offsets,
void SpMVCSRLight16( OffsetsView offsets,
                                  const Index first,
                                  const Index last,
                                  Fetch fetch,
@@ -253,14 +253,25 @@ void SpMVCSRVector( OffsetsView offsets,

   // Reduction
   if( ThreadsPerSegment > 16 )
   {
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result, 16 ) );
   if( ThreadsPerSegment > 8 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
   if( ThreadsPerSegment > 4 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
   if( ThreadsPerSegment > 2 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
   if( ThreadsPerSegment > 1 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
   } else if( ThreadsPerSegment > 8 ) {
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
   } else if( ThreadsPerSegment > 4 ) {
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
   } else if( ThreadsPerSegment > 2 ) {
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
   } else if( ThreadsPerSegment > 1 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );

   // Store result
@@ -461,16 +472,16 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee


         /*if (threadsPerSegment == 2)
            SpMVCSRLightWithoutAtomic2<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
            SpMVCSRLight2<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else if (threadsPerSegment == 4)
            SpMVCSRLightWithoutAtomic4<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
            SpMVCSRLight4<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else if (threadsPerSegment == 8)
            SpMVCSRLightWithoutAtomic8<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
            SpMVCSRLight8<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else if (threadsPerSegment == 16)
            SpMVCSRLightWithoutAtomic16<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
            SpMVCSRLight16<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else if (threadsPerSegment == 32)
         { // CSR SpMV Light with threadsPerSegment = 32 is CSR Vector