Commit 0f8452da authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed mesh traverser for CUDA

parent bf51a8ee
Loading
Loading
Loading
Loading
+14 −14
Original line number Diff line number Diff line
@@ -87,7 +87,7 @@ template< int EntitiesDimension,
          typename Mesh,
          typename UserData >
__global__ void
MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
MeshTraverserBoundaryEntitiesKernel( const Mesh* mesh,
                                     UserData* userData,
                                     typename Mesh::GlobalIndexType entitiesCount )
{
@@ -95,10 +95,10 @@ MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
        i < entitiesCount;
        i += blockDim.x * gridDim.x )
   {
      const auto entityIndex = mesh.template getBoundaryEntityIndex< EntitiesDimension >( i );
      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
      const auto entityIndex = mesh->template getBoundaryEntityIndex< EntitiesDimension >( i );
      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
      EntitiesProcessor::processEntity( mesh, userData, entity );
      EntitiesProcessor::processEntity( *mesh, *userData, entity );
   }
}

@@ -107,7 +107,7 @@ template< int EntitiesDimension,
          typename Mesh,
          typename UserData >
__global__ void
MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
MeshTraverserInteriorEntitiesKernel( const Mesh* mesh,
                                     UserData* userData,
                                     typename Mesh::GlobalIndexType entitiesCount )
{
@@ -115,10 +115,10 @@ MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
        i < entitiesCount;
        i += blockDim.x * gridDim.x )
   {
      const auto entityIndex = mesh.template getInteriorEntityIndex< EntitiesDimension >( i );
      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
      const auto entityIndex = mesh->template getInteriorEntityIndex< EntitiesDimension >( i );
      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
      EntitiesProcessor::processEntity( mesh, userData, entity );
      EntitiesProcessor::processEntity( *mesh, *userData, entity );
   }
}

@@ -127,7 +127,7 @@ template< int EntitiesDimension,
          typename Mesh,
          typename UserData >
__global__ void
MeshTraverserAllEntitiesKernel( Mesh* mesh,
MeshTraverserAllEntitiesKernel( const Mesh* mesh,
                                UserData* userData,
                                typename Mesh::GlobalIndexType entitiesCount )
{
@@ -135,9 +135,9 @@ MeshTraverserAllEntitiesKernel( Mesh* mesh,
        entityIndex < entitiesCount;
        entityIndex += blockDim.x * gridDim.x )
   {
      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
      // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
      EntitiesProcessor::processEntity( mesh, userData, entity );
      EntitiesProcessor::processEntity( *mesh, *userData, entity );
   }
}

@@ -170,7 +170,7 @@ processBoundaryEntities( const MeshPointer& meshPointer,
   Devices::Cuda::synchronizeDevice();
   MeshTraverserBoundaryEntitiesKernel< EntitiesDimension, EntitiesProcessor >
      <<< gridSize, blockSize >>>
      ( &meshPointer.template modifyData< Devices::Cuda >(),
      ( &meshPointer.template getData< Devices::Cuda >(),
        &userDataPointer.template modifyData< Devices::Cuda >(),
        entitiesCount );
#else
@@ -200,7 +200,7 @@ processInteriorEntities( const MeshPointer& meshPointer,
   Devices::Cuda::synchronizeDevice();
   MeshTraverserInteriorEntitiesKernel< EntitiesDimension, EntitiesProcessor >
      <<< gridSize, blockSize >>>
      ( &meshPointer.template modifyData< Devices::Cuda >(),
      ( &meshPointer.template getData< Devices::Cuda >(),
        &userDataPointer.template modifyData< Devices::Cuda >(),
        entitiesCount );
#else
@@ -230,7 +230,7 @@ processAllEntities( const MeshPointer& meshPointer,
   Devices::Cuda::synchronizeDevice();
   MeshTraverserAllEntitiesKernel< EntitiesDimension, EntitiesProcessor >
      <<< gridSize, blockSize >>>
      ( &meshPointer.template modifyData< Devices::Cuda >(),
      ( &meshPointer.template getData< Devices::Cuda >(),
        &userDataPointer.template modifyData< Devices::Cuda >(),
        entitiesCount );
#else