Commit e3b27a61 authored by Illia Kolesnik's avatar Illia Kolesnik Committed by Tomáš Oberhuber
Browse files

Fixed blocks filling

parent 5c9c5d81
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -29,7 +29,7 @@ enum class Type {
};

union Block {
   void set(uint32_t row, Type type = Type::VECTOR, uint32_t index = 0) noexcept {
   Block(uint32_t row, Type type = Type::VECTOR, uint32_t index = 0) noexcept {
      this->index[0] = row;
      this->index[1] = index;
      this->byte[7] = (uint8_t)type;
@@ -87,6 +87,7 @@ public:


   Containers::Vector< Block, Device, Index > blocks;
   
   Index maxElementsPerWarp = 1024;

   using Sparse< Real, Device, Index >::getAllocatedElementsCount;
+82 −145
Original line number Diff line number Diff line
@@ -16,7 +16,7 @@
#include <TNL/Algorithms/AtomicOperations.h>
#include <TNL/Exceptions/NotImplementedError.h>
#include <TNL/Atomic.h>
#include <vector>
#include <vector> // for blocks in CSR Adaptive

#ifdef HAVE_CUSPARSE
#include <cuda.h>
@@ -113,7 +113,7 @@ void CSR< Real, Device, Index, KernelType >::setCompressedRowLengths( ConstCompr
   this->columnIndexes.setSize( this->rowPointers.getElement( this->rows ) );
   this->columnIndexes.setValue( this->columns );

   // if (KernelType == CSRAdaptive)
   if (KernelType == CSRAdaptive)
      this->setBlocks();
}

@@ -121,11 +121,11 @@ void CSR< Real, Device, Index, KernelType >::setCompressedRowLengths( ConstCompr
template< typename Real,
          typename Index,
          typename Device,
          CSRKernel KernelType,
          int maxElemPerWarp>
          CSRKernel KernelType>
Index findLimit(const Index start, const Index max,
               const CSR< Real, Device, Index, KernelType >& matrix,
               const Index size,
               const Index maxElemPerWarp,
               Type &type,
               Index &sum) {
   sum = 0;
@@ -158,33 +158,34 @@ template< typename Real,
void CSR< Real, Device, Index, KernelType >::setBlocks()
{
   const Index rows = this->getRowPointers().getSize();
   Block *tmpBlocks = new Block[rows];
   Index nextStart = 0, start = 0, cnt = 0, sum = 0;
   Index sum, start = 0, nextStart = 0;

   /* Fill blocks */
   std::vector<Block> inBlock;
   inBlock.reserve(rows); // reserve space to avoid reallocation

   while (nextStart != rows - 1) {
      Type type;
      nextStart = findLimit<Real, Index, Device, KernelType, 384>(
         start, this->maxElementsPerWarp, *this, rows, type, sum
      nextStart = findLimit(
         start, 384, *this, rows, this->maxElementsPerWarp, type, sum
      );
      if (type == Type::LONG) {
         uint32_t parts = roundUpDivision(sum, this->maxElementsPerWarp);
         uint32_t parts = roundUpDivision(sum, 384);
         for (uint32_t index = 0; index < parts; ++index) {
            tmpBlocks[cnt++].set(start, Type::LONG, index);
            inBlock.emplace_back(start, Type::LONG, index);
         }
      } else {
         tmpBlocks[cnt++].set(start, type);
         inBlock.emplace_back(start, type);
      }

      start = nextStart;
   }
   tmpBlocks[cnt++].set(nextStart);

   /* Copy to TNL Vector */
   this->blocks.setSize(cnt);
   for (Index i = 0; i < cnt; ++i)
      this->blocks.setElement(i, tmpBlocks[i]);
   inBlock.emplace_back(nextStart);

   delete [] tmpBlocks;
   /* Copy values */
   this->blocks.setSize(inBlock.size());
   for (size_t i = 0; i < inBlock.size(); ++i)
      this->blocks.setElement(i, inBlock[i]);
}

template< typename Real,
@@ -1495,40 +1496,6 @@ void SpMVCSRMultiVectorPrepare( const Real *inVector,
   }
}

/* Find limit of block */
// template< typename Real,
//           typename Index,
//           typename Device,
//           CSRKernel KernelType,
//           int maxElemPerWarp>
// Index findLimit(const Index start, const Index max,
//                const CSR< Real, Device, Index, KernelType >& matrix,
//                const Index size,
//                Type &type,
//                Index &sum) {
//    sum = 0;
//    for (Index current = start; current < size - 1; ++current) {
//       Index elements = matrix.getRowPointers().getElement(current + 1) -
//                        matrix.getRowPointers().getElement(current);
//       sum += elements;
//       if (sum > max) {
//          if (current - start > 1) { // extra row
//             type = STREAM;
//             return current;
//          } else {                  // one long row
//             if (sum <= maxElemPerWarp)
//                type = VECTOR;
//             else
//                type = LONG;
//             return current + 1;
//          }
//       }
//    }

//    type = STREAM;
//    return size - 1; // return last row pointer
// }

template< typename Real,
          typename Index,
          typename Device,
@@ -1551,38 +1518,10 @@ void SpMVCSRAdaptivePrepare( const Real *inVector,
   constexpr Index WARPS_PER_BLOCK = THREADS_PER_BLOCK / 32;
   constexpr Index SHARED_PER_WARP = 49152/sizeof(Real) / WARPS_PER_BLOCK;
   //--------------------------------------------------------------------
   // Index blocks, sum, start = 0, nextStart = 0;
   Index blocks;
   const Index threads = THREADS_PER_BLOCK;

   /* Fill blocks */
   // std::vector<Block> inBlock;
   // inBlock.reserve(rows); // reserve space to avoid reallocation

   // while (nextStart != rows - 1) {
   //    Type type;
   //    nextStart = findLimit<Real, Index, Device, KernelType, maxElemPerWarp>(
   //       start, SHARED_PER_WARP, matrix, rows, type, sum
   //    );
   //    if (type == LONG) {
   //       uint32_t parts = roundUpDivision(sum, maxElemPerWarp);
   //       for (uint32_t index = 0; index < parts; ++index) {
   //          inBlock.emplace_back(start, LONG, index);
   //       }
   //    } else {
   //       inBlock.emplace_back(start, type);
   //    }

   //    start = nextStart;
   // }
   // inBlock.emplace_back(nextStart);

   // /* blocks to GPU */
   // Block *blocksAdaptive = nullptr;
   // cudaMalloc((void **)&blocksAdaptive, sizeof(*blocksAdaptive) * inBlock.size());
   // cudaMemcpy(blocksAdaptive, inBlock.data(), inBlock.size() * sizeof(*blocksAdaptive), cudaMemcpyHostToDevice);

   // size_t neededThreads = inBlock.size() * 32; // one warp per block
   size_t neededThreads = matrix.blocks.getSize() * 32; // one warp per block
   /* Execute kernels on device */
   for (Index grid = 0; neededThreads != 0; ++grid) {
@@ -1606,8 +1545,6 @@ void SpMVCSRAdaptivePrepare( const Real *inVector,
               grid
      );
   }

   // cudaFree(blocksAdaptive);
}

#endif
@@ -1760,43 +1697,43 @@ class CSRDeviceDependentCode< Devices::Cuda >
                                                              inVector.getData(),
                                                              outVector.getData() );
#else
         // switch(KernelType)
         // {
         //    case CSRScalar:
               // SpMVCSRScalarPrepare<Real, Index>(
               //    inVector.getData(),
               //    outVector.getData(),
               //    matrix.getRowPointers().getData(),
               //    matrix.getColumnIndexes().getData(),
               //    matrix.getValues().getData(),
               //    matrix.getRowPointers().getSize() - 1,
               //    matrix.getColumns()
               // );
         //       break;
         //    case CSRVector:
               // SpMVCSRVectorPrepare<Real, Index, 32>(
               //    inVector.getData(),
               //    outVector.getData(),
               //    matrix.getRowPointers().getData(),
               //    matrix.getColumnIndexes().getData(),
               //    matrix.getValues().getData(),
               //    matrix.getRowPointers().getSize() - 1,
               //    matrix.getColumns()
               // );
         //       break;
         //    case CSRLight:
               // SpMVCSRLightPrepare<Real, Index>(
               //    inVector.getData(),
               //    outVector.getData(),
               //    matrix.getRowPointers().getData(),
               //    matrix.getColumnIndexes().getData(),
               //    matrix.getValues().getData(),
               //    matrix.getValues().getSize(),
               //    matrix.getRowPointers().getSize() - 1,
               //    matrix.getColumns()
               // );
         //       break;
         //    case CSRAdaptive:
         switch(KernelType)
         {
            case CSRScalar:
               SpMVCSRScalarPrepare<Real, Index>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
                  matrix.getColumnIndexes().getData(),
                  matrix.getValues().getData(),
                  matrix.getRowPointers().getSize() - 1,
                  matrix.getColumns()
               );
               break;
            case CSRVector:
               SpMVCSRVectorPrepare<Real, Index, 32>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
                  matrix.getColumnIndexes().getData(),
                  matrix.getValues().getData(),
                  matrix.getRowPointers().getSize() - 1,
                  matrix.getColumns()
               );
               break;
            case CSRLight:
               SpMVCSRLightPrepare<Real, Index>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
                  matrix.getColumnIndexes().getData(),
                  matrix.getValues().getData(),
                  matrix.getValues().getSize(),
                  matrix.getRowPointers().getSize() - 1,
                  matrix.getColumns()
               );
               break;
            case CSRAdaptive:
               SpMVCSRAdaptivePrepare<Real, Index, Device, KernelType, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
@@ -1808,32 +1745,32 @@ class CSRDeviceDependentCode< Devices::Cuda >
                  matrix.getRowPointers().getSize(), // don't add -1 !
                  matrix.getColumns()
               );
         //       break;
         //    case CSRMultiVector:
               // SpMVCSRMultiVectorPrepare<Real, Index, 32, 1024>(
               //    inVector.getData(),
               //    outVector.getData(),
               //    matrix.getRowPointers().getData(),
               //    matrix.getColumnIndexes().getData(),
               //    matrix.getValues().getData(),
               //    matrix.getValues().getSize(),
               //    matrix.getRowPointers().getSize() - 1,
               //    matrix.getColumns()
               // );
         //       break;
         //    case CSRLightWithoutAtomic:
               // SpMVCSRLightWithoutAtomicPrepare<Real, Index, 32, 1024>(
               //    inVector.getData(),
               //    outVector.getData(),
               //    matrix.getRowPointers().getData(),
               //    matrix.getColumnIndexes().getData(),
               //    matrix.getValues().getData(),
               //    matrix.getValues().getSize(),
               //    matrix.getRowPointers().getSize() - 1,
               //    matrix.getColumns()
               // );
         //       break;
         // }
               break;
            case CSRMultiVector:
               SpMVCSRMultiVectorPrepare<Real, Index, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
                  matrix.getColumnIndexes().getData(),
                  matrix.getValues().getData(),
                  matrix.getValues().getSize(),
                  matrix.getRowPointers().getSize() - 1,
                  matrix.getColumns()
               );
               break;
            case CSRLightWithoutAtomic:
               SpMVCSRLightWithoutAtomicPrepare<Real, Index, 32, 1024>(
                  inVector.getData(),
                  outVector.getData(),
                  matrix.getRowPointers().getData(),
                  matrix.getColumnIndexes().getData(),
                  matrix.getValues().getData(),
                  matrix.getValues().getSize(),
                  matrix.getRowPointers().getSize() - 1,
                  matrix.getColumns()
               );
               break;
         }
#endif /* HAVE_CUDA */
#endif
      }