From cb834c849b1ce2af64f33f3370be88c9227c453d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Sat, 5 Jan 2019 08:33:50 +0100
Subject: [PATCH] Added GridTraverserBenchmarkHelper.

---
 .../Traversers/GridTraversersBenchmark_1D.h   | 122 +++++++++++++++---
 1 file changed, 104 insertions(+), 18 deletions(-)

diff --git a/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h b/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h
index e626b17e35..22f1d68996 100644
--- a/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h
+++ b/src/Benchmarks/Traversers/GridTraversersBenchmark_1D.h
@@ -28,13 +28,110 @@ namespace TNL {
    namespace Benchmarks {
       namespace Traversers {
 
+template< typename Grid,
+          typename Device = typename Grid::DeviceType >
+class GridTraverserBenchmarkHelper{};
+
+template< typename Grid >
+class GridTraverserBenchmarkHelper< Grid, Devices::Host >
+{
+   public:
+
+      using GridType = Grid;
+      using GridPointer = Pointers::SharedPointer< Grid >;
+      using RealType = typename GridType::RealType;
+      using IndexType = typename GridType::IndexType;
+      using CoordinatesType = typename Grid::CoordinatesType;
+      using MeshFunction = Functions::MeshFunction< Grid >;
+      using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >;
+      using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >;
+      using Traverser = Meshes::Traverser< Grid, Cell >;
+      using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >;
+      using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >;
+      
+
+      static void noBCTraverserTest( const GridPointer& grid,
+                                     WriteOneTraverserUserDataType& userData,
+                                     std::size_t size )
+      {
+         /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >(
+           grid,
+           CoordinatesType( 0 ),
+           grid->getDimensions() - CoordinatesType( 1 ),
+           userData );*/
+
+         const CoordinatesType begin( 0 );
+         const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 );
+         //MeshFunction* _u = &u.template modifyData< Device >();
+         Cell entity( *grid );
+         for( IndexType x = begin.x(); x <= end.x(); x ++ )
+         {
+            entity.getCoordinates().x() = x;
+            entity.refresh();
+            WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity );
+         }
+
+      }
+};
+
+template< typename Grid >
+class GridTraverserBenchmarkHelper< Grid, Devices::Cuda >
+{
+   public:
+
+      using GridType = Grid;
+      using GridPointer = Pointers::SharedPointer< Grid >;
+      using RealType = typename GridType::RealType;
+      using IndexType = typename GridType::IndexType;
+      using CoordinatesType = typename Grid::CoordinatesType;
+      using MeshFunction = Functions::MeshFunction< Grid >;
+      using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >;
+      using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >;
+      using Traverser = Meshes::Traverser< Grid, Cell >;
+      using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >;
+      using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >;
+
+
+      static void noBCTraverserTest( const GridPointer& grid,
+                                     WriteOneTraverserUserDataType& userData,
+                                     std::size_t size )
+      {
+#ifdef HAVE_CUDA
+            dim3 blockSize( 256 ), blocksCount, gridsCount;
+            Devices::Cuda::setupThreads(
+               blockSize,
+               blocksCount,
+               gridsCount,
+               size );
+            dim3 gridIdx;
+            for( gridIdx.x = 0; gridIdx.x < gridsCount.x; gridIdx.x++ )
+            {
+               dim3 gridSize;
+               Devices::Cuda::setupGrid(
+                  blocksCount,
+                  gridsCount,
+                  gridIdx,
+                  gridSize );
+               Meshes::GridTraverser1D< RealType, IndexType, Cell, WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
+               <<< blocksCount, blockSize >>>
+               ( &grid.template getData< Devices::Cuda >(),
+                 userData,
+                 CoordinatesType( 0 ),
+                 CoordinatesType( size ) - CoordinatesType( 1 ),
+                 gridIdx.x );
+
+            }
+#endif
+      }
+};
+
 template< typename Device,
           typename Real,
           typename Index >
 class GridTraversersBenchmark< 1, Device, Real, Index >
 {
    public:
-      
+
       using Vector = Containers::Vector< Real, Device, Index >;
       using Grid = Meshes::Grid< 1, Real, Device, Index >;
       using GridPointer = Pointers::SharedPointer< Grid >;
@@ -130,24 +227,13 @@ class GridTraversersBenchmark< 1, Device, Real, Index >
       void writeOneUsingTraverser()
       {
          using CoordinatesType = typename Grid::CoordinatesType;
-         traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
-            ( grid, userData );
+         //traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
+         //   ( grid, userData );
          
-         /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >(
-           grid,
-           CoordinatesType( 0 ),
-           grid->getDimensions() - CoordinatesType( 1 ),
-           userData );*/
-         /*const CoordinatesType begin( 0 );
-         const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 );
-         MeshFunction* _u = &u.template modifyData< Device >();
-         Cell entity( *grid );
-         for( Index x = begin.x(); x <= end.x(); x ++ )
-         {
-            entity.getCoordinates().x() = x;
-            entity.refresh();
-            WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity );
-         }*/
+         GridTraverserBenchmarkHelper< Grid >::noBCTraverserTest(
+            grid,
+            userData,
+            size );
       }
 
       void traverseUsingPureC()
-- 
GitLab