From bb045e76958f3b54a244b993e8267d49cbbee4c1 Mon Sep 17 00:00:00 2001
From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz>
Date: Thu, 5 Apr 2018 11:48:06 +0200
Subject: [PATCH] Added CUDA switch.

---
 .../tnlDirectEikonalMethodsBase.h             |  17 +-
 .../tnlFastSweepingMethod2D_impl.h            | 234 +++++++++---------
 2 files changed, 129 insertions(+), 122 deletions(-)

diff --git a/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlDirectEikonalMethodsBase.h b/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlDirectEikonalMethodsBase.h
index 4dc8e9228a..2dabb5f1a7 100644
--- a/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlDirectEikonalMethodsBase.h
+++ b/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlDirectEikonalMethodsBase.h
@@ -9,6 +9,7 @@
 
 #include <TNL/Meshes/Grid.h>
 #include <TNL/Functions/MeshFunction.h>
+#include <TNL/Devices/Cuda.h>
 
 using namespace TNL;
 
@@ -36,8 +37,8 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 1, Real, Device, Index > >
                           InterfaceMapType& interfaceMap );
       
       template< typename MeshEntity >
-      void updateCell( MeshFunctionType& u,
-                       const MeshEntity& cell );
+      __cuda_callable__ void updateCell( MeshFunctionType& u,
+                                         const MeshEntity& cell );
       
 };
 
@@ -61,9 +62,9 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 2, Real, Device, Index > >
                           InterfaceMapType& interfaceMap );
       
       template< typename MeshEntity >
-      void updateCell( MeshFunctionType& u,
-                       const MeshEntity& cell,
-                       const RealType velocity = 1.0 );
+      __cuda_callable__ void updateCell( MeshFunctionType& u,
+                                         const MeshEntity& cell,
+                                         const RealType velocity = 1.0 );
 };
 
 template< typename Real,
@@ -85,9 +86,9 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 3, Real, Device, Index > >
                           InterfaceMapType& interfaceMap );
       
       template< typename MeshEntity >
-      void updateCell( MeshFunctionType& u,
-                       const MeshEntity& cell,
-                       const RealType velocity = 1.0);
+      __cuda_callable__ void updateCell( MeshFunctionType& u,
+                                         const MeshEntity& cell,
+                                         const RealType velocity = 1.0);
       
       /*Real sort( Real a, Real b, Real c,
                  const RealType& ha,
diff --git a/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlFastSweepingMethod2D_impl.h b/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlFastSweepingMethod2D_impl.h
index e90743a80d..b727f0cd99 100644
--- a/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlFastSweepingMethod2D_impl.h
+++ b/src/TNL/Experimental/Hamilton-Jacobi/Solvers/hamilton-jacobi/tnlFastSweepingMethod2D_impl.h
@@ -71,137 +71,143 @@ solve( const MeshPointer& mesh,
    IndexType iteration( 0 );
    while( iteration < this->maxIterations )
    {
-
-      for( cell.getCoordinates().y() = 0;
-           cell.getCoordinates().y() < mesh->getDimensions().y();
-           cell.getCoordinates().y()++ )
+      if( std::is_same< DeviceType, Devices::Host >::value )
       {
-         for( cell.getCoordinates().x() = 0;
-              cell.getCoordinates().x() < mesh->getDimensions().x();
-              cell.getCoordinates().x()++ )
+         for( cell.getCoordinates().y() = 0;
+              cell.getCoordinates().y() < mesh->getDimensions().y();
+              cell.getCoordinates().y()++ )
+         {
+            for( cell.getCoordinates().x() = 0;
+                 cell.getCoordinates().x() < mesh->getDimensions().x();
+                 cell.getCoordinates().x()++ )
+               {
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )
+                     this->updateCell( aux, cell );
+               }
+         }
+
+         //aux.save( "aux-1.tnl" );
+
+         for( cell.getCoordinates().y() = 0;
+              cell.getCoordinates().y() < mesh->getDimensions().y();
+              cell.getCoordinates().y()++ )
+         {
+            for( cell.getCoordinates().x() = mesh->getDimensions().x() - 1;
+                 cell.getCoordinates().x() >= 0 ;
+                 cell.getCoordinates().x()-- )		
+               {
+                  //std::cerr << "2 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
+         }
+
+         //aux.save( "aux-2.tnl" );
+
+         for( cell.getCoordinates().y() = mesh->getDimensions().y() - 1;
+              cell.getCoordinates().y() >= 0 ;
+              cell.getCoordinates().y()-- )
             {
-               cell.refresh();
-               if( ! interfaceMap( cell ) )
-                  this->updateCell( aux, cell );
+            for( cell.getCoordinates().x() = 0;
+                 cell.getCoordinates().x() < mesh->getDimensions().x();
+                 cell.getCoordinates().x()++ )
+               {
+                  //std::cerr << "3 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
             }
-      }
-      
-      //aux.save( "aux-1.tnl" );
 
-      for( cell.getCoordinates().y() = 0;
-           cell.getCoordinates().y() < mesh->getDimensions().y();
-           cell.getCoordinates().y()++ )
-      {
-         for( cell.getCoordinates().x() = mesh->getDimensions().x() - 1;
-              cell.getCoordinates().x() >= 0 ;
-              cell.getCoordinates().x()-- )		
+         //aux.save( "aux-3.tnl" );
+
+         for( cell.getCoordinates().y() = mesh->getDimensions().y() - 1;
+              cell.getCoordinates().y() >= 0;
+              cell.getCoordinates().y()-- )
             {
-               //std::cerr << "2 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
+            for( cell.getCoordinates().x() = mesh->getDimensions().x() - 1;
+                 cell.getCoordinates().x() >= 0 ;
+                 cell.getCoordinates().x()-- )		
+               {
+                  //std::cerr << "4 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
             }
-      }
-      
-      //aux.save( "aux-2.tnl" );
 
-      for( cell.getCoordinates().y() = mesh->getDimensions().y() - 1;
-           cell.getCoordinates().y() >= 0 ;
-           cell.getCoordinates().y()-- )
+         //aux.save( "aux-4.tnl" );
+
+         /*for( cell.getCoordinates().x() = 0;
+              cell.getCoordinates().x() < mesh->getDimensions().y();
+              cell.getCoordinates().x()++ )
          {
+            for( cell.getCoordinates().y() = 0;
+                 cell.getCoordinates().y() < mesh->getDimensions().x();
+                 cell.getCoordinates().y()++ )
+               {
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )
+                     this->updateCell( aux, cell );
+               }
+         }     
+
+
+         aux.save( "aux-5.tnl" );
+
          for( cell.getCoordinates().x() = 0;
-              cell.getCoordinates().x() < mesh->getDimensions().x();
+              cell.getCoordinates().x() < mesh->getDimensions().y();
               cell.getCoordinates().x()++ )
-            {
-               //std::cerr << "3 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
-            }
-         }
-      
-      //aux.save( "aux-3.tnl" );
-      
-      for( cell.getCoordinates().y() = mesh->getDimensions().y() - 1;
-           cell.getCoordinates().y() >= 0;
-           cell.getCoordinates().y()-- )
          {
-         for( cell.getCoordinates().x() = mesh->getDimensions().x() - 1;
+            for( cell.getCoordinates().y() = mesh->getDimensions().x() - 1;
+                 cell.getCoordinates().y() >= 0 ;
+                 cell.getCoordinates().y()-- )		
+               {
+                  //std::cerr << "2 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
+         }
+         aux.save( "aux-6.tnl" );
+
+         for( cell.getCoordinates().x() = mesh->getDimensions().y() - 1;
               cell.getCoordinates().x() >= 0 ;
-              cell.getCoordinates().x()-- )		
+              cell.getCoordinates().x()-- )
             {
-               //std::cerr << "4 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
+            for( cell.getCoordinates().y() = 0;
+                 cell.getCoordinates().y() < mesh->getDimensions().x();
+                 cell.getCoordinates().y()++ )
+               {
+                  //std::cerr << "3 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
             }
-         }
-            
-      //aux.save( "aux-4.tnl" );
-      
-      /*for( cell.getCoordinates().x() = 0;
-           cell.getCoordinates().x() < mesh->getDimensions().y();
-           cell.getCoordinates().x()++ )
-      {
-         for( cell.getCoordinates().y() = 0;
-              cell.getCoordinates().y() < mesh->getDimensions().x();
-              cell.getCoordinates().y()++ )
+         aux.save( "aux-7.tnl" );
+
+         for( cell.getCoordinates().x() = mesh->getDimensions().y() - 1;
+              cell.getCoordinates().x() >= 0;
+              cell.getCoordinates().x()-- )
             {
-               cell.refresh();
-               if( ! interfaceMap( cell ) )
-                  this->updateCell( aux, cell );
-            }
-      }     
-        
-      
-      aux.save( "aux-5.tnl" );
-      
-      for( cell.getCoordinates().x() = 0;
-           cell.getCoordinates().x() < mesh->getDimensions().y();
-           cell.getCoordinates().x()++ )
+            for( cell.getCoordinates().y() = mesh->getDimensions().x() - 1;
+                 cell.getCoordinates().y() >= 0 ;
+                 cell.getCoordinates().y()-- )		
+               {
+                  //std::cerr << "4 -> ";
+                  cell.refresh();
+                  if( ! interfaceMap( cell ) )            
+                     this->updateCell( aux, cell );
+               }
+            }*/
+      }
+      if( std::is_same< DeviceType, Devices::Cuda >::value )
       {
-         for( cell.getCoordinates().y() = mesh->getDimensions().x() - 1;
-              cell.getCoordinates().y() >= 0 ;
-              cell.getCoordinates().y()-- )		
-            {
-               //std::cerr << "2 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
-            }
+         // TODO: CUDA code
       }
-      aux.save( "aux-6.tnl" );
-
-      for( cell.getCoordinates().x() = mesh->getDimensions().y() - 1;
-           cell.getCoordinates().x() >= 0 ;
-           cell.getCoordinates().x()-- )
-         {
-         for( cell.getCoordinates().y() = 0;
-              cell.getCoordinates().y() < mesh->getDimensions().x();
-              cell.getCoordinates().y()++ )
-            {
-               //std::cerr << "3 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
-            }
-         }
-      aux.save( "aux-7.tnl" );
-      
-      for( cell.getCoordinates().x() = mesh->getDimensions().y() - 1;
-           cell.getCoordinates().x() >= 0;
-           cell.getCoordinates().x()-- )
-         {
-         for( cell.getCoordinates().y() = mesh->getDimensions().x() - 1;
-              cell.getCoordinates().y() >= 0 ;
-              cell.getCoordinates().y()-- )		
-            {
-               //std::cerr << "4 -> ";
-               cell.refresh();
-               if( ! interfaceMap( cell ) )            
-                  this->updateCell( aux, cell );
-            }
-         }*/
       
       iteration++;
    }
-- 
GitLab