Commit a228f31c authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Refactoring with the shared pointers.

parent b09ddd79
Loading
Loading
Loading
Loading
+363 −6
Original line number Diff line number Diff line
@@ -28,20 +28,25 @@
 */
template< typename Object,
          typename Device = typename Object::DeviceType,
          bool lazy = false >
          bool lazy = false,
          bool isConst = std::is_const< Object >::value >
class tnlSharedPointer
{
   static_assert( ! std::is_same< Device, void >::value, "The device cannot be void. You need to specify the device explicitly in your code." );
};

/****
 * Non-const specialization
 */
template< typename Object, bool lazy >
class tnlSharedPointer< Object, tnlHost, lazy > : public tnlSmartPointer
class tnlSharedPointer< Object, tnlHost, lazy, false > : public tnlSmartPointer
{   
   public:
      
      typedef Object ObjectType;
      typedef tnlHost DeviceType;
      typedef tnlSharedPointer< Object, tnlHost > ThisType;
      typedef tnlSharedPointer< Object, tnlHost, lazy, false > ThisType;
      typedef tnlSharedPointer< const Object, tnlHost, lazy, true > ConstThisType;
         
      template< typename... Args >
      explicit  tnlSharedPointer( Args... args )
@@ -176,14 +181,165 @@ class tnlSharedPointer< Object, tnlHost, lazy > : public tnlSmartPointer
      int* counter;
};

/****
 * Const specialization
 */
template< typename Object, bool lazy >
class tnlSharedPointer< Object, tnlHost, lazy, true > : public tnlSmartPointer
{   
   public:
      
      typedef Object ObjectType;
      typedef tnlHost DeviceType;
      typedef tnlSharedPointer< Object, tnlHost, lazy, true > ThisType;
      typedef typename std::remove_const< Object >::type NonConstObjectType;
         
      template< typename... Args >
      explicit  tnlSharedPointer( Args... args )
      : counter( 0 ), pointer( 0 )
      {
         if( ! lazy )
         {
            this->counter = new int;
            this->pointer = new Object( args... );
            *( this->counter ) = 1;
         }
      }
      
      tnlSharedPointer( const ThisType& pointer )
      : pointer( pointer.pointer ),
        counter( pointer.counter )
      {
         *counter++;
      }
      
      tnlSharedPointer( const tnlSharedPointer< NonConstObjectType, tnlHost, lazy >& pointer )
      : pointer( pointer.pointer ),
        counter( pointer.counter )
      {
         *counter++;
      }

      
      template< typename... Args >
      bool recreate( Args... args )
      {         
         std::cerr << "Creating new shared pointer..." << std::endl;
         if( ! this->counter )
         {
            this->counter = new int;
            *this->counter = 1;
            this->pointer = new ObjectType( args... );
            return true;
         }
         if( *this->counter == 1 )
         {
            /****
             * The object is not shared
             */
            this->pointer->~ObjectType();
            new ( this->pointer ) ObjectType( args... );
            return true;
         }
         ( *this->counter )--;
         this->pointer = new Object( args... );
         this->counter = new int;
         if( ! this->pointer || ! this->counter )
            return false;
         *( this->counter ) = 1;
         return true;         
      }      
      
      const Object* operator->() const
      {
         return this->pointer;
      }
            
      const Object& operator *() const
      {
         return *( this->pointer );
      }
      
      template< typename Device = tnlHost >
      const Object& getData() const
      {
         return *( this->pointer );
      }
      
      const ThisType& operator=( const tnlSharedPointer< NonConstObjectType, tnlHost >& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->counter = ptr.counter;
         *( this->counter )++;
         return *this;
      }      

      
      const ThisType& operator=( const ThisType& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->counter = ptr.counter;
         *( this->counter )++;
         return *this;
      }      
      
      const ThisType& operator=( const ThisType&& ptr )
      {
         if( this-> pointer )
            delete this->pointer;
         this->pointer = ptr.pointer;
         ptr.pointer= NULL;
         this->counter = ptr.counter;
         ptr.counter = NULL;
         return *this;
      }
            
      bool synchronize()
      {
         return true;
      }
      
      ~tnlSharedPointer()
      {
         this->free();
      }

      
   protected:
      
      void free()
      {
         if( ! this->pointer )
            return;
         if( this->counter )
         {
            if( ! --*( this->counter ) )
            {
               delete this->pointer;
               std::cerr << "Deleting data..." << std::endl;
            }
         }

      }
      
      const Object* pointer;
      
      int* counter;
};

/****
 * Non-const specialization for CUDA
 */
template< typename Object, bool lazy >
class tnlSharedPointer< Object, tnlCuda, lazy > : public tnlSmartPointer
class tnlSharedPointer< Object, tnlCuda, lazy, false > : public tnlSmartPointer
{
   public:
      
      typedef Object ObjectType;
      typedef tnlHost DeviceType;
      typedef tnlSharedPointer< Object, tnlCuda > ThisType;
      typedef tnlSharedPointer< Object, tnlCuda, lazy > ThisType;

      template< typename... Args >
      explicit  tnlSharedPointer( Args... args )
@@ -397,3 +553,204 @@ class tnlSharedPointer< Object, tnlCuda, lazy > : public tnlSmartPointer
};


/****
 * Const specialization for CUDA
 */
template< typename Object, bool lazy >
class tnlSharedPointer< Object, tnlCuda, lazy, true > : public tnlSmartPointer
{
   public:
      
      typedef Object ObjectType;
      typedef tnlHost DeviceType;
      typedef tnlSharedPointer< Object, tnlCuda, lazy > ThisType;
      typedef typename std::remove_const< Object >::type NonConstObjectType;      

      template< typename... Args >
      explicit  tnlSharedPointer( Args... args )
      : counter( 0 ), cuda_pointer( 0 ), 
        pointer( 0 ), modified( false )
      {
         if( ! lazy )
         {
            this->counter = new int;
            this->pointer = new Object( args... );
#ifdef HAVE_CUDA         
            this->cuda_pointer = tnlCuda::passToDevice( *this->pointer );
            if( ! checkCudaDevice )
               return;
            tnlCuda::insertSmartPointer( this );
#endif            
         }
      }
                  
      tnlSharedPointer( const ThisType& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
      {
         *counter++;
      }
      
      tnlSharedPointer( const tnlSharedPointer< NonConstObjectType, tnlCuda, lazy >& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter )
      {
         *counter++;
      }      

      template< typename... Args >
      bool recreate( Args... args )
      {
         std::cerr << "Creating new shared pointer..." << std::endl;
         if( ! this->counter )
         {
            this->counter = new int;
            *this->counter = 1;
            this->pointer = new ObjectType( args... );
#ifdef HAVE_CUDA         
            this->cuda_pointer = tnlCuda::passToDevice( *this->object );
            if( ! checkCudaDevice )
               return false;
            tnlCuda::insertSmartPointer( this );
#endif                 
            return true;
         }
         if( *this->counter == 1 )
         {
            /****
             * The object is not shared
             */
            this->pointer->~ObjectType();
            new ( this->pointer ) ObjectType( args... );
#ifdef HAVE_CUDA                     
            cudaMemcpy( this->cuda_pointer, this->pointer, sizeof( Object ), cudaMemcpyHostToDevice );
#endif            
            return true;
         }

         this->modified = false;
         this->counter= new int;
         this->pointer = new Object( args... );
         if( ! this->pointer || ! this->counter )
            return false;
         *( this->counter )= 1;         
#ifdef HAVE_CUDA         
         cudaMalloc( ( void** )  &this->cuda_pointer, sizeof( Object ) );
         cudaMemcpy( this->cuda_pointer, this->pointer, sizeof( Object ), cudaMemcpyHostToDevice );
         if( ! checkCudaDevice )
            return false;
         tnlCuda::insertSmartPointer( this );
#endif
         return true;
      }
      
      const Object* operator->() const
      {
         return this->pointer;
      }
      
      const Object& operator *() const
      {
         return *( this->pointer );
      }
      
      template< typename Device = tnlHost >   
      __cuda_callable__
      const Object& getData() const
      {
         static_assert( std::is_same< Device, tnlHost >::value || std::is_same< Device, tnlCuda >::value, "Only tnlHost or tnlCuda devices are accepted here." );
         tnlAssert( this->pointer, );
         tnlAssert( this->cuda_pointer, );
         if( std::is_same< Device, tnlHost >::value )
            return *( this->pointer );
         if( std::is_same< Device, tnlCuda >::value )
            return *( this->cuda_pointer );            
      }

      
      /*const ThisType& operator=( ThisType&& ptr )
      {
         if( this-> pointer )
            delete this->pointer;
#ifdef HAVE_CUDA
         if( this->cuda_pointer )
            cudaFree( this->cuda_pointer );
#endif                  
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->modified = ptr.modified;
         this->counter = ptr.counter;
         ptr.pointer= NULL;
         ptr.cuda_pointer = NULL;
         ptr.modified = false;
         ptr.counter = NULL;
         return *this;
      }*/

      const ThisType& operator=( const tnlSharedPointer< NonConstObjectType, tnlCuda >& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->counter = ptr.counter;
         this->cuda_pointer = ptr.cuda_pointer;
         this->modified = ptr.modified;
         *( this->counter )++;
         return *this;
      }      
      
      const ThisType& operator=( const ThisType& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->modified = ptr.modified;
         this->counter = ptr.counter;
         *( this->counter )++;
         return *this;
      }      
      
      bool synchronize()
      {
         return true;
      }
            
      ~tnlSharedPointer()
      {
         this->free();
#ifdef HAVE_CUDA         
         tnlCuda::removeSmartPointer( this );
#endif         
      }
      
   protected:
      
      void free()
      {
         if( ! this->pointer )
            return;
         if( this->counter )
         {
            if( ! --*( this->counter ) )
            {
               if( this->pointer )
                  delete this->pointer;
#ifdef HAVE_CUDA
               if( this->cuda_pointer )
                  cudaFree( this->cuda_pointer );
               checkCudaDevice;
#endif         
               std::cerr << "Deleting data..." << std::endl;
            }
         }
         
      }
      
      Object *pointer, *cuda_pointer;
      
      bool modified;
      
      int* counter;
};
+6 −0
Original line number Diff line number Diff line
@@ -75,6 +75,12 @@ class tnlMeshFunction :
                 const Vector& data,
                 const IndexType& offset = 0 );
      
      template< typename Vector >
      void bind( const MeshPointer& meshPointer,
                 const tnlSharedPointer< Vector >& dataPtr,
                 const IndexType& offset = 0 );
      
      
      void setMesh( const MeshPointer& meshPointer );
      
      template< typename Device = tnlHost >
+18 −0
Original line number Diff line number Diff line
@@ -157,6 +157,24 @@ bind( const MeshPointer& meshPointer,
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
}

template< typename Mesh,
          int MeshEntityDimensions,
          typename Real >
   template< typename Vector >
void
tnlMeshFunction< Mesh, MeshEntityDimensions, Real >::
bind( const MeshPointer& meshPointer,
      const tnlSharedPointer< Vector >& data,
      const IndexType& offset )
{
   this->meshPointer = meshPointer;
   this->data.bind( *data, offset, meshPointer->template getEntitiesCount< typename Mesh::template MeshEntity< MeshEntityDimensions > >() );
   tnlAssert( this->data.getSize() == this->meshPointer.getData().template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >(), 
      std::cerr << "this->data.getSize() = " << this->data.getSize() << std::endl
                << "this->mesh->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() = " << this->meshPointer->template getEntitiesCount< typename MeshType::template MeshEntity< MeshEntityDimensions > >() );   
}


template< typename Mesh,
          int MeshEntityDimensions,
          typename Real >
+15 −2
Original line number Diff line number Diff line
@@ -76,6 +76,7 @@ class tnlOperatorFunction< Operator, MeshFunction, void, true >
      typedef typename OperatorType::IndexType IndexType;
      typedef typename OperatorType::ExactOperatorType ExactOperatorType;
      typedef tnlMeshFunction< MeshType, OperatorType::getPreimageEntitiesDimensions() > PreimageFunctionType;
      typedef tnlSharedPointer< MeshType, DeviceType > MeshPointer;
      
      static constexpr int getEntitiesDimensions() { return OperatorType::getImageEntitiesDimensions(); };     
      
@@ -92,6 +93,13 @@ class tnlOperatorFunction< Operator, MeshFunction, void, true >
         return this->preimageFunction->getMesh(); 
      };
      
      const MeshPointer& getMeshPointer() const
      { 
         tnlAssert( this->preimageFunction, std::cerr << "The preimage function was not set." << std::endl );
         return this->preimageFunction->getMeshPointer(); 
      };

      
      void setPreimageFunction( const FunctionType& preimageFunction ) { this->preimageFunction = &preimageFunction; }
      
      Operator& getOperator() { return this->operator_; }
@@ -147,11 +155,12 @@ class tnlOperatorFunction< Operator, PreimageFunction, void, false >
      typedef tnlMeshFunction< MeshType, Operator::getImageEntitiesDimensions() > ImageFunctionType;
      typedef tnlOperatorFunction< Operator, PreimageFunction, void, true > OperatorFunction;
      typedef typename OperatorType::ExactOperatorType ExactOperatorType;
      typedef tnlSharedPointer< MeshType, DeviceType > MeshPointer;
      
      static constexpr int getEntitiesDimensions() { return OperatorType::getImageEntitiesDimensions(); };     
      
      tnlOperatorFunction( OperatorType& operator_,
                           const MeshType& mesh )
                           const MeshPointer& mesh )
      :  operator_( operator_ ), imageFunction( mesh )
      {};
      
@@ -162,6 +171,8 @@ class tnlOperatorFunction< Operator, PreimageFunction, void, false >
      
      const MeshType& getMesh() const { return this->imageFunction.getMesh(); };
      
      const MeshPointer& getMeshPointer() const { return this->imageFunction.getMeshPointer(); };
      
      ImageFunctionType& getImageFunction() { return this->imageFunction; };
      
      const ImageFunctionType& getImageFunction() const { return this->imageFunction; };
@@ -272,12 +283,14 @@ class tnlOperatorFunction< Operator, PreimageFunction, BoundaryConditions, false
                           const PreimageFunctionType& preimageFunction )
      :  operator_( operator_ ),
         boundaryConditions( boundaryConditions ),
         imageFunction( preimageFunction.getMesh() ),
         imageFunction( preimageFunction.getMeshPointer() ),
         preimageFunction( &preimageFunction )
      {};
      
      const MeshType& getMesh() const { return imageFunction.getMesh(); };
      
      const MeshPointer& getMeshPointer() const { return imageFunction.getMeshPointer(); };
      
      void setPreimageFunction( const PreimageFunction& preimageFunction )
      { 
         this->preimageFunction = &preimageFunction;
+2 −1
Original line number Diff line number Diff line
@@ -55,10 +55,11 @@ class tnlCoFVMGradientNorm< tnlGrid< MeshDimensions, MeshReal, Device, MeshIndex
      typedef tnlMeshEntitiesInterpolants< MeshType, MeshDimensions - 1, MeshDimensions > OuterOperator;
      typedef tnlOperatorComposition< OuterOperator, InnerOperator > BaseType;
      typedef tnlExactGradientNorm< MeshDimensions, RealType > ExactOperatorType;
      typedef tnlSharedPointer< MeshType > MeshPointer;
         
      tnlCoFVMGradientNorm( const OuterOperator& outerOperator,
                            InnerOperator& innerOperator,
                            const MeshType& mesh )
                            const MeshPointer& mesh )
      : BaseType( outerOperator, innerOperator, mesh )
      {}
      
Loading