Skip to content
Snippets Groups Projects
Commit ffc00260 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

NDArray: simplified executors for operations

parent f6d08f4b
No related branches found
No related tags found
1 merge request!18NDArray
......@@ -90,86 +90,29 @@ struct SequentialExecutorRTL< Array, IndexTag< 0 > >
template< typename Array,
typename DimTag = IndexTag< Array::getDimension() > >
struct OpenMPExecutor
typename Device = typename Array::DeviceType >
struct ParallelExecutorDeviceDispatch
{
template< typename Func >
void operator()( const Array& array, Func f )
{
SequentialExecutor< Array, IndexTag< 3 > > exec;
const auto size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const auto size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
const auto size2 = array.template getSize< get< 2 >( typename Array::PermutationType{} ) >();
#ifdef HAVE_OPENMP
#pragma omp parallel for collapse(3)
#endif
for( typename Array::IndexType i0 = 0; i0 < size0; i0++ )
for( typename Array::IndexType i1 = 0; i1 < size1; i1++ )
for( typename Array::IndexType i2 = 0; i2 < size2; i2++ )
exec( array, f, i0, i1, i2 );
}
};
using Index = typename Array::IndexType;
template< typename Array >
struct OpenMPExecutor< Array, IndexTag< 3 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
const auto size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const auto size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
const auto size2 = array.template getSize< get< 2 >( typename Array::PermutationType{} ) >();
#ifdef HAVE_OPENMP
#pragma omp parallel for collapse(2)
#endif
for( typename Array::IndexType i0 = 0; i0 < size0; i0++ )
for( typename Array::IndexType i1 = 0; i1 < size1; i1++ )
for( typename Array::IndexType i2 = 0; i2 < size2; i2++ )
call_with_permuted_arguments< typename Array::PermutationType >( f, i0, i1, i2 );
}
};
auto kernel = [=] ( Index i2, Index i1, Index i0 )
{
SequentialExecutor< Array, IndexTag< 3 > > exec;
exec( array, f, i0, i1, i2 );
};
template< typename Array >
struct OpenMPExecutor< Array, IndexTag< 2 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
const auto size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const auto size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
#ifdef HAVE_OPENMP
#pragma omp parallel for
#endif
for( typename Array::IndexType i0 = 0; i0 < size0; i0++ )
for( typename Array::IndexType i1 = 0; i1 < size1; i1++ )
call_with_permuted_arguments< typename Array::PermutationType >( f, i0, i1 );
const Index size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const Index size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
const Index size2 = array.template getSize< get< 2 >( typename Array::PermutationType{} ) >();
ParallelFor3D< Device >::exec( (Index) 0, (Index) 0, (Index) 0, size2, size1, size0, kernel );
}
};
template< typename Array >
struct OpenMPExecutor< Array, IndexTag< 1 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
const auto size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
#ifdef HAVE_OPENMP
#pragma omp parallel for
#endif
for( typename Array::IndexType i0 = 0; i0 < size0; i0++ )
call_with_permuted_arguments< typename Array::PermutationType >( f, i0 );
}
};
template< typename Array,
typename DimTag = IndexTag< Array::getDimension() > >
struct CudaExecutor
struct ParallelExecutorDeviceDispatch< Array, Devices::Cuda >
{
template< typename Func >
void operator()( const Array& array, Func f )
......@@ -189,12 +132,25 @@ struct CudaExecutor
}
};
template< typename Array,
typename DimTag = IndexTag< Array::getDimension() > >
struct ParallelExecutor
{
template< typename Func >
void operator()( const Array& array, Func f )
{
ParallelExecutorDeviceDispatch< Array > dispatch;
dispatch( array, f );
}
};
template< typename Array >
struct CudaExecutor< Array, IndexTag< 3 > >
struct ParallelExecutor< Array, IndexTag< 3 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
using Device = typename Array::DeviceType;
using Index = typename Array::IndexType;
auto kernel = [=] __cuda_callable__ ( Index i2, Index i1, Index i0 )
......@@ -205,16 +161,17 @@ struct CudaExecutor< Array, IndexTag< 3 > >
const Index size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const Index size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
const Index size2 = array.template getSize< get< 2 >( typename Array::PermutationType{} ) >();
ParallelFor3D< Devices::Cuda >::exec( (Index) 0, (Index) 0, (Index) 0, size2, size1, size0, kernel );
ParallelFor3D< Device >::exec( (Index) 0, (Index) 0, (Index) 0, size2, size1, size0, kernel );
}
};
template< typename Array >
struct CudaExecutor< Array, IndexTag< 2 > >
struct ParallelExecutor< Array, IndexTag< 2 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
using Device = typename Array::DeviceType;
using Index = typename Array::IndexType;
auto kernel = [=] __cuda_callable__ ( Index i1, Index i0 )
......@@ -224,16 +181,17 @@ struct CudaExecutor< Array, IndexTag< 2 > >
const Index size0 = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
const Index size1 = array.template getSize< get< 1 >( typename Array::PermutationType{} ) >();
ParallelFor2D< Devices::Cuda >::exec( (Index) 0, (Index) 0, size1, size0, kernel );
ParallelFor2D< Device >::exec( (Index) 0, (Index) 0, size1, size0, kernel );
}
};
template< typename Array >
struct CudaExecutor< Array, IndexTag< 1 > >
struct ParallelExecutor< Array, IndexTag< 1 > >
{
template< typename Func >
void operator()( const Array& array, Func f )
{
using Device = typename Array::DeviceType;
using Index = typename Array::IndexType;
auto kernel = [=] __cuda_callable__ ( Index i )
......@@ -242,7 +200,7 @@ struct CudaExecutor< Array, IndexTag< 1 > >
};
const Index size = array.template getSize< get< 0 >( typename Array::PermutationType{} ) >();
ParallelFor< Devices::Cuda >::exec( (Index) 0, size, kernel );
ParallelFor< Device >::exec( (Index) 0, size, kernel );
}
};
......@@ -265,7 +223,7 @@ struct ExecutorDispatcher< Array, Devices::Host >
void operator()( const Array& array, Func f )
{
if( Devices::Host::isOMPEnabled() && Devices::Host::getMaxThreadsCount() > 1 )
OpenMPExecutor< Array >()( array, f );
ParallelExecutor< Array >()( array, f );
else
SequentialExecutor< Array >()( array, f );
}
......@@ -277,7 +235,7 @@ struct ExecutorDispatcher< Array, Devices::Cuda >
template< typename Func >
void operator()( const Array& array, Func f )
{
CudaExecutor< Array >()( array, f );
ParallelExecutor< Array >()( array, f );
}
};
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment