Skip to content
Snippets Groups Projects
Commit 546a69d0 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Optimizing CPU kernel for CSR format.

parent b00aed9a
No related branches found
No related tags found
1 merge request!105TO/matrices-adaptive-csr
......@@ -21,6 +21,108 @@ namespace TNL {
namespace Algorithms {
namespace Segments {
template< typename Index,
typename Device,
typename Fetch,
typename Reduce,
typename Keep,
bool DispatchScalarCSR = details::CheckFetchLambda< Index, Fetch >::hasAllParameters() >
struct CSRScalarKernelreduceSegmentsDispatcher;
template< typename Index,
typename Device,
typename Fetch,
typename Reduction,
typename ResultKeeper >
struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper, true >
{
template< typename Offsets,
typename Real >
static void reduce( const Offsets& offsets,
Index first,
Index last,
Fetch& fetch,
const Reduction& reduction,
ResultKeeper& keep,
const Real& zero )
{
auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable {
const Index begin = offsets[ segmentIdx ];
const Index end = offsets[ segmentIdx + 1 ];
Real aux( zero );
Index localIdx( 0 );
bool compute( true );
for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++ )
aux = reduction( aux, fetch( segmentIdx, localIdx++, globalIdx, compute ) );
keep( segmentIdx, aux );
};
if( std::is_same< Device, TNL::Devices::Sequential >::value )
{
for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
l( segmentIdx );
}
else if( std::is_same< Device, TNL::Devices::Host >::value )
{
#ifdef HAVE_OPENMP
#pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() )
#endif
for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
l( segmentIdx );
}
else
Algorithms::ParallelFor< Device >::exec( first, last, l );
}
};
template< typename Index,
typename Device,
typename Fetch,
typename Reduce,
typename Keep >
struct CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduce, Keep, false >
{
template< typename OffsetsView,
typename Real >
static void reduce( const OffsetsView& offsets,
Index first,
Index last,
Fetch& fetch,
const Reduce& reduction,
Keep& keep,
const Real& zero )
{
auto l = [=] __cuda_callable__ ( const Index segmentIdx ) mutable {
const Index begin = offsets[ segmentIdx ];
const Index end = offsets[ segmentIdx + 1 ];
Real aux( zero );
bool compute( true );
for( Index globalIdx = begin; globalIdx < end && compute; globalIdx++ )
aux = reduction( aux, fetch( globalIdx, compute ) );
keep( segmentIdx, aux );
};
if( std::is_same< Device, TNL::Devices::Sequential >::value )
{
for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
l( segmentIdx );
}
else if( std::is_same< Device, TNL::Devices::Host >::value )
{
#ifdef HAVE_OPENMP
#pragma omp parallel for firstprivate( l ) schedule( dynamic, 100 ), if( Devices::Host::isOMPEnabled() )
#endif
for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
l( segmentIdx );
}
else
Algorithms::ParallelFor< Device >::exec( first, last, l );
}
};
template< typename Index,
typename Device >
template< typename Offsets >
......@@ -84,6 +186,9 @@ reduceSegments( const OffsetsView& offsets,
const Real& zero,
Args... args )
{
CSRScalarKernelreduceSegmentsDispatcher< Index, Device, Fetch, Reduction, ResultKeeper >::reduce(
offsets, first, last, fetch, reduction, keeper, zero );
/*
auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
const IndexType begin = offsets[ segmentIdx ];
const IndexType end = offsets[ segmentIdx + 1 ];
......@@ -102,7 +207,7 @@ reduceSegments( const OffsetsView& offsets,
#endif
for( Index segmentIdx = first; segmentIdx < last; segmentIdx ++ )
l( segmentIdx, args... );
/*{
{
const IndexType begin = offsets[ segmentIdx ];
const IndexType end = offsets[ segmentIdx + 1 ];
Real aux( zero );
......@@ -111,10 +216,10 @@ reduceSegments( const OffsetsView& offsets,
for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++ )
aux = reduction( aux, detail::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
keeper( segmentIdx, aux );
}*/
}
}
else
Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
Algorithms::ParallelFor< Device >::exec( first, last, l, args... );*/
}
} // namespace Segments
} // namespace Algorithms
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment