Commit cb834c84 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added GridTraverserBenchmarkHelper.

parent 1ace5365
Loading
Loading
Loading
Loading
+104 −18
Original line number Diff line number Diff line
@@ -28,6 +28,103 @@ namespace TNL {
   namespace Benchmarks {
      namespace Traversers {

template< typename Grid,
          typename Device = typename Grid::DeviceType >
class GridTraverserBenchmarkHelper{};

template< typename Grid >
class GridTraverserBenchmarkHelper< Grid, Devices::Host >
{
   public:

      using GridType = Grid;
      using GridPointer = Pointers::SharedPointer< Grid >;
      using RealType = typename GridType::RealType;
      using IndexType = typename GridType::IndexType;
      using CoordinatesType = typename Grid::CoordinatesType;
      using MeshFunction = Functions::MeshFunction< Grid >;
      using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >;
      using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >;
      using Traverser = Meshes::Traverser< Grid, Cell >;
      using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >;
      using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >;
      

      static void noBCTraverserTest( const GridPointer& grid,
                                     WriteOneTraverserUserDataType& userData,
                                     std::size_t size )
      {
         /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >(
           grid,
           CoordinatesType( 0 ),
           grid->getDimensions() - CoordinatesType( 1 ),
           userData );*/

         const CoordinatesType begin( 0 );
         const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 );
         //MeshFunction* _u = &u.template modifyData< Device >();
         Cell entity( *grid );
         for( IndexType x = begin.x(); x <= end.x(); x ++ )
         {
            entity.getCoordinates().x() = x;
            entity.refresh();
            WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity );
         }

      }
};

template< typename Grid >
class GridTraverserBenchmarkHelper< Grid, Devices::Cuda >
{
   public:

      using GridType = Grid;
      using GridPointer = Pointers::SharedPointer< Grid >;
      using RealType = typename GridType::RealType;
      using IndexType = typename GridType::IndexType;
      using CoordinatesType = typename Grid::CoordinatesType;
      using MeshFunction = Functions::MeshFunction< Grid >;
      using MeshFunctionPointer = Pointers::SharedPointer< MeshFunction >;
      using Cell = typename Grid::template EntityType< 1, Meshes::GridEntityNoStencilStorage >;
      using Traverser = Meshes::Traverser< Grid, Cell >;
      using WriteOneTraverserUserDataType = WriteOneUserData< MeshFunctionPointer >;
      using WriteOneEntitiesProcessorType = WriteOneEntitiesProcessor< WriteOneTraverserUserDataType >;


      static void noBCTraverserTest( const GridPointer& grid,
                                     WriteOneTraverserUserDataType& userData,
                                     std::size_t size )
      {
#ifdef HAVE_CUDA
            dim3 blockSize( 256 ), blocksCount, gridsCount;
            Devices::Cuda::setupThreads(
               blockSize,
               blocksCount,
               gridsCount,
               size );
            dim3 gridIdx;
            for( gridIdx.x = 0; gridIdx.x < gridsCount.x; gridIdx.x++ )
            {
               dim3 gridSize;
               Devices::Cuda::setupGrid(
                  blocksCount,
                  gridsCount,
                  gridIdx,
                  gridSize );
               Meshes::GridTraverser1D< RealType, IndexType, Cell, WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
               <<< blocksCount, blockSize >>>
               ( &grid.template getData< Devices::Cuda >(),
                 userData,
                 CoordinatesType( 0 ),
                 CoordinatesType( size ) - CoordinatesType( 1 ),
                 gridIdx.x );

            }
#endif
      }
};

template< typename Device,
          typename Real,
          typename Index >
@@ -130,24 +227,13 @@ class GridTraversersBenchmark< 1, Device, Real, Index >
      void writeOneUsingTraverser()
      {
         using CoordinatesType = typename Grid::CoordinatesType;
         traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
            ( grid, userData );
         //traverser.template processAllEntities< WriteOneTraverserUserDataType, WriteOneEntitiesProcessorType >
         //   ( grid, userData );
         
         /*Meshes::GridTraverser< Grid >::template processEntities< Cell, WriteOneEntitiesProcessorType, WriteOneTraverserUserDataType, false >(
         GridTraverserBenchmarkHelper< Grid >::noBCTraverserTest(
            grid,
           CoordinatesType( 0 ),
           grid->getDimensions() - CoordinatesType( 1 ),
           userData );*/
         /*const CoordinatesType begin( 0 );
         const CoordinatesType end = CoordinatesType( size ) - CoordinatesType( 1 );
         MeshFunction* _u = &u.template modifyData< Device >();
         Cell entity( *grid );
         for( Index x = begin.x(); x <= end.x(); x ++ )
         {
            entity.getCoordinates().x() = x;
            entity.refresh();
            WriteOneEntitiesProcessorType::processEntity( entity.getMesh(), userData, entity );
         }*/
            userData,
            size );
      }

      void traverseUsingPureC()