From cb834c849b1ce2af64f33f3370be88c9227c453d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Sat, 5 Jan 2019 08:33:50 +0100 Subject: [PATCH] Added GridTraverserBenchmarkHelper. --- .../Traversers/GridTraversersBenchmark_1D.h | 122 +++++++++++++++--- 1 file changed, 104 insertions(+), 18 deletions(-) diff --git a/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h b/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h index e626b17e35..22f1d68996 100644 --- a/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h +++ b/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h @@ -28,13 +28,110 @@ namespace TNL { namespace Benchmarks { namespace Traversers { +template< typename Grid, + typename Device = typename Grid::DeviceType > +class GridTraverserBenchmarkHelper{}; + +template< typename Grid > +class GridTraverserBenchmarkHelper< Grid, Devices::Host > +{ + public: + + using GridType = Grid; + using GridPointer = Pointers::SharedPointer< Grid >; + using RealType = typename GridType::RealType; + using IndexType = typename GridType::IndexType; + using CoordinatesType = typename Grid::CoordinatesType; + using MeshFunction = Functions::MeshFunction< Grid >; + using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >; + using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >; + using Traverser = Meshes::Traverser< Grid, Cell >; + using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >; + using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >; + + + static void noBCTraverserTest( const GridPointer& grid, + WriteOneTraverserUserDataType& userData, + std::size_t size ) + { + /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >( + grid, + CoordinatesType( 0 ), + grid->getDimensions() - CoordinatesType( 1 ), + userData );*/ + + const CoordinatesType begin( 0 ); + const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 ); + //MeshFunction* _u = &u.template modifyData< Device >(); + Cell entity( *grid ); + for( IndexType x = begin.x(); x <= end.x(); x ++ ) + { + entity.getCoordinates().x() = x; + entity.refresh(); + WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity ); + } + + } +}; + +template< typename Grid > +class GridTraverserBenchmarkHelper< Grid, Devices::Cuda > +{ + public: + + using GridType = Grid; + using GridPointer = Pointers::SharedPointer< Grid >; + using RealType = typename GridType::RealType; + using IndexType = typename GridType::IndexType; + using CoordinatesType = typename Grid::CoordinatesType; + using MeshFunction = Functions::MeshFunction< Grid >; + using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >; + using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >; + using Traverser = Meshes::Traverser< Grid, Cell >; + using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >; + using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >; + + + static void noBCTraverserTest( const GridPointer& grid, + WriteOneTraverserUserDataType& userData, + std::size_t size ) + { +#ifdef HAVE_CUDA + dim3 blockSize( 256 ), blocksCount, gridsCount; + Devices::Cuda::setupThreads( + blockSize, + blocksCount, + gridsCount, + size ); + dim3 gridIdx; + for( gridIdx.x = 0; gridIdx.x < gridsCount.x; gridIdx.x++ ) + { + dim3 gridSize; + Devices::Cuda::setupGrid( + blocksCount, + gridsCount, + gridIdx, + gridSize ); + Meshes::GridTraverser1D< RealType, IndexType, Cell, WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType > + <<< blocksCount, blockSize >>> + ( &grid.template getData< Devices::Cuda >(), + userData, + CoordinatesType( 0 ), + CoordinatesType( size ) - CoordinatesType( 1 ), + gridIdx.x ); + + } +#endif + } +}; + template< typename Device, typename Real, typename Index > class GridTraversersBenchmark< 1, Device, Real, Index > { public: - + using Vector = Containers::Vector< Real, Device, Index >; using Grid = Meshes::Grid< 1, Real, Device, Index >; using GridPointer = Pointers::SharedPointer< Grid >; @@ -130,24 +227,13 @@ class GridTraversersBenchmark< 1, Device, Real, Index > void writeOneUsingTraverser() { using CoordinatesType = typename Grid::CoordinatesType; - traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType > - ( grid, userData ); + //traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType > + // ( grid, userData ); - /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >( - grid, - CoordinatesType( 0 ), - grid->getDimensions() - CoordinatesType( 1 ), - userData );*/ - /*const CoordinatesType begin( 0 ); - const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 ); - MeshFunction* _u = &u.template modifyData< Device >(); - Cell entity( *grid ); - for( Index x = begin.x(); x <= end.x(); x ++ ) - { - entity.getCoordinates().x() = x; - entity.refresh(); - WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity ); - }*/ + GridTraverserBenchmarkHelper< Grid >::noBCTraverserTest( + grid, + userData, + size ); } void traverseUsingPureC() -- GitLab