From 7bc3df04ea99d0ec09f2935900526abbabfd8364 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Thu, 12 Aug 2021 18:54:24 +0200
Subject: [PATCH] Threads per segment in Light CSR can be set by the user.

---
 src/Benchmarks/SpMV/spmv.h                    | 15 ++-
 src/TNL/Algorithms/Segments/CSR.h             |  4 +
 src/TNL/Algorithms/Segments/CSRView.h         |  4 +
 .../Segments/Kernels/CSRLightKernel.h         | 18 +++-
 .../Segments/Kernels/CSRLightKernel.hpp       | 94 ++++++++++++++++---
 src/TNL/Matrices/SparseMatrix.hpp             |  2 +-
 6 files changed, 117 insertions(+), 20 deletions(-)

diff --git a/src/Benchmarks/SpMV/spmv.h b/src/Benchmarks/SpMV/spmv.h
index 767d446a63..e187f54367 100644
--- a/src/Benchmarks/SpMV/spmv.h
+++ b/src/Benchmarks/SpMV/spmv.h
@@ -436,8 +436,14 @@ benchmarkSpMVCSRLight( BenchmarkType& benchmark,
    auto spmvCuda = [&]() {
       cudaMatrix.vectorProduct( cudaInVector, cudaOutVector );
    };
-   SpmvBenchmarkResult< Real, Devices::Cuda, int > cudaBenchmarkResults( MatrixInfo< HostMatrix >::getFormat(), csrResultVector, cudaOutVector, cudaMatrix.getNonzeroElementsCount() );
-   benchmark.time< Devices::Cuda >( resetCudaVectors, "GPU", spmvCuda, cudaBenchmarkResults );
+
+   for( auto threadsPerRow : std::vector< int >{ 1, 2, 4, 8, 16, 32 } )
+   {
+      cudaMatrix.getSegments().getKernel().setThreadsPerSegment( threadsPerRow );
+      String format = MatrixInfo< HostMatrix >::getFormat() + " " + convertToString( threadsPerRow );
+      SpmvBenchmarkResult< Real, Devices::Cuda, int > cudaBenchmarkResults( format, csrResultVector, cudaOutVector, cudaMatrix.getNonzeroElementsCount() );
+      benchmark.time< Devices::Cuda >( resetCudaVectors, "GPU", spmvCuda, cudaBenchmarkResults );
+   }
  #endif
 }
 
@@ -570,10 +576,13 @@ benchmarkSpmv( BenchmarkType& benchmark,
    ////
    // Perform benchmark on host with CSR as a reference CPU format
    //
+   auto nonzeros = csrHostMatrix.getNonzeroElementsCount();
    benchmark.addCommonLogs( BenchmarkType::CommonLogs( {
       { "matrix name", convertToString( inputFileName ) },
       { "rows", convertToString( csrHostMatrix.getRows() ) },
-      { "columns", convertToString( csrHostMatrix.getColumns() ) } } ) );
+      { "columns", convertToString( csrHostMatrix.getColumns() ) },
+      { "nonzeros", convertToString( nonzeros ) },
+      { "nonzeros per row", convertToString( ( double ) nonzeros / ( double ) csrHostMatrix.getRows() ) } } ) );
 
    HostVector hostInVector( csrHostMatrix.getRows() ), hostOutVector( csrHostMatrix.getRows() );
 
diff --git a/src/TNL/Algorithms/Segments/CSR.h b/src/TNL/Algorithms/Segments/CSR.h
index aa3f16d6bd..27bdfe3e2c 100644
--- a/src/TNL/Algorithms/Segments/CSR.h
+++ b/src/TNL/Algorithms/Segments/CSR.h
@@ -507,6 +507,10 @@ class CSR
       template< typename Fetch >
       SegmentsPrinter< CSR, Fetch > print( Fetch&& fetch ) const;
 
+      KernelType& getKernel() { return kernel; }
+
+      const KernelType& getKernel() const { return kernel; }
+
    protected:
 
       OffsetsContainer offsets;
diff --git a/src/TNL/Algorithms/Segments/CSRView.h b/src/TNL/Algorithms/Segments/CSRView.h
index 884ed71cf3..b593dc4677 100644
--- a/src/TNL/Algorithms/Segments/CSRView.h
+++ b/src/TNL/Algorithms/Segments/CSRView.h
@@ -143,6 +143,10 @@ class CSRView
       template< typename Fetch >
       SegmentsPrinter< CSRView, Fetch > print( Fetch&& fetch ) const;
 
+      KernelType& getKernel() { return kernel; }
+
+      const KernelType& getKernel() const { return kernel; }
+
    protected:
 
       OffsetsView offsets;
diff --git a/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.h b/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.h
index a3aa961b40..49a662ccf2 100644
--- a/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.h
+++ b/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.h
@@ -20,6 +20,8 @@ namespace TNL {
    namespace Algorithms {
       namespace Segments {
 
+enum LightCSRSThreadsMapping { LightCSRConstantThreads, CSRLightAutomaticThreads, CSRLightAutomaticThreadsLightSpMV };
+
 template< typename Index,
           typename Device >
 struct CSRLightKernel
@@ -40,6 +42,8 @@ struct CSRLightKernel
 
    static TNL::String getKernelType();
 
+   TNL::String getSetup() const;
+
    template< typename OffsetsView,
              typename Fetch,
              typename Reduction,
@@ -53,8 +57,20 @@ struct CSRLightKernel
                         ResultKeeper& keeper,
                         const Real& zero ) const;
 
+
+   void setThreadsMapping( LightCSRSThreadsMapping mapping );
+
+   LightCSRSThreadsMapping getThreadsMapping() const;
+
+   void setThreadsPerSegment( int threadsPerSegment );
+
+   int getThreadsPerSegment() const;
+
    protected:
-      int threadsPerSegment = 0;
+
+      LightCSRSThreadsMapping mapping = LightCSRConstantThreads;
+
+      int threadsPerSegment = 32;
 };
 
       } // namespace Segments
diff --git a/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.hpp b/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.hpp
index 59d8ae0e3e..1c35182884 100644
--- a/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.hpp
+++ b/src/TNL/Algorithms/Segments/Kernels/CSRLightKernel.hpp
@@ -426,6 +426,7 @@ struct CSRLightKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Kee
 #ifdef HAVE_CUDA
       const size_t threads = 128;
       Index blocks, groupSize;
+
       size_t  neededThreads = threadsPerSegment * ( last - first );
 
       for (Index grid = 0; neededThreads != 0; ++grid)
@@ -513,21 +514,40 @@ CSRLightKernel< Index, Device >::
 init( const Offsets& offsets )
 {
    const Index segmentsCount = offsets.getSize() - 1;
-   //size_t neededThreads = segmentsCount * 32;//warpSize;
-
-   const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
-   if( elementsInSegment <= 2 )
-      this->threadsPerSegment = 2;
-   else if( elementsInSegment <= 4 )
-      this->threadsPerSegment = 4;
-   else if( elementsInSegment <= 8 )
-      this->threadsPerSegment = 8;
-   else if( elementsInSegment <= 16 )
-      this->threadsPerSegment = 16;
-   else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
-      this->threadsPerSegment = 32; // CSR Vector
-   //else
-   //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
+
+   if( this->getThreadsMapping() == CSRLightAutomaticThreads )
+   {
+      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
+      if( elementsInSegment <= 2 )
+         this->threadsPerSegment = 2;
+      else if( elementsInSegment <= 4 )
+         this->threadsPerSegment = 4;
+      else if( elementsInSegment <= 8 )
+         this->threadsPerSegment = 8;
+      else if( elementsInSegment <= 16 )
+         this->threadsPerSegment = 16;
+      else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
+         this->threadsPerSegment = 32; // CSR Vector
+      //else
+      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
+   }
+
+   if( this->getThreadsMapping() == CSRLightAutomaticThreadsLightSpMV )
+   {
+      const Index elementsInSegment = roundUpDivision( offsets.getElement( segmentsCount ), segmentsCount ); // non zeroes per row
+      if( elementsInSegment <= 2 )
+         this->threadsPerSegment = 2;
+      else if( elementsInSegment <= 4 )
+         this->threadsPerSegment = 4;
+      else if( elementsInSegment <= 8 )
+         this->threadsPerSegment = 8;
+      else if( elementsInSegment <= 16 )
+         this->threadsPerSegment = 16;
+      else //if (nnz <= 2 * matrix.MAX_ELEMENTS_PER_WARP)
+         this->threadsPerSegment = 32; // CSR Vector
+      //else
+      //   threadsPerSegment = roundUpDivision(nnz, matrix.MAX_ELEMENTS_PER_WARP) * 32; // CSR MultiVector
+   }
 
    TNL_ASSERT_GE( this->threadsPerSegment, 0, "" );
    TNL_ASSERT_LE( this->threadsPerSegment, 33, "" );
@@ -594,6 +614,50 @@ reduceSegments( const OffsetsView& offsets,
       offsets, first, last, fetch, reduce, keep, zero, this->threadsPerSegment );
 }
 
+template< typename Index,
+          typename Device >
+void
+CSRLightKernel< Index, Device >::
+setThreadsMapping( LightCSRSThreadsMapping mapping )
+{
+   this-> mapping = mapping;
+}
+
+template< typename Index,
+          typename Device >
+LightCSRSThreadsMapping
+CSRLightKernel< Index, Device >::
+getThreadsMapping() const
+{
+   return this->mapping;
+}
+
+template< typename Index,
+          typename Device >
+void
+CSRLightKernel< Index, Device >::
+setThreadsPerSegment( int threadsPerSegment )
+{
+   if( threadsPerSegment !=  1 &&
+       threadsPerSegment !=  2 &&
+       threadsPerSegment !=  4 &&
+       threadsPerSegment !=  8 &&
+       threadsPerSegment != 16 &&
+       threadsPerSegment != 32 )
+       throw std::runtime_error( "Number of threads per segment must be power of 2 - 1, 2, ... 32." );
+   this->threadsPerSegment = threadsPerSegment;
+}
+
+template< typename Index,
+          typename Device >
+int
+CSRLightKernel< Index, Device >::
+getThreadsPerSegment() const
+{
+   return this->threadsPerSegment;
+}
+
+
       } // namespace Segments
    }  // namespace Algorithms
 } // namespace TNL
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index ac86884257..dd11c6cf74 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -524,7 +524,7 @@ vectorProduct( const InVector& inVector,
                const IndexType firstRow,
                const IndexType lastRow ) const
 {
-   this->view.vectorProduct( inVector, outVector, matrixMultiplicator, outVectorMultiplicator, firstRow, lastRow );
+   this->getView().vectorProduct( inVector, outVector, matrixMultiplicator, outVectorMultiplicator, firstRow, lastRow );
 }
 
 template< typename Real,
-- 
GitLab