Commit 967345b2 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

CUDA version with smart pointers is working.

parent 5280b85e
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -269,7 +269,7 @@ class tnlSharedPointer< Object, tnlCuda > : public tnlSmartPointer
#ifdef HAVE_CUDA
         if( this->modified )
         {
            std::cerr << "Synchronizing data..." << std::endl;
            //std::cerr << "Synchronizing data..." << std::endl;
            tnlAssert( this->pointer, );
            tnlAssert( this->cuda_pointer, );
            cudaMemcpy( this->cuda_pointer, this->pointer, sizeof( ObjectType ), cudaMemcpyHostToDevice );
+4 −4
Original line number Diff line number Diff line
@@ -396,12 +396,12 @@ void
tnlTestFunction< FunctionDimensions, Real, Device >::
deleteFunction()
{
   if( Device::DeviceType == ( int ) tnlHostDevice )
   if( std::is_same< Device, tnlHost >::value )
   {
      if( function )
         delete ( FunctionType * ) function;
   }
   if( Device::DeviceType == ( int ) tnlCudaDevice )
   if( std::is_same< Device, tnlCuda >::value )
   {
      if( function )
         tnlCuda::freeFromDevice( ( FunctionType * ) function );
@@ -455,12 +455,12 @@ void
tnlTestFunction< FunctionDimensions, Real, Device >::
copyFunction( const void* function )
{
   if( Device::DeviceType == ( int ) tnlHostDevice ) 
   if( std::is_same< Device, tnlHost >::value ) 
   {
      FunctionType* f = new FunctionType;
      *f = * ( FunctionType* )function;
   }
   if( Device::DeviceType == ( int ) tnlCudaDevice )
   if( std::is_same< Device, tnlCuda >::value )
   {
      tnlAssert( false, );
      abort();
+3 −0
Original line number Diff line number Diff line
@@ -302,6 +302,9 @@ tnlGridTraverser2D(
   typedef tnlGrid< 2, Real, tnlCuda, Index > GridType;
   typename GridType::CoordinatesType coordinates;

   //if( threadIdx.x == 0 )
   //   printf( "%d x %d \n ", grid->getDimensions().x(), grid->getDimensions().y() );
   
   coordinates.x() = begin->x() + ( gridXIdx * tnlCuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x;
   coordinates.y() = begin->y() + ( gridYIdx * tnlCuda::getMaxGridSize() + blockIdx.y ) * blockDim.y + threadIdx.y;  
   
+39 −38
Original line number Diff line number Diff line
@@ -29,23 +29,24 @@ template< typename Real,
          typename RightHandSide >
class tnlExplicitUpdaterTraverserUserData
{
   public:

      /*const DifferentialOperator differentialOperator;
      const DifferentialOperator* differentialOperator;

      const BoundaryConditions boundaryConditions;
      const BoundaryConditions* boundaryConditions;

      const RightHandSide rightHandSide;
      const RightHandSide* rightHandSide;

      MeshFunction u, fu;*/
      MeshFunction *u, *fu;
      
      char data[ sizeof( DifferentialOperator ) + 
      /*char data[ sizeof( DifferentialOperator ) + 
                 sizeof( BoundaryConditions ) + 
                 sizeof( RightHandSide ) +
                 2 * sizeof( MeshFunction ) ];
                 2 * sizeof( MeshFunction ) ];*/

      public:

         const Real time;         
         const Real* time;         


      tnlExplicitUpdaterTraverserUserData( const Real& time,
@@ -54,14 +55,14 @@ class tnlExplicitUpdaterTraverserUserData
                                           const RightHandSide& rightHandSide,
                                           MeshFunction& u,
                                           MeshFunction& fu )
      : time( time )
        /*differentialOperator( differentialOperator ),
        boundaryConditions( boundaryConditions ),
        rightHandSide( rightHandSide ),
        u( u ),
        fu( fu )*/
      : time( &time ),
        differentialOperator( &differentialOperator ),
        boundaryConditions( &boundaryConditions ),
        rightHandSide( &rightHandSide ),
        u( &u ),
        fu( &fu )
      {
         char* ptr = data;
         /*char* ptr = data;
         memcpy( ptr, &differentialOperator, sizeof( DifferentialOperator ) );
         ptr +=  sizeof( DifferentialOperator );
         memcpy( ptr, &boundaryConditions, sizeof( BoundaryConditions ) );
@@ -70,39 +71,39 @@ class tnlExplicitUpdaterTraverserUserData
         ptr += sizeof( RightHandSide );
         memcpy( ptr, &u, sizeof( MeshFunction ) );
         ptr += sizeof( MeshFunction );
         memcpy( ptr, &fu, sizeof( MeshFunction ) );
         memcpy( ptr, &fu, sizeof( MeshFunction ) );*/
      };
      
      DifferentialOperator& differentialOperator()
      /*DifferentialOperator& differentialOperator()
      {
         return * ( DifferentialOperator* ) data;
         return this->differentialOperator; //* ( DifferentialOperator* ) data;
      }
      
      BoundaryConditions& boundaryConditions()
      {
         return * ( BoundaryConditions* ) & data[ sizeof( DifferentialOperator ) ];
         return this->boundaryConditions; //* ( BoundaryConditions* ) & data[ sizeof( DifferentialOperator ) ];
      }
      
      RightHandSide& rightHandSide()
      {
         return * ( RightHandSide* ) & data[ sizeof( DifferentialOperator ) +
                                             sizeof( BoundaryConditions ) ];
         return this->rightHandSide; //* ( RightHandSide* ) & data[ sizeof( DifferentialOperator ) +
                                     //        sizeof( BoundaryConditions ) ];
      }
      
      MeshFunction& u()
      {
         return * ( MeshFunction* ) & data[ sizeof( DifferentialOperator ) +
                                            sizeof( BoundaryConditions ) + 
                                            sizeof( RightHandSide )];
         return this->u; //* ( MeshFunction* ) & data[ sizeof( DifferentialOperator ) +
                         //                   sizeof( BoundaryConditions ) + 
                         //                   sizeof( RightHandSide )];
      }
      
      MeshFunction& fu()
      {
         return * ( MeshFunction* ) & data[ sizeof( DifferentialOperator ) +
                                            sizeof( BoundaryConditions ) + 
                                            sizeof( RightHandSide ) + 
                                            sizeof( MeshFunction ) ];
      }
         return this->fu; //* ( MeshFunction* ) & data[ sizeof( DifferentialOperator ) +
                          //                  sizeof( BoundaryConditions ) + 
                          //                  sizeof( RightHandSide ) + 
                          //                  sizeof( MeshFunction ) ];
      }*/
};


@@ -152,10 +153,10 @@ class tnlExplicitUpdater
                                              TraverserUserData& userData,
                                              const GridEntity& entity )
            {
               ( userData.u() )( entity ) = userData.boundaryConditions().operator()
               ( userData.u(),
               ( *userData.u )( entity ) = ( *userData.boundaryConditions )
               ( *userData.u,
                 entity,
                 userData.time );
                 *userData.time );
            }

      };
@@ -172,18 +173,18 @@ class tnlExplicitUpdater
                                              TraverserUserData& userData,
                                              const EntityType& entity )
            {
               ( userData.fu())( entity ) =
                  userData.differentialOperator().operator()(
                     userData.u(),
               ( *userData.fu )( entity ) =
                  ( *userData.differentialOperator )(
                     *userData.u,
                     entity,
                     userData.time );
                     *userData.time );

               typedef tnlFunctionAdapter< MeshType, RightHandSide > FunctionAdapter;
               (  userData.fu() )( entity ) += 
               (  *userData.fu )( entity ) += 
                  FunctionAdapter::getValue(
                     userData.rightHandSide(),
                     *userData.rightHandSide,
                     entity,
                     userData.time );
                     *userData.time );
            }
      };
      
+6 −6
Original line number Diff line number Diff line
@@ -46,7 +46,7 @@ update( const RealType& time,
                                           typename MeshFunction::DeviceType,
                                           typename MeshFunction::IndexType > >::value != true,
      "Error: I am getting tnlVector instead of tnlMeshFunction or similar object. You might forget to bind DofVector into tnlMeshFunction in you method getExplicitRHS."  );
   //if( std::is_same< DeviceType, tnlHost >::value )
   if( std::is_same< DeviceType, tnlHost >::value )
   {
      TraverserUserData userData( time, differentialOperator, boundaryConditions, rightHandSide, u, fu );
      tnlTraverser< MeshType, EntityType > meshTraverser;
@@ -60,7 +60,7 @@ update( const RealType& time,
                                                      userData );

   }
   /*if( std::is_same< DeviceType, tnlCuda >::value )
   if( std::is_same< DeviceType, tnlCuda >::value )
   {
      if( this->gpuTransferTimer ) 
         this->gpuTransferTimer->start();
@@ -73,16 +73,16 @@ update( const RealType& time,
     if( this->gpuTransferTimer ) 
         this->gpuTransferTimer->stop();

      //TraverserUserData userData( *kernelTime, *kernelDifferentialOperator, *kernelBoundaryConditions, *kernelRightHandSide, *kernelU, *kernelFu );
      TraverserUserData userData( *kernelTime, *kernelDifferentialOperator, *kernelBoundaryConditions, *kernelRightHandSide, *kernelU, *kernelFu );
      checkCudaDevice;
      tnlTraverser< MeshType, EntityType > meshTraverser;
      meshTraverser.template processBoundaryEntities< TraverserUserData,
                                                      TraverserBoundaryEntitiesProcessor >
                                                    ( mesh,
                                                    ( meshPointer,
                                                      userData );
      meshTraverser.template processInteriorEntities< TraverserUserData,
                                                      TraverserInteriorEntitiesProcessor >
                                                    ( mesh,
                                                    ( meshPointer,
                                                      userData );

      if( this->gpuTransferTimer ) 
@@ -100,7 +100,7 @@ update( const RealType& time,
      if( this->gpuTransferTimer ) 
         this->gpuTransferTimer->stop();

   }*/
   }
}

#endif /* TNLEXPLICITUPDATER_IMPL_H_ */