From 650b12ed929ea3b2eed5097c96c755c91f191b95 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Fri, 9 Jul 2021 10:57:54 +0200
Subject: [PATCH] Added CUDA kernel for RowMajor ordering of Ellpack.

---
 src/TNL/Algorithms/Segments/EllpackView.h     |   2 +
 src/TNL/Algorithms/Segments/EllpackView.hpp   | 147 ++++++++++++++++--
 .../Matrices/SparseMatrixTest_Ellpack.h       |  10 +-
 .../SparseMatrixVectorProductTest_Ellpack.h   |  10 +-
 4 files changed, 155 insertions(+), 14 deletions(-)

diff --git a/src/TNL/Algorithms/Segments/EllpackView.h b/src/TNL/Algorithms/Segments/EllpackView.h
index b8066f6353..1a14db3384 100644
--- a/src/TNL/Algorithms/Segments/EllpackView.h
+++ b/src/TNL/Algorithms/Segments/EllpackView.h
@@ -22,6 +22,8 @@ namespace TNL {
    namespace Algorithms {
       namespace Segments {
 
+enum EllpackKernelType { Scalar, Vector, Vector2, Vector4, Vector8, Vector16 };
+
 template< typename Device,
           typename Index,
           ElementsOrganization Organization = Segments::DefaultElementsOrganization< Device >::getOrganization(),
diff --git a/src/TNL/Algorithms/Segments/EllpackView.hpp b/src/TNL/Algorithms/Segments/EllpackView.hpp
index 7abf2caed6..6f49c55eef 100644
--- a/src/TNL/Algorithms/Segments/EllpackView.hpp
+++ b/src/TNL/Algorithms/Segments/EllpackView.hpp
@@ -19,6 +19,124 @@ namespace TNL {
    namespace Algorithms {
       namespace Segments {
 
+#ifdef HAVE_CUDA
+template< typename Index,
+          typename Fetch,
+          typename Reduction,
+          typename ResultKeeper,
+          typename Real >
+__global__ void
+EllpackCudaReductionKernelFull( Index first, Index last, Fetch fetch, const Reduction reduction, ResultKeeper keep, const Real zero, Index segmentSize )
+{
+   const int warpSize = 32;
+   const int gridID = 0;
+   const Index segmentIdx = first + ((gridID * TNL::Cuda::getMaxGridXSize() ) + (blockIdx.x * blockDim.x) + threadIdx.x) / warpSize;
+   if (segmentIdx >= last)
+      return;
+
+   Real result = zero;
+   const Index laneID = threadIdx.x & 31; // & is cheaper than %
+   const Index begin = segmentIdx * segmentSize;
+   const Index end = begin + segmentSize;
+
+   /* Calculate result */
+   Index localIdx( 0 );
+   bool compute( true );
+   for( Index i = begin + laneID; i < end; i += warpSize)
+      result = reduction( result, fetch( segmentIdx, localIdx++, i, compute ) );
+
+   /* Reduction */
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result, 16 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
+   /* Write result */
+   if( laneID == 0 )
+      keep( segmentIdx, result );
+}
+
+template< typename Index,
+          typename Fetch,
+          typename Reduction,
+          typename ResultKeeper,
+          typename Real >
+__global__ void
+EllpackCudaReductionKernelCompact( Index first, Index last, Fetch fetch, const Reduction reduction, ResultKeeper keep, const Real zero, Index segmentSize )
+{
+   const int warpSize = 32;
+   const int gridID = 0;
+   const Index segmentIdx = first + ((gridID * TNL::Cuda::getMaxGridXSize() ) + (blockIdx.x * blockDim.x) + threadIdx.x) / warpSize;
+   if (segmentIdx >= last)
+      return;
+
+   Real result = zero;
+   const Index laneID = threadIdx.x & 31; // & is cheaper than %
+   const Index begin = segmentIdx * segmentSize;
+   const Index end = begin + segmentSize;
+
+   /* Calculate result */
+   bool compute( true );
+   for( Index i = begin + laneID; i < end; i += warpSize)
+      result = reduction( result, fetch( i, compute ) );
+
+   /* Reduction */
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result, 16 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  8 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  4 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  2 ) );
+   result = reduction( result, __shfl_down_sync(0xFFFFFFFF, result,  1 ) );
+   /* Write result */
+   if( laneID == 0 )
+      keep( segmentIdx, result );
+
+}
+#endif
+
+template< typename Index,
+          typename Fetch,
+          typename Reduction,
+          typename ResultKeeper,
+          typename Real,
+          bool FullFetch = details::CheckFetchLambda< Index, Fetch >::hasAllParameters() >
+struct EllpackCudaReductionDispatcher
+{
+   static void
+   exec( Index first, Index last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Index segmentSize )
+   {
+   #ifdef HAVE_CUDA
+      const Index segmentsCount = last - first;
+      const Index threadsCount = segmentsCount * 32;
+      const Index blocksCount = Cuda::getNumberOfBlocks( threadsCount, 256 );
+      dim3 blockSize( 256 );
+      dim3 gridSize( blocksCount );
+      EllpackCudaReductionKernelFull<<< gridSize, blockSize >>>( first, last, fetch, reduction, keeper, zero, segmentSize );
+      cudaDeviceSynchronize();
+   #endif
+   }
+};
+
+template< typename Index,
+          typename Fetch,
+          typename Reduction,
+          typename ResultKeeper,
+          typename Real >
+struct EllpackCudaReductionDispatcher< Index, Fetch, Reduction, ResultKeeper, Real, false >
+{
+   static void
+   exec( Index first, Index last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Index segmentSize )
+   {
+   #ifdef HAVE_CUDA
+      const Index segmentsCount = last - first;
+      const Index threadsCount = segmentsCount * 32;
+      const Index blocksCount = Cuda::getNumberOfBlocks( threadsCount, 256 );
+      dim3 blockSize( 256 );
+      dim3 gridSize( blocksCount );
+      EllpackCudaReductionKernelCompact<<< gridSize, blockSize >>>( first, last, fetch, reduction, keeper, zero, segmentSize );
+      cudaDeviceSynchronize();
+   #endif
+   }
+};
 
 template< typename Device,
           typename Index,
@@ -277,18 +395,23 @@ reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction&
    using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
    if( Organization == RowMajorOrder )
    {
-      const IndexType segmentSize = this->segmentSize;
-      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx ) mutable {
-         const IndexType begin = segmentIdx * segmentSize;
-         const IndexType end = begin + segmentSize;
-         RealType aux( zero );
-         IndexType localIdx( 0 );
-         bool compute( true );
-         for( IndexType j = begin; j < end && compute; j++  )
-            aux = reduction( aux, detail::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, j, compute ) );
-         keeper( segmentIdx, aux );
-      };
-      Algorithms::ParallelFor< Device >::exec( first, last, l );
+      if( std::is_same< Device, Devices::Cuda >::value )
+         EllpackCudaReductionDispatcher< IndexType, Fetch, Reduction, ResultKeeper, Real>::exec( first, last, fetch, reduction, keeper, zero, segmentSize );
+      else
+      {
+         const IndexType segmentSize = this->segmentSize;
+         auto l = [=] __cuda_callable__ ( const IndexType segmentIdx ) mutable {
+            const IndexType begin = segmentIdx * segmentSize;
+            const IndexType end = begin + segmentSize;
+            RealType aux( zero );
+            IndexType localIdx( 0 );
+            bool compute( true );
+            for( IndexType j = begin; j < end && compute; j++  )
+               aux = reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, j, compute ) );
+            keeper( segmentIdx, aux );
+         };
+         Algorithms::ParallelFor< Device >::exec( first, last, l );
+      }
    }
    else
    {
diff --git a/src/UnitTests/Matrices/SparseMatrixTest_Ellpack.h b/src/UnitTests/Matrices/SparseMatrixTest_Ellpack.h
index ef56ec63a4..b13a19c6a4 100644
--- a/src/UnitTests/Matrices/SparseMatrixTest_Ellpack.h
+++ b/src/UnitTests/Matrices/SparseMatrixTest_Ellpack.h
@@ -46,7 +46,15 @@ using MatrixTypes = ::testing::Types
     TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
     TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
     TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
-    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
+    TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >
 #endif
 >;
 
diff --git a/src/UnitTests/Matrices/SparseMatrixVectorProductTest_Ellpack.h b/src/UnitTests/Matrices/SparseMatrixVectorProductTest_Ellpack.h
index abb4213ca2..c93aace755 100644
--- a/src/UnitTests/Matrices/SparseMatrixVectorProductTest_Ellpack.h
+++ b/src/UnitTests/Matrices/SparseMatrixVectorProductTest_Ellpack.h
@@ -46,7 +46,15 @@ using MatrixTypes = ::testing::Types
     TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
     TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
     TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
-    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, ColumnMajorEllpack >,
+    TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, int,   TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< int,     TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< long,    TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< float,   TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >,
+    TNL::Matrices::SparseMatrix< double,  TNL::Devices::Cuda, long,  TNL::Matrices::GeneralMatrix, RowMajorEllpack >
 #endif
 >;
 
-- 
GitLab