Commit d1ef48e7 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

CUDA optimizations in tnlGridTraveser.

parent 0444bfbe
Loading
Loading
Loading
Loading
+69 −17
Original line number Diff line number Diff line
@@ -72,8 +72,7 @@ template< typename Real,
          typename Index,
          typename GridEntity,
          typename UserData,
          typename EntitiesProcessor,
          bool processOnlyBoundaryEntities >
          typename EntitiesProcessor >
__global__ void 
tnlGridTraverser1D(
   const tnlGrid< 1, Real, tnlCuda, Index >* grid,
@@ -95,13 +94,49 @@ tnlGridTraverser1D(
   
   if( coordinates.x() <= end->x() )
   {
      if( ! processOnlyBoundaryEntities || entity.isBoundaryEntity() )
      //if( ! processOnlyBoundaryEntities || entity.isBoundaryEntity() )
      {
         entity.refresh();
         EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
      }
   }
}

template< typename Real,
          typename Index,
          typename GridEntity,
          typename UserData,
          typename EntitiesProcessor >
__global__ void 
tnlGridBoundaryTraverser1D(
   const tnlGrid< 1, Real, tnlCuda, Index >* grid,
   UserData* userData,
   const typename GridEntity::CoordinatesType* begin,
   const typename GridEntity::CoordinatesType* end,
   const typename GridEntity::CoordinatesType* entityOrientation,
   const typename GridEntity::CoordinatesType* entityBasis )
{
   typedef Real RealType;
   typedef Index IndexType;
   typedef tnlGrid< 1, Real, tnlCuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;
   
   if( threadIdx.x == 0 )
   {   
      coordinates.x() = begin->x();
      GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
   }
   if( threadIdx.x == 1 )
   {   
      coordinates.x() = end->x();
      GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
   }
}

#endif

template< typename Real,           
@@ -129,13 +164,28 @@ processEntities(
   typename GridEntity::MeshType* kernelGrid = tnlCuda::passToDevice( grid );
   UserData* kernelUserData = tnlCuda::passToDevice( userData );
      
   if( processOnlyBoundaryEntities )
   {
      dim3 cudaBlockSize( 2 );
      dim3 cudaBlocks( 1 );
      tnlGridBoundaryTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            ( kernelGrid,
              kernelUserData,
              kernelBegin,
              kernelEnd,
              kernelEntityOrientation,
              kernelEntityBasis );
   }
   else
   {
      dim3 cudaBlockSize( 256 );
      dim3 cudaBlocks;
      cudaBlocks.x = tnlCuda::getNumberOfBlocks( end.x() - begin.x() + 1, cudaBlockSize.x );
      const IndexType cudaXGrids = tnlCuda::getNumberOfGrids( cudaBlocks.x );

      for( IndexType gridXIdx = 0; gridXIdx < cudaXGrids; gridXIdx ++ )
      tnlGridTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities >
         tnlGridTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            ( kernelGrid,
              kernelUserData,
@@ -144,6 +194,7 @@ processEntities(
              kernelEntityOrientation,
              kernelEntityBasis,
              gridXIdx );
   }
   cudaThreadSynchronize();
   checkCudaDevice;
   tnlCuda::freeFromDevice( kernelGrid );
@@ -316,6 +367,7 @@ processEntities(
              kernelEntityBasis,
              gridXIdx,
              gridYIdx );
      
   cudaThreadSynchronize();
   checkCudaDevice;   
   tnlCuda::freeFromDevice( kernelGrid );
+2 −1
Original line number Diff line number Diff line
@@ -52,7 +52,8 @@ class tnlExplicitUpdaterTraverserUserData
        rightHandSide( &rightHandSide ),
        u( &u ),
        fu( &fu )
      {};
      {
      };
};


+1 −1
Original line number Diff line number Diff line
@@ -23,7 +23,7 @@ template< typename Mesh, typename Real >class HeatEquationBenchmarkRhs
         typedef typename MeshEntity::MeshType::VertexType VertexType;
         VertexType v = entity.getCenter();
         return 0.0;
      };
      }
};

#endif /* HeatEquationBenchmarkRHS_H_ */