Commit 5ce9ef29 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized SmartPointersRegister using std::unordered_set instead of std::list

parent 93db2ce6
Loading
Loading
Loading
Loading
+18 −15
Original line number Diff line number Diff line
@@ -9,6 +9,7 @@
/* See Copyright Notice in tnl/Copyright */

#include <TNL/Devices/Cuda.h>
#include <TNL/Devices/CudaDeviceInfo.h>
#include <TNL/core/mfuncs.h>
#include <TNL/tnlConfig.h>
#include <TNL/Config/ConfigDescription.h>
@@ -46,7 +47,8 @@ int Cuda::getNumberOfGrids( const int blocks,

}*/

void Cuda::configSetup( Config::ConfigDescription& config, const String& prefix )
void Cuda::configSetup( Config::ConfigDescription& config,
                        const String& prefix )
{
#ifdef HAVE_CUDA
   config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation.", 0 );
@@ -71,18 +73,19 @@ bool Cuda::setup( const Config::ParameterContainer& parameters,

void Cuda::insertSmartPointer( SmartPointer* pointer )
{
    smartPointersRegister.insert( pointer, 0 );
   smartPointersRegister.insert( pointer, Devices::CudaDeviceInfo::getActiveDevice() );
}

void Cuda::removeSmartPointer( SmartPointer* pointer )
{
    smartPointersRegister.remove( pointer, 0 );
   smartPointersRegister.remove( pointer, Devices::CudaDeviceInfo::getActiveDevice() );
}

bool Cuda::synchronizeDevice( int deviceId )
{
    smartPointersRegister.synchronizeDevice( deviceId );
    return checkCudaDevice;
   if( deviceId < 0 )
      deviceId = Devices::CudaDeviceInfo::getActiveDevice();
   return smartPointersRegister.synchronizeDevice( deviceId );
}

} // namespace Devices
+4 −3
Original line number Diff line number Diff line
@@ -104,13 +104,14 @@ class Cuda
   
   static void removeSmartPointer( SmartPointer* pointer );
   
   static bool synchronizeDevice( int deviceId = 0  );
   // Negative deviceId means that CudaDeviceInfo::getActiveDevice will be
   // called to get the device ID.
   static bool synchronizeDevice( int deviceId = -1 );
   
   protected:
   
   static SmartPointersRegister smartPointersRegister;


};

#ifdef HAVE_CUDA
+0 −1
Original line number Diff line number Diff line
@@ -407,7 +407,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
         this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
         if( ! this->cuda_pointer )
            return false;
         // TODO: what if 'this' is already in the register?
         Devices::Cuda::insertSmartPointer( this );
         return true;
      }
+20 −18
Original line number Diff line number Diff line
@@ -17,34 +17,36 @@

#include <iostream>
#include <TNL/SmartPointersRegister.h>

SmartPointersRegister::SmartPointersRegister( int devicesCount )
{
   Assert( devicesCount > 0, std::cerr << "devicesCount = " << devicesCount );
   pointersOnDevices.resize( devicesCount );
   this->devicesCount = devicesCount;
}
#include <TNL/Devices/Cuda.h>

void SmartPointersRegister::insert( SmartPointer* pointer, int deviceId )
{
   Assert( deviceId >= 0 && deviceId < this->devicesCount,
              std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount );
   //std::cerr << "Inserting pointer " << pointer << " to the register..." << std::endl;
   pointersOnDevices[ deviceId ].push_back( pointer );
   pointersOnDevices[ deviceId ].insert( pointer );
}

void SmartPointersRegister::remove( SmartPointer* pointer, int deviceId )
{
   Assert( deviceId >= 0 && deviceId < this->devicesCount,
              std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount );   
   pointersOnDevices[ deviceId ].remove( pointer );
   try {
      pointersOnDevices.at( deviceId ).erase( pointer );
   }
   catch( std::out_of_range ) {
      std::cerr << "Given deviceId " << deviceId << " does not have any pointers yet. "
                << "Requested to remove pointer " << pointer << ". "
                << "This is most likely a bug in the smart pointer." << std::endl;
      throw;
   }
}


bool SmartPointersRegister::synchronizeDevice( int deviceId )
{
   for( ListType::iterator it = pointersOnDevices[ deviceId ].begin();
        it != pointersOnDevices[ deviceId ].end();
        it++ )
      ( *it )->synchronize();            
   try {
      const auto & set = pointersOnDevices.at( deviceId );
      for( auto&& it : set )
         ( *it ).synchronize();
      return checkCudaDevice;
   }
   catch( std::out_of_range ) {
      return false;
   }
}
+4 −8
Original line number Diff line number Diff line
@@ -17,8 +17,8 @@

#pragma once

#include <vector>
#include <list>
#include <unordered_set>
#include <unordered_map>
#include <TNL/SmartPointer.h>
#include <TNL/Assert.h>

@@ -27,8 +27,6 @@ class SmartPointersRegister
  
   public:
   
      SmartPointersRegister( int devicesCount = 1 );
      
      void insert( SmartPointer* pointer, int deviceId );
      
      void remove( SmartPointer* pointer, int deviceId );
@@ -37,10 +35,8 @@ class SmartPointersRegister
      
   protected:
      
      typedef std::list< SmartPointer* > ListType;   
      
      std::vector< ListType > pointersOnDevices;
      typedef std::unordered_set< SmartPointer* > SetType;   
      
      int devicesCount;
      std::unordered_map< int, SetType > pointersOnDevices;
};