Commit bb045e76 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added CUDA switch.

parent 7d7e266e
Loading
Loading
Loading
Loading
+9 −8
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@

#include <TNL/Meshes/Grid.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Devices/Cuda.h>

using namespace TNL;

@@ -36,7 +37,7 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 1, Real, Device, Index > >
                          InterfaceMapType& interfaceMap );
      
      template< typename MeshEntity >
      void updateCell( MeshFunctionType& u,
      __cuda_callable__ void updateCell( MeshFunctionType& u,
                                         const MeshEntity& cell );
      
};
@@ -61,7 +62,7 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 2, Real, Device, Index > >
                          InterfaceMapType& interfaceMap );
      
      template< typename MeshEntity >
      void updateCell( MeshFunctionType& u,
      __cuda_callable__ void updateCell( MeshFunctionType& u,
                                         const MeshEntity& cell,
                                         const RealType velocity = 1.0 );
};
@@ -85,7 +86,7 @@ class tnlDirectEikonalMethodsBase< Meshes::Grid< 3, Real, Device, Index > >
                          InterfaceMapType& interfaceMap );
      
      template< typename MeshEntity >
      void updateCell( MeshFunctionType& u,
      __cuda_callable__ void updateCell( MeshFunctionType& u,
                                         const MeshEntity& cell,
                                         const RealType velocity = 1.0);
      
+120 −114
Original line number Diff line number Diff line
@@ -71,7 +71,8 @@ solve( const MeshPointer& mesh,
   IndexType iteration( 0 );
   while( iteration < this->maxIterations )
   {

      if( std::is_same< DeviceType, Devices::Host >::value )
      {
         for( cell.getCoordinates().y() = 0;
              cell.getCoordinates().y() < mesh->getDimensions().y();
              cell.getCoordinates().y()++ )
@@ -202,6 +203,11 @@ solve( const MeshPointer& mesh,
                     this->updateCell( aux, cell );
               }
            }*/
      }
      if( std::is_same< DeviceType, Devices::Cuda >::value )
      {
         // TODO: CUDA code
      }
      
      iteration++;
   }