Commit 7658b3b3 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Optimizing CUDA explicit solver.

parent 2ded6e73
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -304,8 +304,6 @@ GridEntity( const GridType& grid )
  neighbourEntitiesStorage( *this )
{
   this->coordinates = CoordinatesType( ( Index ) 0 );
   this->orientation = EntityOrientationType( ( Index ) 0 );
   this->basis = EntityBasisType( ( Index ) 1 );
}

template< int Dimensions,
@@ -324,8 +322,6 @@ GridEntity( const GridType& grid,
  coordinates( coordinates ),
  neighbourEntitiesStorage( *this )
{
   this->orientation = EntityOrientationType( ( Index ) 0 );
   this->basis = EntityBasisType( ( Index ) 1 );
}

template< int Dimensions,
+16 −9
Original line number Diff line number Diff line
@@ -315,22 +315,28 @@ __global__ void
GridTraverser2D(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData,
   //const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData,
   const typename GridEntity::CoordinatesType begin,
   const typename GridEntity::CoordinatesType end,
   const typename GridEntity::CoordinatesType entityOrientation,
   const typename GridEntity::CoordinatesType entityBasis,
   const Index gridXIdx,
   const Index gridYIdx )
{
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = kernelData->begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = kernelData->begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   //coordinates.x() = kernelData->begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   //coordinates.y() = kernelData->begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   
   coordinates.x() = begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   

   if( coordinates.x() <= kernelData->end.x() &&
       coordinates.y() <= kernelData->end.y() )
   if( coordinates.x() <= end.x() &&
       coordinates.y() <= end.y() )
   {
      GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );
      GridEntity entity( *grid, coordinates, entityOrientation, entityBasis );
      entity.refresh();
      if( ! processOnlyBoundaryEntities || entity.isBoundaryEntity() )
      {
@@ -363,8 +369,8 @@ processEntities(
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
#ifdef HAVE_CUDA   
   UniquePointer< TraverserKernelData< CoordinatesType >, Devices::Cuda >
      kernelData( begin, end, entityOrientation, entityBasis );
   //UniquePointer< TraverserKernelData< CoordinatesType >, Devices::Cuda >
   //   kernelData( begin, end, entityOrientation, entityBasis );

   dim3 cudaBlockSize( 16, 16 );
   dim3 cudaBlocks;
@@ -380,7 +386,8 @@ processEntities(
            <<< cudaBlocks, cudaBlockSize >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              &userDataPointer.template modifyData< Devices::Cuda >(),
              &kernelData.template getData< Devices::Cuda >(),
              //&kernelData.template getData< Devices::Cuda >(),
              begin, end, entityOrientation, entityBasis,
              gridXIdx,
              gridYIdx );
 
+0 −4
Original line number Diff line number Diff line
@@ -268,10 +268,6 @@ class GridEntity< Meshes::Grid< Dimensions, Real, Device, Index >, Dimensions, C
 
      CoordinatesType coordinates;
 
      EntityOrientationType orientation;
 
      EntityBasisType basis;
 
      NeighbourGridEntitiesStorageType neighbourEntitiesStorage;
 
      //__cuda_callable__ inline
+6 −6
Original line number Diff line number Diff line
@@ -222,9 +222,9 @@ getExplicitRHS( const RealType& time,
    */
   
   //cout << "u = " << u << endl;
   std::cerr << "==========================================================================================" << std::endl;
   std::cerr << "==========================================================================================" << std::endl;
   std::cerr << "==========================================================================================" << std::endl;
   //std::cerr << "==========================================================================================" << std::endl;
   //std::cerr << "==========================================================================================" << std::endl;
   //std::cerr << "==========================================================================================" << std::endl;
   this->bindDofs( meshPointer, uDofs );
   MeshFunctionPointer fuPointer( meshPointer, fuDofs );
   Solvers::PDE::ExplicitUpdater< Mesh, MeshFunctionType, DifferentialOperator, BoundaryCondition, RightHandSide > explicitUpdater;
@@ -236,9 +236,9 @@ getExplicitRHS( const RealType& time,
      this->rightHandSidePointer,
      this->uPointer,
      fuPointer );
   std::cerr << "******************************************************************************************" << std::endl;
   std::cerr << "******************************************************************************************" << std::endl;
   std::cerr << "******************************************************************************************" << std::endl;
   //std::cerr << "******************************************************************************************" << std::endl;
   //std::cerr << "******************************************************************************************" << std::endl;
   //std::cerr << "******************************************************************************************" << std::endl;
   /*Solvers::PDE::BoundaryConditionsSetter< MeshFunctionType, BoundaryCondition > boundaryConditionsSetter;
   boundaryConditionsSetter.template apply< typename Mesh::Cell >(
      this->boundaryConditionPointer,
+7 −2
Original line number Diff line number Diff line
@@ -83,11 +83,16 @@ class ExplicitTimeStepper

   Problem* problem;

   MeshPointer mesh;

   RealType timeStep;

   MeshDependentDataPointer meshDependentData;
   /****
    * The pointers on the shared pointer is important here to avoid 
    * memory deallocation in the assignment operator in SharedPointer.
    */
   MeshDependentDataPointer* meshDependentData;
   
   const MeshPointer* mesh;
 
   Timer preIterateTimer, explicitUpdaterTimer, mainTimer, postIterateTimer;
 
Loading