Commit cb173073 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Refactoring Light CSR kernel.

parent d05990b9
Loading
Loading
Loading
Loading
+85 −6
Original line number Diff line number Diff line
@@ -179,7 +179,7 @@ void SpMVCSRLightWithoutAtomic16( OffsetsView offsets,
      keep( segmentIdx, result );
}

template< typename Real,
/*template< typename Real,
          typename Index,
          typename OffsetsView,
          typename Fetch,
@@ -204,22 +204,71 @@ void SpMVCSRVector( OffsetsView offsets,
   const Index laneID = threadIdx.x & 31; // & is cheaper than %
   Index endID = offsets[warpID + 1];

   /* Calculate result */
   // Calculate result
   bool compute = true;
   for (Index i = offsets[warpID] + laneID; i < endID; i += warpSize)
      result = reduce( result, fetch( i, compute ) );

   /* Reduction */
   // Reduction
   result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result, 16 ) );
   result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
   result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
   result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
   result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
   /* Write result */
   // Write result
   if( laneID == 0 )
      keep( warpID, result );
}*/

template< int ThreadsPerSegment,
          typename Real,
          typename Index,
          typename OffsetsView,
          typename Fetch,
          typename Reduce,
          typename Keep >
__global__
void SpMVCSRVector( OffsetsView offsets,
                    const Index first,
                    const Index last,
                    Fetch fetch,
                    Reduce reduce,
                    Keep keep,
                    const Real zero,
                    const Index gridID )
{
   //const int warpSize = 32;
   const Index warpID = first + ((gridID * TNL::Cuda::getMaxGridXSize() ) + (blockIdx.x * blockDim.x) + threadIdx.x) / ThreadsPerSegment;
   if (warpID >= last)
      return;

   Real result = zero;
   const Index laneID = threadIdx.x & ( ThreadsPerSegment - 1 ); // & is cheaper than %
   Index endID = offsets[warpID + 1];

   // Calculate result
   bool compute = true;
   for (Index i = offsets[warpID] + laneID; i < endID; i += ThreadsPerSegment )
      result = reduce( result, fetch( i, compute ) );

   // Reduction
   if( ThreadsPerSegment > 16 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result, 16 ) );
   if( ThreadsPerSegment > 8 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
   if( ThreadsPerSegment > 4 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
   if( ThreadsPerSegment > 2 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
   if( ThreadsPerSegment > 1 )
      result = reduce( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );

   // Store result
   if( laneID == 0 )
      keep( warpID, result );
}


template< int BlockSize,
          int ThreadsPerSegment,
          typename Offsets,
@@ -381,7 +430,37 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee
            neededThreads -= TNL::Cuda::getMaxGridXSize() * threads;
         }

         if( threadsPerSegment == 1 )
            SpMVCSRVector< 1, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 2 )
            SpMVCSRVector< 2, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 4 )
            SpMVCSRVector< 4, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 8 )
            SpMVCSRVector< 8, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 16 )
            SpMVCSRVector< 16, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 32 )
            SpMVCSRVector< 32, Real, Index, OffsetsView, Fetch, Reduce, Keep ><<< blocks, threads >>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         if( threadsPerSegment == 64 )
         { // Execute CSR MultiVector
            reduceSegmentsCSRLightMultivectorKernel< 128, 64 ><<<blocks, threads>>>(
                     grid, offsets, first, last, fetch, reduce, keep, zero );
         }
         if (threadsPerSegment >= 128 )
         { // Execute CSR MultiVector
            reduceSegmentsCSRLightMultivectorKernel< 128, 128 ><<<blocks, threads>>>(
                     grid, offsets, first, last, fetch, reduce, keep, zero );
         }


         /*if (threadsPerSegment == 2)
            SpMVCSRLightWithoutAtomic2<Real, Index, OffsetsView, Fetch, Reduce, Keep ><<<blocks, threads>>>(
               offsets, first, last, fetch, reduce, keep, zero, grid );
         else if (threadsPerSegment == 4)
@@ -407,7 +486,7 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee
         { // Execute CSR MultiVector
            reduceSegmentsCSRLightMultivectorKernel< 128, 128 ><<<blocks, threads>>>(
                     grid, offsets, first, last, fetch, reduce, keep, zero );
         }
         }*/
      }
#endif