From 0f8452da3dfebf7079867847a6272ac0571c2648 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Sat, 18 Feb 2017 12:02:54 +0100
Subject: [PATCH] Fixed mesh traverser for CUDA

---
 src/TNL/Meshes/MeshDetails/Traverser_impl.h | 28 ++++++++++-----------
 1 file changed, 14 insertions(+), 14 deletions(-)

diff --git a/src/TNL/Meshes/MeshDetails/Traverser_impl.h b/src/TNL/Meshes/MeshDetails/Traverser_impl.h
index 666b19915e..9d449c8751 100644
--- a/src/TNL/Meshes/MeshDetails/Traverser_impl.h
+++ b/src/TNL/Meshes/MeshDetails/Traverser_impl.h
@@ -87,7 +87,7 @@ template< int EntitiesDimension,
           typename Mesh,
           typename UserData >
 __global__ void
-MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
+MeshTraverserBoundaryEntitiesKernel( const Mesh* mesh,
                                      UserData* userData,
                                      typename Mesh::GlobalIndexType entitiesCount )
 {
@@ -95,10 +95,10 @@ MeshTraverserBoundaryEntitiesKernel( Mesh* mesh,
         i < entitiesCount;
         i += blockDim.x * gridDim.x )
    {
-      const auto entityIndex = mesh.template getBoundaryEntityIndex< EntitiesDimension >( i );
-      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      const auto entityIndex = mesh->template getBoundaryEntityIndex< EntitiesDimension >( i );
+      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
       // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
-      EntitiesProcessor::processEntity( mesh, userData, entity );
+      EntitiesProcessor::processEntity( *mesh, *userData, entity );
    }
 }
 
@@ -107,7 +107,7 @@ template< int EntitiesDimension,
           typename Mesh,
           typename UserData >
 __global__ void
-MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
+MeshTraverserInteriorEntitiesKernel( const Mesh* mesh,
                                      UserData* userData,
                                      typename Mesh::GlobalIndexType entitiesCount )
 {
@@ -115,10 +115,10 @@ MeshTraverserInteriorEntitiesKernel( Mesh* mesh,
         i < entitiesCount;
         i += blockDim.x * gridDim.x )
    {
-      const auto entityIndex = mesh.template getInteriorEntityIndex< EntitiesDimension >( i );
-      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      const auto entityIndex = mesh->template getInteriorEntityIndex< EntitiesDimension >( i );
+      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
       // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
-      EntitiesProcessor::processEntity( mesh, userData, entity );
+      EntitiesProcessor::processEntity( *mesh, *userData, entity );
    }
 }
 
@@ -127,7 +127,7 @@ template< int EntitiesDimension,
           typename Mesh,
           typename UserData >
 __global__ void
-MeshTraverserAllEntitiesKernel( Mesh* mesh,
+MeshTraverserAllEntitiesKernel( const Mesh* mesh,
                                 UserData* userData,
                                 typename Mesh::GlobalIndexType entitiesCount )
 {
@@ -135,9 +135,9 @@ MeshTraverserAllEntitiesKernel( Mesh* mesh,
         entityIndex < entitiesCount;
         entityIndex += blockDim.x * gridDim.x )
    {
-      auto& entity = mesh.template getEntity< EntitiesDimension >( entityIndex );
+      auto& entity = mesh->template getEntity< EntitiesDimension >( entityIndex );
       // TODO: if the Mesh::IdType is void, then we should also pass the entityIndex
-      EntitiesProcessor::processEntity( mesh, userData, entity );
+      EntitiesProcessor::processEntity( *mesh, *userData, entity );
    }
 }
 
@@ -170,7 +170,7 @@ processBoundaryEntities( const MeshPointer& meshPointer,
    Devices::Cuda::synchronizeDevice();
    MeshTraverserBoundaryEntitiesKernel< EntitiesDimension, EntitiesProcessor >
       <<< gridSize, blockSize >>>
-      ( &meshPointer.template modifyData< Devices::Cuda >(),
+      ( &meshPointer.template getData< Devices::Cuda >(),
         &userDataPointer.template modifyData< Devices::Cuda >(),
         entitiesCount );
 #else
@@ -200,7 +200,7 @@ processInteriorEntities( const MeshPointer& meshPointer,
    Devices::Cuda::synchronizeDevice();
    MeshTraverserInteriorEntitiesKernel< EntitiesDimension, EntitiesProcessor >
       <<< gridSize, blockSize >>>
-      ( &meshPointer.template modifyData< Devices::Cuda >(),
+      ( &meshPointer.template getData< Devices::Cuda >(),
         &userDataPointer.template modifyData< Devices::Cuda >(),
         entitiesCount );
 #else
@@ -230,7 +230,7 @@ processAllEntities( const MeshPointer& meshPointer,
    Devices::Cuda::synchronizeDevice();
    MeshTraverserAllEntitiesKernel< EntitiesDimension, EntitiesProcessor >
       <<< gridSize, blockSize >>>
-      ( &meshPointer.template modifyData< Devices::Cuda >(),
+      ( &meshPointer.template getData< Devices::Cuda >(),
         &userDataPointer.template modifyData< Devices::Cuda >(),
         entitiesCount );
 #else
-- 
GitLab