From 546a69d0614f60bdcf78b85dcf603569a797b6da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Mon, 9 Aug 2021 21:51:51 +0200 Subject: [PATCH] Optimizing CPU kernel for CSR format. --- .../Segments/Kernels/CSRScalarKernel.hpp | 111 +++++++++++++++++- 1 file changed, 108 insertions(+), 3 deletions(-) diff --git a/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp b/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp index d98f886613..5b9c5e7233 100644 --- a/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp +++ b/src/TNL/Algorithms/Segments/Kernels/CSRScalarKernel.hpp @@ -21,6 +21,108 @@ namespace TNL { namespace Algorithms { namespace Segments { +template< typename Index, + typename Device, + typename Fetch, + typename Reduce, + typename Keep, + bool DispatchScalarCSR = details::CheckFetchLambda< Index, Fetch >::hasAllParameters() > +struct CSRScalarKernelreduceSegmentsDispatcher; + +template< typename Index, + typename Device, + typename Fetch, + typename Reduction, + typename ResultKeeper > +struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper, true > +{ + + template< typename Offsets, + typename Real > + static void reduce( const Offsets& offsets, + Index first, + Index last, + Fetch& fetch, + const Reduction& reduction, + ResultKeeper& keep, + const Real& zero ) + { + auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable { + const Index begin = offsets[ segmentIdx ]; + const Index end = offsets[ segmentIdx + 1 ]; + Real aux( zero ); + Index localIdx( 0 ); + bool compute( true ); + for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++ ) + aux = reduction( aux, fetch( segmentIdx, localIdx++, globalIdx, compute ) ); + keep( segmentIdx, aux ); + }; + + if( std::is_same< Device, TNL::Devices::Sequential >::value ) + { + for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ ) + l( segmentIdx ); + } + else if( std::is_same< Device, TNL::Devices::Host >::value ) + { +#ifdef HAVE_OPENMP + #pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() ) +#endif + for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ ) + l( segmentIdx ); + } + else + Algorithms::ParallelFor< Device >::exec( first, last, l ); + } +}; + +template< typename Index, + typename Device, + typename Fetch, + typename Reduce, + typename Keep > +struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Keep, false > +{ + template< typename OffsetsView, + typename Real > + static void reduce( const OffsetsView& offsets, + Index first, + Index last, + Fetch& fetch, + const Reduce& reduction, + Keep& keep, + const Real& zero ) + { + auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable { + const Index begin = offsets[ segmentIdx ]; + const Index end = offsets[ segmentIdx + 1 ]; + Real aux( zero ); + bool compute( true ); + for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++ ) + aux = reduction( aux, fetch( globalIdx, compute ) ); + keep( segmentIdx, aux ); + }; + + if( std::is_same< Device, TNL::Devices::Sequential >::value ) + { + for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ ) + l( segmentIdx ); + } + else if( std::is_same< Device, TNL::Devices::Host >::value ) + { +#ifdef HAVE_OPENMP + #pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() ) +#endif + for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ ) + l( segmentIdx ); + } + else + Algorithms::ParallelFor< Device >::exec( first, last, l ); + + } +}; + + template< typename Index, typename Device > template< typename Offsets > @@ -84,6 +186,9 @@ reduceSegments( const OffsetsView& offsets, const Real& zero, Args... args ) { + CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper >::reduce( + offsets, first, last, fetch, reduction, keeper, zero ); + /* auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable { const IndexType begin = offsets[ segmentIdx ]; const IndexType end = offsets[ segmentIdx + 1 ]; @@ -102,7 +207,7 @@ reduceSegments( const OffsetsView& offsets, #endif for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ ) l( segmentIdx, args... ); - /*{ + { const IndexType begin = offsets[ segmentIdx ]; const IndexType end = offsets[ segmentIdx + 1 ]; Real aux( zero ); @@ -111,10 +216,10 @@ reduceSegments( const OffsetsView& offsets, for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++ ) aux = reduction( aux, detail::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) ); keeper( segmentIdx, aux ); - }*/ + } } else - Algorithms::ParallelFor< Device >::exec( first, last, l, args... ); + Algorithms::ParallelFor< Device >::exec( first, last, l, args... );*/ } } // namespace Segments } // namespace Algorithms -- GitLab