Commit a3298c46 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Improved smart pointers

All dynamically allocated objects in the smart pointers were aggregated
into a single structure to avoid the number of dynamic allocations,
decrease the size of the smart pointer classes and to make the code more
readable.
parent bd5d5b99
Loading
Loading
Loading
Loading
+54 −56
Original line number Diff line number Diff line
@@ -212,8 +212,9 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      typedef DevicePointer< Object, Devices::Cuda > ThisType;

      explicit  DevicePointer( ObjectType& obj )
      : pointer( 0 ), cuda_pointer( 0 ),
        counter( 0 ), last_sync_state( 0 )
      : pointer( nullptr ),
        pd( nullptr ),
        cuda_pointer( nullptr )
      {
         this->allocate( obj );
      }
@@ -221,11 +222,10 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      // this is needed only to avoid the default compiler-generated constructor
      DevicePointer( const ThisType& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        last_sync_state( pointer.last_sync_state )
        pd( (PointerData*) pointer.pd ),
        cuda_pointer( pointer.cuda_pointer )
      {
         *counter += 1;
         this->pd->counter += 1;
      }

      // conditional constructor for non-const -> const data
@@ -233,24 +233,21 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
                typename = typename Enabler< Object_ >::type >
      DevicePointer( const DevicePointer< Object_, DeviceType >& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        last_sync_state( pointer.last_sync_state )
        pd( (PointerData*) pointer.pd ),
        cuda_pointer( pointer.cuda_pointer )
      {
         *counter += 1;
         this->pd->counter += 1;
      }

      // this is needed only to avoid the default compiler-generated constructor
      DevicePointer( ThisType&& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        last_sync_state( pointer.last_sync_state )
        pd( (PointerData*) pointer.pd ),
        cuda_pointer( pointer.cuda_pointer )
      {
         pointer.pointer = nullptr;
         pointer.pd = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.last_sync_state = nullptr;
      }

      // conditional constructor for non-const -> const data
@@ -258,14 +255,12 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
                typename = typename Enabler< Object_ >::type >
      DevicePointer( DevicePointer< Object_, DeviceType >&& pointer )
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        last_sync_state( pointer.last_sync_state )
        pd( (PointerData*) pointer.pd ),
        cuda_pointer( pointer.cuda_pointer )
      {
         pointer.pointer = nullptr;
         pointer.pd = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.last_sync_state = nullptr;
      }

      const Object* operator->() const
@@ -290,7 +285,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

      operator bool()
      {
         return this->pointer;
         return this->pd;
      }

      template< typename Device = Devices::Host >
@@ -299,6 +294,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      {
         static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value, "Only Devices::Host or Devices::Cuda devices are accepted here." );
         Assert( this->pointer, );
         Assert( this->pd, );
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
            return *( this->pointer );
@@ -312,26 +308,22 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      {
         static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value, "Only Devices::Host or Devices::Cuda devices are accepted here." );
         Assert( this->pointer, );
         Assert( this->pd, );
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
         {
            return *( this->pointer );
         }
         if( std::is_same< Device, Devices::Cuda >::value )
         {
            return *( this->cuda_pointer );
      }
      }

      // this is needed only to avoid the default compiler-generated operator
      const ThisType& operator=( const ThisType& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->pd = (PointerData*) ptr.pd;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
         this->pd->counter += 1;
         return *this;
      }

@@ -342,10 +334,9 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      {
         this->free();
         this->pointer = ptr.pointer;
         this->pd = (PointerData*) ptr.pd;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
         this->pd->counter += 1;
         return *this;
      }

@@ -354,13 +345,11 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      {
         this->free();
         this->pointer = ptr.pointer;
         this->pd = (PointerData*) ptr.pd;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.pd = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.last_sync_state = nullptr;
         return *this;
      }

@@ -371,18 +360,18 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      {
         this->free();
         this->pointer = ptr.pointer;
         this->pd = (PointerData*) ptr.pd;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.pd = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.last_sync_state = nullptr;
         return *this;
      }

      bool synchronize()
      {
         if( ! this->pd )
            return true;
#ifdef HAVE_CUDA
         if( this->modified() )
         {
@@ -409,16 +398,23 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

   protected:

      struct PointerData
      {
         char data_image[ sizeof(Object) ];
         int counter = 1;
      };

      bool allocate( ObjectType& obj )
      {
         this->counter = new int( 1 );
         if( ! this->counter )
            return false;
         this->pointer = &obj;
         this->pd = new PointerData();
         if( ! this->pd )
            return false;
         // pass to device
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         if( ! this->cuda_pointer )
            return false;
         this->last_sync_state = ::operator new( sizeof( Object ) );
         // set last-sync state
         this->set_last_sync_state();
         Devices::Cuda::insertSmartPointer( this );
         return true;
@@ -426,37 +422,39 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

      void set_last_sync_state()
      {
         std::memcpy( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) );
         Assert( this->pointer, );
         Assert( this->pd, );
         std::memcpy( (void*) &this->pd->data_image, (void*) this->pointer, sizeof( Object ) );
      }

      bool modified()
      {
         if( ! this->pointer || ! this->last_sync_state )
            return false;
         return std::memcmp( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) ) != 0;
         Assert( this->pointer, );
         Assert( this->pd, );
         return std::memcmp( (void*) &this->pd->data_image, (void*) this->pointer, sizeof( Object ) ) != 0;
      }

      void free()
      {
         if( this->counter )
         if( this->pd )
         {
            if( ! --*( this->counter ) )
            if( ! --this->pd->counter )
            {
               delete this->counter;
               this->counter = nullptr;
               delete this->pd;
               this->pd = nullptr;
               if( this->cuda_pointer )
                  Devices::Cuda::freeFromDevice( this->cuda_pointer );
               if( this->last_sync_state )
                  ::operator delete( this->last_sync_state );
            }
         }
      }

      Object *pointer, *cuda_pointer;
      Object* pointer;

      int* counter;
      PointerData* pd;

      void* last_sync_state;
      // cuda_pointer can't be part of PointerData structure, since we would be
      // unable to dereference this-pd on the device
      Object* cuda_pointer;
};

} // namespace TNL
+135 −149

File changed.

Preview size limit exceeded, changes collapsed.

+44 −33
Original line number Diff line number Diff line
@@ -124,43 +124,46 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      typedef UniquePointer< Object, Devices::Cuda > ThisType;
         
      template< typename... Args >
      UniquePointer( const Args... args )
      : pointer( 0 ), cuda_pointer( 0 )
      explicit  UniquePointer( const Args... args )
      : pd( nullptr ),
        cuda_pointer( nullptr )
      {
         this->allocate( args... );
      }
      
      const Object* operator->() const
      {
         return this->pointer;
         return &this->pd->data;
      }
      
      Object* operator->()
      {
         return this->pointer;
         return &this->pd->data;
      }
      
      const Object& operator *() const
      {
         return *( this->pointer );
         return this->pd->data;
      }
      
      Object& operator *()
      {
         return *( this->pointer );
         return this->pd->data;
      }
      
      operator bool()
      {
         return this->pointer;
         return this->pd;
      }

      template< typename Device = Devices::Host >      
      const Object& getData() const
      {
         static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value, "Only Devices::Host or Devices::Cuda devices are accepted here." );
         Assert( this->pd, );
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
            return *( this->pointer );
            return this->pd->data;
         if( std::is_same< Device, Devices::Cuda >::value )
            return *( this->cuda_pointer );            
      }
@@ -169,22 +172,20 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      Object& modifyData()
      {
         static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value, "Only Devices::Host or Devices::Cuda devices are accepted here." );
         Assert( this->pd, );
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
         {
            return *( this->pointer );
         }
            return this->pd->data;
         if( std::is_same< Device, Devices::Cuda >::value )
         {
            return *( this->cuda_pointer );
      }
      }
      
      const ThisType& operator=( ThisType& ptr )
      {
         this->free();
         this->pointer = ptr.pointer;
         this->pd = ptr.pd;
         this->cuda_pointer = ptr.cuda_pointer;
         ptr.pointer = nullptr;
         ptr.pd = nullptr;
         ptr.cuda_pointer = nullptr;
         return *this;
      }
@@ -196,10 +197,12 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      
      bool synchronize()
      {
         if( ! this->pd )
            return true;
#ifdef HAVE_CUDA
         if( this->modified() )
         {
            cudaMemcpy( (void*) this->cuda_pointer, (void*) this->pointer, sizeof( Object ), cudaMemcpyHostToDevice );
            cudaMemcpy( (void*) this->cuda_pointer, (void*) &this->pd->data, sizeof( Object ), cudaMemcpyHostToDevice );
            if( ! checkCudaDevice )
               return false;
            this->set_last_sync_state();
@@ -219,18 +222,25 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      
   protected:

      struct PointerData
      {
         Object data;
         char data_image[ sizeof(Object) ];

         template< typename... Args >
         explicit PointerData( Args... args )
         : data( args... )
         {}
      };

      template< typename... Args >
      bool allocate( Args... args )
      {
         // Allocate space for two objects: the first one is the "real" object,
         // the second is a "mirror" used to set last-sync state.
         this->pointer = (Object*) ::operator new( 2 * sizeof( Object ) );
         if( ! this->pointer )
         this->pd = new PointerData( args... );
         if( ! this->pd )
            return false;
         // construct the object
         new( (void*) this->pointer ) Object( args... );
         // pass to device
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         this->cuda_pointer = Devices::Cuda::passToDevice( this->pd->data );
         if( ! this->cuda_pointer )
            return false;
         // set last-sync state
@@ -241,28 +251,29 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer

      void set_last_sync_state()
      {
         std::memcpy( (void*) (this->pointer + 1), (void*) this->pointer, sizeof( ObjectType ) );
         Assert( this->pd, );
         std::memcpy( (void*) &this->pd->data_image, (void*) &this->pd->data, sizeof( ObjectType ) );
      }

      bool modified()
      {
         if( ! this->pointer )
            return false;
         return std::memcmp( (void*) (this->pointer + 1), (void*) this->pointer, sizeof( ObjectType ) ) != 0;
         Assert( this->pd, );
         return std::memcmp( (void*) &this->pd->data_image, (void*) &this->pd->data, sizeof( ObjectType ) ) != 0;
      }

      void free()
      {
         if( this->pointer ) {
            // call destructor on the "real" object, but not on the "mirror"
            ( (Object*)this->pointer )->~Object();
            ::operator delete( (void*) this->pointer );
         }
         if( this->pd )
            delete this->pd;
         if( this->cuda_pointer )
            Devices::Cuda::freeFromDevice( this->cuda_pointer );
      }
      
      Object *pointer, *cuda_pointer;
      PointerData* pd;

      // cuda_pointer can't be part of PointerData structure, since we would be
      // unable to dereference this-pd on the device
      Object* cuda_pointer;
};

} // namespace TNL