Commit ab534045 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Rewritten detection of modified data in smart pointers

The original 'modified' flag was not shared between instances of the
smart pointer, so theoretically the same data object might have been
synchronized more than once or even not at all.

Another problem with the detection was that e.g. access through the
modifyData method does not imply modification of the object's class
data. For example if the object is TNL::Vector, we need to go through
modifyData to be able to modify the vector data, but synchronization is
not needed since we go through another pointer. Another example is
repeated binding of the same object, e.g. bindDofs in the Problem
classes.
parent 0e4a66a1
Loading
Loading
Loading
Loading
+32 −20
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@
#include <TNL/Devices/Cuda.h>
#include <TNL/SmartPointer.h>

#include <cstring>


namespace TNL {

@@ -211,13 +213,15 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

      explicit  DevicePointer( ObjectType& obj )
      : pointer( 0 ), cuda_pointer( 0 ),
        counter( 0 ), modified( false )
        counter( 0 ), last_sync_state( 0 )
      {
         this->counter = new int( 1 );
         this->pointer = &obj;
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         if( ! this->cuda_pointer )
            return;
         this->last_sync_state = ::operator new( sizeof( Object ) );
         this->set_last_sync_state();
         Devices::Cuda::insertSmartPointer( this );
      }

@@ -226,7 +230,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         *counter += 1;
      }
@@ -238,7 +242,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         *counter += 1;
      }
@@ -248,12 +252,12 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         pointer.pointer = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.modified = false;
         pointer.last_sync_state = nullptr;
      }

      // conditional constructor for non-const -> const data
@@ -263,12 +267,12 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         pointer.pointer = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.modified = false;
         pointer.last_sync_state = nullptr;
      }

      const Object* operator->() const
@@ -278,7 +282,6 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

      Object* operator->()
      {
         this->modified = true;
         return this->pointer;
      }

@@ -289,7 +292,6 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

      Object& operator *()
      {
         this->modified = true;
         return *( this->pointer );
      }

@@ -320,7 +322,6 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
         {
            this->modified = true;
            return *( this->pointer );
         }
         if( std::is_same< Device, Devices::Cuda >::value )
@@ -336,7 +337,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
         return *this;
      }
@@ -350,7 +351,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
         return *this;
      }
@@ -362,11 +363,11 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.modified = false;
         ptr.last_sync_state = nullptr;
         return *this;
      }

@@ -379,18 +380,18 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.modified = false;
         ptr.last_sync_state = nullptr;
         return *this;
      }

      bool synchronize()
      {
#ifdef HAVE_CUDA
         if( this->modified )
         if( this->modified() )
         {
            Assert( this->pointer, );
            Assert( this->cuda_pointer, );
@@ -398,7 +399,7 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
            if( ! checkCudaDevice ) {
               return false;
            }
            this->modified = false;
            this->set_last_sync_state();
            return true;
         }
         return true;
@@ -415,6 +416,16 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer

   protected:

      void set_last_sync_state()
      {
         std::memcpy( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) );
      }

      bool modified()
      {
         return std::memcmp( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) ) != 0;
      }

      void free()
      {
         if( this->counter )
@@ -425,16 +436,17 @@ class DevicePointer< Object, Devices::Cuda > : public SmartPointer
               this->counter = nullptr;
               if( this->cuda_pointer )
                  Devices::Cuda::freeFromDevice( this->cuda_pointer );
               if( this->last_sync_state )
                  ::operator delete( this->last_sync_state );
            }
         }

      }

      Object *pointer, *cuda_pointer;

      int* counter;

      bool modified;
      void* last_sync_state;
};

} // namespace TNL
+40 −21
Original line number Diff line number Diff line
@@ -21,6 +21,8 @@
#include <TNL/Devices/Cuda.h>
#include <TNL/SmartPointer.h>

#include <cstring>


//#define TNL_DEBUG_SHARED_POINTERS

@@ -309,7 +311,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      template< typename... Args >
      explicit  SharedPointer( Args... args )
      : pointer( 0 ), cuda_pointer( 0 ),
        counter( 0 ), modified( false )
        counter( 0 ), last_sync_state( 0 )
      {
         if( ! lazy )
         {
@@ -318,6 +320,8 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
            this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
            if( ! this->cuda_pointer )
               return;
            this->last_sync_state = ::operator new( sizeof( Object ) );
            this->set_last_sync_state();
#ifdef TNL_DEBUG_SHARED_POINTERS
            std::cerr << "Created shared pointer to " << demangle(typeid(ObjectType).name()) << " (cuda_pointer = " << this->cuda_pointer << ")" << std::endl;
#endif
@@ -330,7 +334,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         *counter += 1;
      }
@@ -342,7 +346,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         *counter += 1;
      }
@@ -352,12 +356,12 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         pointer.pointer = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.modified = false;
         pointer.last_sync_state = nullptr;
      }

      // conditional constructor for non-const -> const data
@@ -367,12 +371,12 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      : pointer( pointer.pointer ),
        cuda_pointer( pointer.cuda_pointer ),
        counter( pointer.counter ),
        modified( pointer.modified )
        last_sync_state( pointer.last_sync_state )
      {
         pointer.pointer = nullptr;
         pointer.cuda_pointer = nullptr;
         pointer.counter = nullptr;
         pointer.modified = false;
         pointer.last_sync_state = nullptr;
      }

      template< typename... Args >
@@ -388,6 +392,8 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
            this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
            if( ! this->cuda_pointer )
               return false;
            this->last_sync_state = ::operator new( sizeof( Object ) );
            this->set_last_sync_state();
            Devices::Cuda::insertSmartPointer( this );
            return true;
         }
@@ -401,10 +407,13 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
#ifdef HAVE_CUDA
            cudaMemcpy( (void*) this->cuda_pointer, (void*) this->pointer, sizeof( Object ), cudaMemcpyHostToDevice );
#endif
            this->set_last_sync_state();
            return true;
         }

         this->modified = false;
         // free will just decrement the counter
         this->free();

         this->counter= new int( 1 );
         this->pointer = new Object( args... );
         if( ! this->pointer || ! this->counter )
@@ -412,6 +421,8 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         if( ! this->cuda_pointer )
            return false;
         this->last_sync_state = ::operator new( sizeof( Object ) );
         this->set_last_sync_state();
         Devices::Cuda::insertSmartPointer( this );
         return true;
      }
@@ -423,7 +434,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer

      Object* operator->()
      {
         this->modified = true;
         return this->pointer;
      }

@@ -434,7 +444,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer

      Object& operator *()
      {
         this->modified = true;
         return *( this->pointer );
      }

@@ -465,7 +474,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         Assert( this->cuda_pointer, );
         if( std::is_same< Device, Devices::Host >::value )
         {
            this->modified = true;
            return *( this->pointer );
         }
         if( std::is_same< Device, Devices::Cuda >::value )
@@ -481,7 +489,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
#ifdef TNL_DEBUG_SHARED_POINTERS
         std::cerr << "Copy-assigned shared pointer: counter = " << *(this->counter) << ", type: " << demangle(typeid(ObjectType).name()) << std::endl;
@@ -498,7 +506,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         *( this->counter ) += 1;
#ifdef TNL_DEBUG_SHARED_POINTERS
         std::cerr << "Copy-assigned shared pointer: counter = " << *(this->counter) << ", type: " << demangle(typeid(ObjectType).name()) << std::endl;
@@ -513,11 +521,11 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.modified = false;
         ptr.last_sync_state = nullptr;
#ifdef TNL_DEBUG_SHARED_POINTERS
         std::cerr << "Move-assigned shared pointer: counter = " << *(this->counter) << ", type: " << demangle(typeid(ObjectType).name()) << std::endl;
#endif
@@ -533,11 +541,11 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->counter = ptr.counter;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.counter = nullptr;
         ptr.modified = false;
         ptr.last_sync_state = nullptr;
#ifdef TNL_DEBUG_SHARED_POINTERS
         std::cerr << "Move-assigned shared pointer: counter = " << *(this->counter) << ", type: " << demangle(typeid(ObjectType).name()) << std::endl;
#endif
@@ -547,7 +555,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
      bool synchronize()
      {
#ifdef HAVE_CUDA
         if( this->modified )
         if( this->modified() )
         {
#ifdef TNL_DEBUG_SHARED_POINTERS
            std::cerr << "Synchronizing shared pointer: counter = " << *(this->counter) << ", type: " << demangle(typeid(ObjectType).name()) << std::endl;
@@ -559,7 +567,7 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
            if( ! checkCudaDevice ) {
               return false;
            }
            this->modified = false;
            this->set_last_sync_state();
            return true;
         }
         return true;
@@ -576,6 +584,16 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer

   protected:

      void set_last_sync_state()
      {
         std::memcpy( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) );
      }

      bool modified()
      {
         return std::memcmp( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) ) != 0;
      }

      void free()
      {
         if( this->counter )
@@ -591,19 +609,20 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
                  delete this->pointer;
               if( this->cuda_pointer )
                  Devices::Cuda::freeFromDevice( this->cuda_pointer );
               if( this->last_sync_state )
                  ::operator delete( this->last_sync_state );
#ifdef TNL_DEBUG_SHARED_POINTERS
               std::cerr << "...deleted data." << std::endl;
#endif
            }
         }

      }

      Object *pointer, *cuda_pointer;

      int* counter;

      bool modified;
      void* last_sync_state;
};

} // namespace TNL
+25 −14
Original line number Diff line number Diff line
@@ -21,6 +21,9 @@
#include <TNL/Devices/Cuda.h>
#include <TNL/SmartPointer.h>

#include <cstring>


namespace TNL { 

template< typename Object, typename Device = typename Object::DeviceType >
@@ -37,11 +40,6 @@ class UniquePointer< Object, Devices::Host > : public SmartPointer
      typedef Devices::Host DeviceType;
      typedef UniquePointer< Object, Devices::Host > ThisType;
         
      UniquePointer()
      {
         this->pointer = new Object();
      }
      
      template< typename... Args >
      UniquePointer( const Args... args )
      {
@@ -123,14 +121,17 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      
      typedef Object ObjectType;
      typedef Devices::Cuda DeviceType;
      typedef UniquePointer< Object, Devices::Host > ThisType;
      typedef UniquePointer< Object, Devices::Cuda > ThisType;
         
      template< typename... Args >
      UniquePointer( const Args... args )
      : modified( false )
      : pointer( 0 ), cuda_pointer( 0 ),
        last_sync_state( 0 )
      {
         this->pointer = new Object( args... );
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         this->last_sync_state = ::operator new( sizeof( Object ) );
         this->set_last_sync_state();
         Devices::Cuda::insertSmartPointer( this );
      }
      
@@ -141,7 +142,6 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      
      Object* operator->()
      {
         this->modified = true;
         return this->pointer;
      }
      
@@ -152,7 +152,6 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      
      Object& operator *()
      {
         this->modified = true;
         return *( this->pointer );
      }
      
@@ -177,7 +176,6 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
         static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value, "Only Devices::Host or Devices::Cuda devices are accepted here." );
         if( std::is_same< Device, Devices::Host >::value )
         {
            this->modified = true;
            return *( this->pointer );
         }
         if( std::is_same< Device, Devices::Cuda >::value )
@@ -194,10 +192,10 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
            Devices::Cuda::freeFromDevice( this->cuda_pointer );
         this->pointer = ptr.pointer;
         this->cuda_pointer = ptr.cuda_pointer;
         this->modified = ptr.modified;
         this->last_sync_state = ptr.last_sync_state;
         ptr.pointer = nullptr;
         ptr.cuda_pointer = nullptr;
         ptr.modified = false;
         ptr.last_sync_state = nullptr;
         return *this;
      }
      
@@ -209,11 +207,12 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
      bool synchronize()
      {
#ifdef HAVE_CUDA
         if( this->modified )
         if( this->modified() )
         {
            cudaMemcpy( (void*) this->cuda_pointer, (void*) this->pointer, sizeof( Object ), cudaMemcpyHostToDevice );
            if( ! checkCudaDevice )
               return false;
            this->set_last_sync_state();
            return true;
         }
         return true;
@@ -228,14 +227,26 @@ class UniquePointer< Object, Devices::Cuda > : public SmartPointer
            delete this->pointer;
         if( this->cuda_pointer )
            Devices::Cuda::freeFromDevice( this->cuda_pointer );
         if( this->last_sync_state )
            ::operator delete( this->last_sync_state );
         Devices::Cuda::removeSmartPointer( this );
      }
      
   protected:

      void set_last_sync_state()
      {
         std::memcpy( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) );
      }

      bool modified()
      {
         return std::memcmp( (void*) this->last_sync_state, (void*) this->pointer, sizeof( ObjectType ) ) != 0;
      }
      
      Object *pointer, *cuda_pointer;
      
      bool modified;      
      void* last_sync_state;
};

} // namespace TNL