Commit a768f494 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Passing TraverserUserData via SharedPointer

Apart from squashing begin, end, entityOrientation and entityBasis
parameters into TraverserKernelData, this does not improve performance
of the LinearSystemAssembler, ExplicitUpdater etc., but in tnl-mhfem it
allows us to pass MeshDependentData, which is already available as a
SharedPointer, directly to the grid traverser without duplicating the
transfer to GPU.
parent 673b7a22
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -119,7 +119,8 @@ evaluateEntities( OutMeshFunctionPointer& meshFunction,
   typedef Functions::MeshFunctionEvaluatorAssignmentEntitiesProcessor< MeshType, TraverserUserData > AssignmentEntitiesProcessor;
   typedef Functions::MeshFunctionEvaluatorAdditionEntitiesProcessor< MeshType, TraverserUserData > AdditionEntitiesProcessor;
 
   TraverserUserData userData( &function.template getData< DeviceType >(),
   SharedPointer< TraverserUserData, DeviceType >
      userData( &function.template getData< DeviceType >(),
                time,
                &meshFunction.template modifyData< DeviceType >(),
                outFunctionMultiplicator,
+4 −3
Original line number Diff line number Diff line
@@ -28,7 +28,8 @@ getCompressedRowsLengths( const MeshPointer& meshPointer,
                          CompressedRowsLengthsVectorPointer& rowLengthsPointer ) const
{
   {
      TraversalUserData userData( &differentialOperatorPointer.template getData< DeviceType >(),
      SharedPointer< TraversalUserData, DeviceType >
         userData( &differentialOperatorPointer.template getData< DeviceType >(),
                   &boundaryConditionsPointer.template getData< DeviceType >(),
                   &rowLengthsPointer.template modifyData< DeviceType >() );
      Meshes::Traverser< MeshType, EntityType > meshTraversal;
+13 −11
Original line number Diff line number Diff line
@@ -10,6 +10,8 @@

#pragma once

#include <TNL/SharedPointer.h>


namespace TNL {
namespace Meshes {
@@ -50,7 +52,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Host, Index > >
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

/****
@@ -65,7 +67,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >
      typedef Meshes::Grid< 1, Real, Devices::Cuda, Index > GridType;
      typedef SharedPointer< GridType > GridPointer;
      typedef Real RealType;
      typedef Devices::Host DeviceType;
      typedef Devices::Cuda DeviceType;
      typedef Index IndexType;
      typedef typename GridType::CoordinatesType CoordinatesType;
 
@@ -81,7 +83,7 @@ class GridTraverser< Meshes::Grid< 1, Real, Devices::Cuda, Index > >
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

/****
@@ -110,11 +112,11 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Host, Index > >
      static void
      processEntities(
         const GridPointer& gridPointer,
         const CoordinatesType begin,
         const CoordinatesType end,
         const CoordinatesType& begin,
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

/****
@@ -129,7 +131,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Cuda, Index > >
      typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
      typedef SharedPointer< GridType > GridPointer;
      typedef Real RealType;
      typedef Devices::Host DeviceType;
      typedef Devices::Cuda DeviceType;
      typedef Index IndexType;
      typedef typename GridType::CoordinatesType CoordinatesType;
 
@@ -147,7 +149,7 @@ class GridTraverser< Meshes::Grid< 2, Real, Devices::Cuda, Index > >
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

/****
@@ -181,7 +183,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Host, Index > >
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

/****
@@ -196,7 +198,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Cuda, Index > >
      typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
      typedef SharedPointer< GridType > GridPointer;
      typedef Real RealType;
      typedef Devices::Host DeviceType;
      typedef Devices::Cuda DeviceType;
      typedef Index IndexType;
      typedef typename GridType::CoordinatesType CoordinatesType;
 
@@ -215,7 +217,7 @@ class GridTraverser< Meshes::Grid< 3, Real, Devices::Cuda, Index > >
         const CoordinatesType& end,
         const CoordinatesType& entityOrientation,
         const CoordinatesType& entityBasis,
         UserData& userData );
         SharedPointer< UserData, DeviceType >& userData );
};

} // namespace Meshes
+86 −122
Original line number Diff line number Diff line
@@ -10,9 +10,32 @@

#pragma once

#include <TNL/UniquePointer.h>


namespace TNL {
namespace Meshes {

template< typename CoordinatesType >
struct TraverserKernelData
{
   CoordinatesType begin;
   CoordinatesType end;
   CoordinatesType entityOrientation;
   CoordinatesType entityBasis;

   TraverserKernelData( CoordinatesType begin,
                        CoordinatesType end,
                        CoordinatesType entityOrientation,
                        CoordinatesType entityBasis )
   : begin( begin ),
     end( end ),
     entityOrientation( entityOrientation ),
     entityBasis( entityBasis )
   {}
};


/****
 * 1D traverser, host
 */
@@ -31,10 +54,8 @@ processEntities(
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{

   
   GridEntity entity( *gridPointer );
   entity.setOrientation( entityOrientation );
   entity.setBasis( entityBasis );
@@ -42,10 +63,10 @@ processEntities(
   {
      entity.getCoordinates() = begin;
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
      EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
      entity.getCoordinates() = end;
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
      EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
   }
   else
   {
@@ -54,7 +75,7 @@ processEntities(
           entity.getCoordinates().x() ++ )
      {
         entity.refresh();
         EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
         EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
      }
   }
}
@@ -72,10 +93,7 @@ __global__ void
GridTraverser1D(
   const Meshes::Grid< 1, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const typename GridEntity::CoordinatesType* begin,
   const typename GridEntity::CoordinatesType* end,
   const typename GridEntity::CoordinatesType* entityOrientation,
   const typename GridEntity::CoordinatesType* entityBasis,
   const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData,
   const Index gridIdx )
{
   typedef Real RealType;
@@ -83,12 +101,12 @@ GridTraverser1D(
   typedef Meshes::Grid< 1, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;
 
   coordinates.x() = begin->x() + ( gridIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.x() = kernelData->begin.x() + ( gridIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
 
   GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
   GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );
   
   entity.refresh();
   if( coordinates.x() <= end->x() )
   if( coordinates.x() <= kernelData->end.x() )
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
}

@@ -101,10 +119,7 @@ __global__ void
GridBoundaryTraverser1D(
   const Meshes::Grid< 1, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const typename GridEntity::CoordinatesType* begin,
   const typename GridEntity::CoordinatesType* end,
   const typename GridEntity::CoordinatesType* entityOrientation,
   const typename GridEntity::CoordinatesType* entityBasis )
   const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData )
{
   typedef Real RealType;
   typedef Index IndexType;
@@ -113,15 +128,15 @@ GridBoundaryTraverser1D(
 
   if( threadIdx.x == 0 )
   {
      coordinates.x() = begin->x();
      GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
      coordinates.x() = kernelData->begin.x();
      GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
   }
   if( threadIdx.x == 1 )
   {
      coordinates.x() = end->x();
      GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
      coordinates.x() = kernelData->end.x();
      GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );
      entity.refresh();
      EntitiesProcessor::processEntity( entity.getMesh(), *userData, entity );
   }
@@ -144,15 +159,11 @@ processEntities(
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
#ifdef HAVE_CUDA
   CoordinatesType* kernelBegin = Devices::Cuda::passToDevice( begin );
   CoordinatesType* kernelEnd = Devices::Cuda::passToDevice( end );
   CoordinatesType* kernelEntityOrientation = Devices::Cuda::passToDevice( entityOrientation );
   CoordinatesType* kernelEntityBasis = Devices::Cuda::passToDevice( entityBasis );
   //typename GridEntity::MeshType* kernelGrid = Devices::Cuda::passToDevice( *gridPointer );
   UserData* kernelUserData = Devices::Cuda::passToDevice( userData );
   UniquePointer< TraverserKernelData< CoordinatesType >, Devices::Cuda >
      kernelData( begin, end, entityOrientation, entityBasis );

   Devices::Cuda::synchronizeDevice();
   if( processOnlyBoundaryEntities )
@@ -162,11 +173,8 @@ processEntities(
      GridBoundaryTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              kernelUserData,
              kernelBegin,
              kernelEnd,
              kernelEntityOrientation,
              kernelEntityBasis );
              &userDataPointer.template modifyData< Devices::Cuda >(),
              &kernelData.template getData< Devices::Cuda >() );
   }
   else
   {
@@ -179,22 +187,12 @@ processEntities(
         GridTraverser1D< Real, Index, GridEntity, UserData, EntitiesProcessor >
            <<< cudaBlocks, cudaBlockSize >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              kernelUserData,
              kernelBegin,
              kernelEnd,
              kernelEntityOrientation,
              kernelEntityBasis,
              &userDataPointer.template modifyData< Devices::Cuda >(),
              &kernelData.template getData< Devices::Cuda >(),
              gridXIdx );
   }
   cudaThreadSynchronize();
   checkCudaDevice;
   //Devices::Cuda::freeFromDevice( kernelGrid );
   Devices::Cuda::freeFromDevice( kernelBegin );
   Devices::Cuda::freeFromDevice( kernelEnd );
   Devices::Cuda::freeFromDevice( kernelEntityOrientation );
   Devices::Cuda::freeFromDevice( kernelEntityBasis );
   Devices::Cuda::freeFromDevice( kernelUserData );
   checkCudaDevice;
#endif
}

@@ -215,11 +213,11 @@ void
GridTraverser< Meshes::Grid< 2, Real, Devices::Host, Index > >::
processEntities(
   const GridPointer& gridPointer,
   const CoordinatesType begin,
   const CoordinatesType end,
   const CoordinatesType& begin,
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
   GridEntity entity( *gridPointer );
   entity.setOrientation( entityOrientation );
@@ -234,10 +232,10 @@ processEntities(
         {
            entity.getCoordinates().y() = begin.y();
            entity.refresh();
            EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
            EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            entity.getCoordinates().y() = end.y();
            entity.refresh();
            EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
            EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
         }
      if( XOrthogonalBoundary )
         for( entity.getCoordinates().y() = begin.y();
@@ -246,10 +244,10 @@ processEntities(
         {
            entity.getCoordinates().x() = begin.x();
            entity.refresh();
            EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
            EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            entity.getCoordinates().x() = end.x();
            entity.refresh();
            EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
            EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
         }
   }
   else
@@ -263,7 +261,7 @@ processEntities(
              entity.getCoordinates().x() ++ )
         {
            entity.refresh();
            EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
            EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
         }
   }
}
@@ -282,23 +280,20 @@ __global__ void
GridTraverser2D(
   const Meshes::Grid< 2, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const typename GridEntity::CoordinatesType* begin,
   const typename GridEntity::CoordinatesType* end,
   const typename GridEntity::CoordinatesType* entityOrientation,
   const typename GridEntity::CoordinatesType* entityBasis,
   const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData,
   const Index gridXIdx,
   const Index gridYIdx )
{
   typedef Meshes::Grid< 2, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = begin->x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = begin->y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   coordinates.x() = kernelData->begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = kernelData->begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   
   GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
   GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );

   if( entity.getCoordinates().x() <= end->x() &&
       entity.getCoordinates().y() <= end->y() )
   if( entity.getCoordinates().x() <= kernelData->end.x() &&
       entity.getCoordinates().y() <= kernelData->end.y() )
   {
      entity.refresh();
      if( ! processOnlyBoundaryEntities || entity.isBoundaryEntity() )
@@ -329,16 +324,11 @@ processEntities(
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
#ifdef HAVE_CUDA   
   CoordinatesType* kernelBegin = Devices::Cuda::passToDevice( begin );
   CoordinatesType* kernelEnd = Devices::Cuda::passToDevice( end );
   CoordinatesType* kernelEntityOrientation = Devices::Cuda::passToDevice( entityOrientation );
   CoordinatesType* kernelEntityBasis = Devices::Cuda::passToDevice( entityBasis );
   //typename GridEntity::MeshType* kernelGrid = Devices::Cuda::passToDevice( *gridPointer );
   UserData* kernelUserData = Devices::Cuda::passToDevice( userData );
   checkCudaDevice;   
   UniquePointer< TraverserKernelData< CoordinatesType >, Devices::Cuda >
      kernelData( begin, end, entityOrientation, entityBasis );

   dim3 cudaBlockSize( 16, 16 );
   dim3 cudaBlocks;
@@ -353,23 +343,13 @@ processEntities(
         GridTraverser2D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities >
            <<< cudaBlocks, cudaBlockSize >>>
            ( &gridPointer.template getData< Devices::Cuda >(),
              kernelUserData,
              kernelBegin,
              kernelEnd,
              kernelEntityOrientation,
              kernelEntityBasis,
              &userDataPointer.template modifyData< Devices::Cuda >(),
              &kernelData.template getData< Devices::Cuda >(),
              gridXIdx,
              gridYIdx );
 
   cudaThreadSynchronize();
   checkCudaDevice;
   //Devices::Cuda::freeFromDevice( kernelGrid );
   Devices::Cuda::freeFromDevice( kernelBegin );
   Devices::Cuda::freeFromDevice( kernelEnd );
   Devices::Cuda::freeFromDevice( kernelEntityOrientation );
   Devices::Cuda::freeFromDevice( kernelEntityBasis );
   Devices::Cuda::freeFromDevice( kernelUserData );
   checkCudaDevice;
#endif
}

@@ -394,7 +374,7 @@ processEntities(
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
   GridEntity entity( *gridPointer );
   entity.setOrientation( entityOrientation );
@@ -412,10 +392,10 @@ processEntities(
            {
               entity.getCoordinates().z() = begin.z();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
               entity.getCoordinates().z() = end.z();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            }
      if( YOrthogonalBoundary )
         for( entity.getCoordinates().z() = begin.z();
@@ -427,10 +407,10 @@ processEntities(
            {
               entity.getCoordinates().y() = begin.y();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
               entity.getCoordinates().y() = end.y();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            }
      if( XOrthogonalBoundary )
         for( entity.getCoordinates().z() = begin.z();
@@ -442,10 +422,10 @@ processEntities(
            {
               entity.getCoordinates().x() = begin.x();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
               entity.getCoordinates().x() = end.x();
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            }
   }
   else
@@ -461,7 +441,7 @@ processEntities(
                 entity.getCoordinates().x() ++ )
            {
               entity.refresh();
               EntitiesProcessor::processEntity( entity.getMesh(), userData, entity );
               EntitiesProcessor::processEntity( entity.getMesh(), *userDataPointer, entity );
            }
   }
}
@@ -480,10 +460,7 @@ __global__ void
GridTraverser3D(
   const Meshes::Grid< 3, Real, Devices::Cuda, Index >* grid,
   UserData* userData,
   const typename GridEntity::CoordinatesType* begin,
   const typename GridEntity::CoordinatesType* end,
   const typename GridEntity::CoordinatesType* entityOrientation,
   const typename GridEntity::CoordinatesType* entityBasis,
   const TraverserKernelData< typename GridEntity::CoordinatesType >* kernelData,
   const Index gridXIdx,
   const Index gridYIdx,
   const Index gridZIdx )
@@ -491,15 +468,15 @@ GridTraverser3D(
   typedef Meshes::Grid< 3, Real, Devices::Cuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   coordinates.x() = begin->x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = begin->y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;
   coordinates.z() = begin->z() + ( gridZIdx * Devices::Cuda::getMaxGridSize() + blockIdx.z ) * blockDim.z + threadIdx.z;
   coordinates.x() = kernelData->begin.x() + ( gridXIdx * Devices::Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = kernelData->begin.y() + ( gridYIdx * Devices::Cuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;
   coordinates.z() = kernelData->begin.z() + ( gridZIdx * Devices::Cuda::getMaxGridSize() + blockIdx.z ) * blockDim.z + threadIdx.z;
 
   GridEntity entity( *grid, coordinates, *entityOrientation, *entityBasis );
   GridEntity entity( *grid, coordinates, kernelData->entityOrientation, kernelData->entityBasis );

   if( entity.getCoordinates().x() <= end->x() &&
       entity.getCoordinates().y() <= end->y() &&
       entity.getCoordinates().z() <= end->z() )
   if( entity.getCoordinates().x() <= kernelData->end.x() &&
       entity.getCoordinates().y() <= kernelData->end.y() &&
       entity.getCoordinates().z() <= kernelData->end.z() )
   {
      entity.refresh();
      if( ! processOnlyBoundaryEntities || entity.isBoundaryEntity() )
@@ -531,15 +508,11 @@ processEntities(
   const CoordinatesType& end,
   const CoordinatesType& entityOrientation,
   const CoordinatesType& entityBasis,
   UserData& userData )
   SharedPointer< UserData, DeviceType >& userDataPointer )
{
#ifdef HAVE_CUDA   
   CoordinatesType* kernelBegin = Devices::Cuda::passToDevice( begin );
   CoordinatesType* kernelEnd = Devices::Cuda::passToDevice( end );
   CoordinatesType* kernelEntityOrientation = Devices::Cuda::passToDevice( entityOrientation );
   CoordinatesType* kernelEntityBasis = Devices::Cuda::passToDevice( entityBasis );
   //typename GridEntity::MeshType* kernelGrid = Devices::Cuda::passToDevice( grid );
   UserData* kernelUserData = Devices::Cuda::passToDevice( userData );
   UniquePointer< TraverserKernelData< CoordinatesType >, Devices::Cuda >
      kernelData( begin, end, entityOrientation, entityBasis );
      
   dim3 cudaBlockSize( 8, 8, 8 );
   dim3 cudaBlocks;
@@ -557,23 +530,14 @@ processEntities(
            GridTraverser3D< Real, Index, GridEntity, UserData, EntitiesProcessor, processOnlyBoundaryEntities >
               <<< cudaBlocks, cudaBlockSize >>>
               ( &gridPointer.template getData< Devices::Cuda >(),
                 kernelUserData,
                 kernelBegin,
                 kernelEnd,
                 kernelEntityOrientation,
                 kernelEntityBasis,
                 &userDataPointer.template modifyData< Devices::Cuda >(),
                 &kernelData.template getData< Devices::Cuda >(),
                 gridXIdx,
                 gridYIdx,
                 gridZIdx );

   cudaThreadSynchronize();
   checkCudaDevice;
   //Devices::Cuda::freeFromDevice( kernelGrid );
   Devices::Cuda::freeFromDevice( kernelBegin );
   Devices::Cuda::freeFromDevice( kernelEnd );
   Devices::Cuda::freeFromDevice( kernelEntityOrientation );
   Devices::Cuda::freeFromDevice( kernelEntityBasis );
   Devices::Cuda::freeFromDevice( kernelUserData );
   checkCudaDevice;
#endif
}

+6 −6
Original line number Diff line number Diff line
@@ -32,17 +32,17 @@ class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 1 >
      template< typename UserData,
                typename EntitiesProcessor >
      void processBoundaryEntities( const GridPointer& gridPointer,
                                    UserData& userData ) const;
                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;

      template< typename UserData,
                typename EntitiesProcessor >
      void processInteriorEntities( const GridPointer& gridPointer,
                                    UserData& userData ) const;
                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;
 
      template< typename UserData,
                typename EntitiesProcessor >
      void processAllEntities( const GridPointer& gridPointer,
                               UserData& userData ) const;
                               SharedPointer< UserData, DeviceType >& userDataPointer ) const;
 
};

@@ -64,17 +64,17 @@ class Traverser< Meshes::Grid< 1, Real, Device, Index >, GridEntity, 0 >
      template< typename UserData,
                typename EntitiesProcessor >
      void processBoundaryEntities( const GridPointer& gridPointer,
                                    UserData& userData ) const;
                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;

      template< typename UserData,
                typename EntitiesProcessor >
      void processInteriorEntities( const GridPointer& gridPointer,
                                    UserData& userData ) const;
                                    SharedPointer< UserData, DeviceType >& userDataPointer ) const;
 
      template< typename UserData,
                typename EntitiesProcessor >
      void processAllEntities( const GridPointer& gridPointer,
                               UserData& userData ) const;
                               SharedPointer< UserData, DeviceType >& userDataPointer ) const;
};

} // namespace Meshes
Loading