Commit e7b4ef8b authored by Illia Kolesnik's avatar Illia Kolesnik Committed by Tomáš Oberhuber
Browse files

Bug fix for CSR MultiVector, optimizations for CSR LightWithoutAtomic

parent 466e0136
Loading
Loading
Loading
Loading
+77 −69
Original line number Diff line number Diff line
@@ -42,7 +42,6 @@ union Block {

/* Configuration */
constexpr size_t MAX_X_DIM = 2147483647;
constexpr int ELEMENTS_PER_WARP = 1024;
//-----------------------------------------------------------------

namespace TNL {
@@ -791,7 +790,8 @@ void CSR< Real, Device, Index, KernelType >::spmvCudaVectorized( const InVector&
template< typename Real,
          typename Index,
          int warpSize,
          int sharedPerWarp >
          int sharedPerWarp,
          int maxElemPerWarp >
__global__
void SpMVCSRAdaptive( const Real *inVector,
                      Real *outVector,
@@ -799,18 +799,18 @@ void SpMVCSRAdaptive( const Real *inVector,
                      const Index* columnIndexes,
                      const Real* values,
                      const Block *blocks,
                      Index blocks_size,
                      Index blocksSize,
                      Index getColumns,
                      Index gridID) {
   __shared__ Real shared_res[49152/sizeof(Real)];
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index blockIdx = index / warpSize;
   if (blockIdx >= blocks_size)
   if (blockIdx >= blocksSize)
      return;

   Block block = blocks[blockIdx];
   Real result = 0.0;
   const Index laneID = index % warpSize;
   const Index laneID = threadIdx.x % warpSize;
   const Index minID = rowPointers[block.index[0]/* minRow */];
   Index i, to, column, offset, maxID;
   if (block.byte[7] == 0) {
@@ -859,8 +859,8 @@ void SpMVCSRAdaptive( const Real *inVector,
      /////////////////////////////////////* CSR VECTOR L */////////////
      maxID = rowPointers[block.index[0]/* minRow */ + 1];

      offset = block.index[1]/* warpInRow */ * ELEMENTS_PER_WARP;
      to = minID + (block.index[1]/* warpInRow */ + 1) * ELEMENTS_PER_WARP;
      offset = block.index[1]/* warpInRow */ * maxElemPerWarp;
      to = minID + (block.index[1]/* warpInRow */ + 1) * maxElemPerWarp;
      if (to > maxID) to = maxID;

      for (i = minID + offset + laneID; i < to; i += warpSize) {
@@ -892,15 +892,15 @@ void SpMVCSRScalar( const Real *inVector,
                    const Index rows,
                    const Index getColumns,
                    const Index gridID) {
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   if (index >= rows)
   const Index row = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   if (row >= rows)
      return;

   Index column;
   Real result = 0.0;
   const Index endID = rowPointers[index + 1];
   const Index endID = rowPointers[row + 1];

   for (Index i = rowPointers[index]; i < endID; ++i) {
   for (Index i = rowPointers[row]; i < endID; ++i) {
      column = columnIndexes[i];
      if (column >= getColumns)
         break;
@@ -908,7 +908,7 @@ void SpMVCSRScalar( const Real *inVector,
      result += values[i] * inVector[column];
   }

   outVector[index] = result;
   outVector[row] = result;
}

template< typename Real,
@@ -922,21 +922,23 @@ void SpMVCSRMultiVector( const Real *inVector,
                         const Real* values,
                         const Index rows,
                         const Index getColumns,
                         const Index offset,
                         const Index warps, // warps per row
                         const Index gridID)
{
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index rowID = index / offset;
   const Index warpID =
      ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / warpSize;
   const Index rowID = warpID / warps;
   if (rowID >= rows)
      return;

   const Index inRowID = index % offset;
   const Index laneID = threadIdx.x % warpSize;
   const Index offset = warps * warpSize;

   Real result = 0.0;
   Index endID = rowPointers[rowID + 1];

   /* Calculate result */
   for (Index i = rowPointers[rowID] + inRowID; i < endID; i += offset) {
   for (Index i = rowPointers[rowID] + (warpID % warps) * warpSize + laneID;
            i < endID; i += offset) {
      Index column = columnIndexes[i];
      if (column >= getColumns)
         break;
@@ -951,7 +953,7 @@ void SpMVCSRMultiVector( const Real *inVector,
   result += __shfl_down_sync(0xFFFFFFFF, result, 2);
   result += __shfl_down_sync(0xFFFFFFFF, result, 1);
   /* Write result */
   if (index % warpSize == 0) atomicAdd(&outVector[rowID], result);
   if (laneID == 0) atomicAdd(&outVector[rowID], result);
}

template< typename Real,
@@ -967,13 +969,12 @@ void SpMVCSRVector( const Real *inVector,
                    const Index getColumns,
                    const Index gridID)
{
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index warpID = index / warpSize;
   const Index warpID = ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / warpSize;
   if (warpID >= rows)
      return;

   Real result = 0.0;
   const Index laneID = index % warpSize;
   const Index laneID = threadIdx.x % warpSize;
   Index endID = rowPointers[warpID + 1];

   /* Calculate result */
@@ -1017,7 +1018,7 @@ void SpMVCSRLight( const Real *inVector,
      /* Get row number */
      if (inGroupID == 0) row = atomicAdd(rowCnt, 1);

      /* Propagate row number in group */
      /* share row number in group */
      row = __shfl_sync(0xFFFFFFFF, row, groupID * groupSize);
      if (row >= rows)
         return;
@@ -1053,13 +1054,12 @@ void SpMVCSRLightWithoutAtomic2( const Real *inVector,
                                 const Index rows,
                                 const Index getColumns,
                                 const Index gridID) {
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index row = index / 2;

   const Index row =
      ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / 2;
   if (row >= rows)
      return;

   const Index inGroupID = index % 2;
   const Index inGroupID = threadIdx.x % 2;
   const Index maxID = rowPointers[row + 1];

   Real result = 0.0;
@@ -1089,13 +1089,12 @@ void SpMVCSRLightWithoutAtomic4( const Real *inVector,
                                 const Index rows,
                                 const Index getColumns,
                                 const Index gridID) {
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index row = index / 4;

   const Index row =
      ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / 4;
   if (row >= rows)
      return;

   const Index inGroupID = index % 4;
   const Index inGroupID = threadIdx.x % 4;
   const Index maxID = rowPointers[row + 1];

   Real result = 0.0;
@@ -1126,14 +1125,13 @@ void SpMVCSRLightWithoutAtomic8( const Real *inVector,
                                 const Index rows,
                                 const Index getColumns,
                                 const Index gridID) {
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index row = index / 8;
   Index i, column;

   const Index row =
      ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / 8;
   if (row >= rows)
      return;

   const Index inGroupID = index % 8;
   Index i, column;
   const Index inGroupID = threadIdx.x % 8;
   const Index maxID = rowPointers[row + 1];

   Real result = 0.0;
@@ -1165,14 +1163,14 @@ void SpMVCSRLightWithoutAtomic16( const Real *inVector,
                                  const Index rows,
                                  const Index getColumns,
                                  const Index gridID) {
   const Index index = (gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x;
   const Index row = index / 16;
   Index i, column;

   const Index row =
      ((gridID * MAX_X_DIM) + (blockIdx.x * blockDim.x) + threadIdx.x) / 16;
   if (row >= rows)
      return;

   const Index inGroupID = index % 16;

   Index i, column;
   const Index inGroupID = threadIdx.x % 16;
   const Index maxID = rowPointers[row + 1];

   Real result = 0.0;
@@ -1195,8 +1193,7 @@ void SpMVCSRLightWithoutAtomic16( const Real *inVector,
}

template< typename Real,
          typename Index,
          int warpSize >
          typename Index >
void SpMVCSRScalarPrepare( const Real *inVector,
                           Real* outVector,
                           const Index* rowPointers,
@@ -1267,8 +1264,7 @@ void SpMVCSRVectorPrepare( const Real *inVector,
}

template< typename Real,
          typename Index,
          int warpSize >
          typename Index >
void SpMVCSRLightPrepare( const Real *inVector,
                          Real* outVector,
                          const Index* rowPointers,
@@ -1278,7 +1274,7 @@ void SpMVCSRLightPrepare( const Real *inVector,
                          const Index rows,
                          const Index getColumns) {
   const Index threads = 1024; // max block size
   Index blocks, groupSize;
   Index groupSize;
   /* Copy rowCnt to GPU */
   unsigned rowCnt = 0;
   unsigned *kernelRowCnt = nullptr;
@@ -1287,7 +1283,8 @@ void SpMVCSRLightPrepare( const Real *inVector,

   cudaDeviceProp properties;
   cudaGetDeviceProperties( &properties, Cuda::DeviceInfo::getActiveDevice() );
   blocks = properties.multiProcessorCount * properties.maxThreadsPerMultiProcessor / threads;
   Index blocks = 
      properties.multiProcessorCount * properties.maxThreadsPerMultiProcessor / threads;

   const Index nnz = roundUpDivision(valuesSize, rows); // non zeroes per row
   if (nnz <= 2)
@@ -1316,7 +1313,8 @@ void SpMVCSRLightPrepare( const Real *inVector,

template< typename Real,
          typename Index,
          int warpSize >
          int warpSize,
          int maxElemPerWarp >
void SpMVCSRLightWithoutAtomicPrepare( const Real *inVector,
                                       Real* outVector,
                                       const Index* rowPointers,
@@ -1338,8 +1336,10 @@ void SpMVCSRLightWithoutAtomicPrepare( const Real *inVector,
      groupSize = 8;
   else if (nnz <= 16)
      groupSize = 16;
   else if (nnz <= maxElemPerWarp)
      groupSize = 32; // CSR Vector
   else
      groupSize = 32;
      groupSize = roundUpDivision(nnz, maxElemPerWarp) * 32; // CSR MultiVector

   neededThreads = groupSize * rows;
   /* Execute kernels on device */
@@ -1372,18 +1372,24 @@ void SpMVCSRLightWithoutAtomicPrepare( const Real *inVector,
                  inVector, outVector, rowPointers, columnIndexes, values,
                  rows, getColumns, grid
         );
      } else { // CSR SpMV Light with groupsize = 32 is CSR Vector
      } else if (groupSize == 32) { // CSR SpMV Light with groupsize = 32 is CSR Vector
         SpMVCSRVector<Real, Index, warpSize><<<blocks, threads>>>(
                  inVector, outVector, rowPointers, columnIndexes, values,
                  rows, getColumns, grid
         );
      } else { // Execute CSR MultiVector
         SpMVCSRMultiVector<Real, Index, warpSize><<<blocks, threads>>>(
                  inVector, outVector, rowPointers, columnIndexes, values,
                  rows, getColumns, groupSize / 32, grid
         );
      }
   }
}

template< typename Real,
          typename Index,
          int warpSize >
          int warpSize,
          int maxElemPerWarp>
void SpMVCSRMultiVectorPrepare( const Real *inVector,
                                Real* outVector,
                                const Index* rowPointers,
@@ -1398,9 +1404,8 @@ void SpMVCSRMultiVectorPrepare( const Real *inVector,
   Index blocks;

   const Index nnz = roundUpDivision(valuesSize, rows); // non zeroes per row
   const size_t neededWarps = roundUpDivision(nnz, ELEMENTS_PER_WARP); // warps per row
   const Index offset = neededWarps * ELEMENTS_PER_WARP;
   size_t neededThreads = offset * rows;
   const Index neededWarps = roundUpDivision(nnz, maxElemPerWarp); // warps per row
   size_t neededThreads = warpSize * neededWarps * rows;
   /* Execute kernels on device */
   for (Index grid = 0; neededThreads != 0; ++grid) {
      if (MAX_X_DIM * threads >= neededThreads) {
@@ -1431,7 +1436,7 @@ void SpMVCSRMultiVectorPrepare( const Real *inVector,
                  values,
                  rows,
                  getColumns,
                  offset,
                  neededWarps,
                  grid
         );
      }
@@ -1442,7 +1447,8 @@ void SpMVCSRMultiVectorPrepare( const Real *inVector,
template< typename Real,
          typename Index,
          typename Device,
          CSRKernel KernelType>
          CSRKernel KernelType,
          int maxElemPerWarp>
Index findLimit(const Index start, const Index max,
               const CSR< Real, Device, Index, KernelType >& matrix,
               const Index size,
@@ -1458,7 +1464,7 @@ Index findLimit(const Index start, const Index max,
            type = STREAM;
            return current;
         } else {                  // one long row
            if (sum <= ELEMENTS_PER_WARP)
            if (sum <= maxElemPerWarp)
               type = VECTOR;
            else
               type = LONG;
@@ -1475,7 +1481,8 @@ template< typename Real,
          typename Index,
          typename Device,
          CSRKernel KernelType,
          int warpSize >
          int warpSize,
          int maxElemPerWarp >
void SpMVCSRAdaptivePrepare( const Real *inVector,
                             Real* outVector,
                             const CSR< Real, Device, Index, KernelType >& matrix,
@@ -1488,10 +1495,9 @@ void SpMVCSRAdaptivePrepare( const Real *inVector,
   /* Configuration ---------------------------------------------------*/
   /* Execute 1024 threads per block for float, (12 elements per thread) for 48KB cache
              512  threads per block for double (12 elements per thread) */
   constexpr size_t THREADS_PER_BLOCK = sizeof(Real) == 4 ? 1024 : 512;
   constexpr Index THREADS_PER_BLOCK = sizeof(Real) == 4 ? 1024 : 512;
   constexpr Index WARPS_PER_BLOCK = THREADS_PER_BLOCK / 32;
   constexpr Index SHARED = 49152/sizeof(Real); 
   constexpr Index SHARED_PER_WARP = SHARED / WARPS_PER_BLOCK;
   constexpr Index SHARED_PER_WARP = 49152/sizeof(Real) / WARPS_PER_BLOCK;
   //--------------------------------------------------------------------
   Index blocks, sum, start = 0, nextStart = 0;
   const Index threads = THREADS_PER_BLOCK;
@@ -1502,9 +1508,11 @@ void SpMVCSRAdaptivePrepare( const Real *inVector,

   while (nextStart != rows - 1) {
      Type type;
      nextStart = findLimit(start, SHARED_PER_WARP, matrix, rows, type, sum);
      nextStart = findLimit<Real, Index, Device, KernelType, maxElemPerWarp>(
         start, SHARED_PER_WARP, matrix, rows, type, sum
      );
      if (type == LONG) {
         uint32_t parts = roundUpDivision(sum, ELEMENTS_PER_WARP);
         uint32_t parts = roundUpDivision(sum, maxElemPerWarp);
         for (uint32_t index = 0; index < parts; ++index) {
            inBlock.emplace_back(start, LONG, index);
         }
@@ -1532,7 +1540,7 @@ void SpMVCSRAdaptivePrepare( const Real *inVector,
         neededThreads -= MAX_X_DIM * threads;
      }

      SpMVCSRAdaptive<Real, Index, warpSize, SHARED_PER_WARP><<<blocks, threads>>>(
      SpMVCSRAdaptive<Real, Index, warpSize, SHARED_PER_WARP, maxElemPerWarp><<<blocks, threads>>>(
               inVector,
               outVector,
               rowPointers,
@@ -1701,7 +1709,7 @@ class CSRDeviceDependentCode< Devices::Cuda >
         switch(KernelType)
         {
            case CSRScalar:
               SpMVCSRScalarPrepare<Real, Index, 32>(
               SpMVCSRScalarPrepare<Real, Index>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
@@ -1723,7 +1731,7 @@ class CSRDeviceDependentCode< Devices::Cuda >
               );
               break;
            case CSRLight:
               SpMVCSRLightPrepare<Real, Index, 32>(
               SpMVCSRLightPrepare<Real, Index>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
@@ -1735,7 +1743,7 @@ class CSRDeviceDependentCode< Devices::Cuda >
               );
               break;
            case CSRAdaptive:
               SpMVCSRAdaptivePrepare<Real, Index, Device, KernelType, 32>(
               SpMVCSRAdaptivePrepare<Real, Index, Device, KernelType, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix,
@@ -1748,7 +1756,7 @@ class CSRDeviceDependentCode< Devices::Cuda >
               );
               break;
            case CSRMultiVector:
               SpMVCSRMultiVectorPrepare<Real, Index, 32>(
               SpMVCSRMultiVectorPrepare<Real, Index, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
@@ -1760,7 +1768,7 @@ class CSRDeviceDependentCode< Devices::Cuda >
               );
               break;
            case CSRLightWithoutAtomic:
               SpMVCSRLightWithoutAtomicPrepare<Real, Index, 32>(
               SpMVCSRLightWithoutAtomicPrepare<Real, Index, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),