From 546a69d0614f60bdcf78b85dcf603569a797b6da Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Mon, 9 Aug 2021 21:51:51 +0200
Subject: [PATCH] Optimizing CPU kernel for CSR format.

---
 .../Segments/Kernels/CSRScalarKernel.hpp      | 111 +++++++++++++++++-
 1 file changed, 108 insertions(+), 3 deletions(-)

diff --git a/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp b/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp
index d98f886613..5b9c5e7233 100644
--- a/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp
+++ b/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp
@@ -21,6 +21,108 @@ namespace TNL {
    namespace Algorithms {
       namespace Segments {
 
+template< typename Index,
+          typename Device,
+          typename Fetch,
+          typename Reduce,
+          typename Keep,
+          bool DispatchScalarCSR = details::CheckFetchLambda< Index, Fetch >::hasAllParameters() >
+struct CSRScalarKernelreduceSegmentsDispatcher;
+
+template< typename Index,
+          typename Device,
+          typename Fetch,
+          typename Reduction,
+          typename ResultKeeper >
+struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper, true >
+{
+
+   template< typename Offsets,
+             typename Real >
+   static void reduce( const Offsets& offsets,
+                       Index first,
+                       Index last,
+                       Fetch& fetch,
+                       const Reduction& reduction,
+                       ResultKeeper& keep,
+                       const Real& zero )
+   {
+      auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable {
+         const Index begin = offsets[ segmentIdx ];
+         const Index end = offsets[ segmentIdx + 1 ];
+         Real aux( zero );
+         Index localIdx( 0 );
+         bool compute( true );
+         for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++  )
+             aux = reduction( aux, fetch( segmentIdx, localIdx++, globalIdx, compute ) );
+         keep( segmentIdx, aux );
+      };
+
+      if( std::is_same< Device, TNL::Devices::Sequential >::value )
+      {
+         for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
+            l( segmentIdx );
+      }
+      else if( std::is_same< Device, TNL::Devices::Host >::value )
+      {
+#ifdef HAVE_OPENMP
+        #pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() )
+#endif
+         for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
+            l( segmentIdx );
+      }
+      else
+         Algorithms::ParallelFor< Device >::exec( first, last, l );
+   }
+};
+
+template< typename Index,
+          typename Device,
+          typename Fetch,
+          typename Reduce,
+          typename Keep >
+struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Keep, false >
+{
+   template< typename OffsetsView,
+             typename Real >
+   static void reduce( const OffsetsView& offsets,
+                       Index first,
+                       Index last,
+                       Fetch& fetch,
+                       const Reduce& reduction,
+                       Keep& keep,
+                       const Real& zero )
+   {
+      auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable {
+         const Index begin = offsets[ segmentIdx ];
+         const Index end = offsets[ segmentIdx + 1 ];
+         Real aux( zero );
+         bool compute( true );
+         for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++  )
+             aux = reduction( aux, fetch( globalIdx, compute ) );
+         keep( segmentIdx, aux );
+      };
+
+      if( std::is_same< Device, TNL::Devices::Sequential >::value )
+      {
+         for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
+            l( segmentIdx );
+      }
+      else if( std::is_same< Device, TNL::Devices::Host >::value )
+      {
+#ifdef HAVE_OPENMP
+        #pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() )
+#endif
+         for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
+            l( segmentIdx );
+      }
+      else
+         Algorithms::ParallelFor< Device >::exec( first, last, l );
+
+   }
+};
+
+
 template< typename Index,
           typename Device >
     template< typename Offsets >
@@ -84,6 +186,9 @@ reduceSegments( const OffsetsView& offsets,
                    const Real& zero,
                    Args... args )
 {
+   CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper >::reduce(
+      offsets, first, last, fetch, reduction, keeper, zero );
+   /*
     auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
         const IndexType begin = offsets[ segmentIdx ];
         const IndexType end = offsets[ segmentIdx + 1 ];
@@ -102,7 +207,7 @@ reduceSegments( const OffsetsView& offsets,
 #endif
         for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
             l( segmentIdx, args... );
-        /*{
+        {
             const IndexType begin = offsets[ segmentIdx ];
             const IndexType end = offsets[ segmentIdx + 1 ];
             Real aux( zero );
@@ -111,10 +216,10 @@ reduceSegments( const OffsetsView& offsets,
             for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++  )
                 aux = reduction( aux, detail::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
             keeper( segmentIdx, aux );
-        }*/
+        }
     }
     else
-        Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
+        Algorithms::ParallelFor< Device >::exec( first, last, l, args... );*/
 }
       } // namespace Segments
    }  // namespace Algorithms
-- 
GitLab