Skip to content
Snippets Groups Projects
Commit 5ce9ef29 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Optimized SmartPointersRegister using std::unordered_set instead of std::list

parent 93db2ce6
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ ...@@ -9,6 +9,7 @@
/* See Copyright Notice in tnl/Copyright */ /* See Copyright Notice in tnl/Copyright */
#include <TNL/Devices/Cuda.h> #include <TNL/Devices/Cuda.h>
#include <TNL/Devices/CudaDeviceInfo.h>
#include <TNL/core/mfuncs.h> #include <TNL/core/mfuncs.h>
#include <TNL/tnlConfig.h> #include <TNL/tnlConfig.h>
#include <TNL/Config/ConfigDescription.h> #include <TNL/Config/ConfigDescription.h>
...@@ -16,9 +17,9 @@ ...@@ -16,9 +17,9 @@
namespace TNL { namespace TNL {
namespace Devices { namespace Devices {
SmartPointersRegister Cuda::smartPointersRegister; SmartPointersRegister Cuda::smartPointersRegister;
String Cuda::getDeviceType() String Cuda::getDeviceType()
{ {
return String( "Cuda" ); return String( "Cuda" );
...@@ -30,13 +31,13 @@ int Cuda::getGPUTransferBufferSize() ...@@ -30,13 +31,13 @@ int Cuda::getGPUTransferBufferSize()
} }
int Cuda::getNumberOfBlocks( const int threads, int Cuda::getNumberOfBlocks( const int threads,
const int blockSize ) const int blockSize )
{ {
return roundUpDivision( threads, blockSize ); return roundUpDivision( threads, blockSize );
} }
int Cuda::getNumberOfGrids( const int blocks, int Cuda::getNumberOfGrids( const int blocks,
const int gridSize ) const int gridSize )
{ {
return roundUpDivision( blocks, gridSize ); return roundUpDivision( blocks, gridSize );
} }
...@@ -46,17 +47,18 @@ int Cuda::getNumberOfGrids( const int blocks, ...@@ -46,17 +47,18 @@ int Cuda::getNumberOfGrids( const int blocks,
}*/ }*/
void Cuda::configSetup( Config::ConfigDescription& config, const String& prefix ) void Cuda::configSetup( Config::ConfigDescription& config,
const String& prefix )
{ {
#ifdef HAVE_CUDA #ifdef HAVE_CUDA
config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation.", 0 ); config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation.", 0 );
#else #else
config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation (not supported on this system).", 0 ); config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation (not supported on this system).", 0 );
#endif #endif
} }
bool Cuda::setup( const Config::ParameterContainer& parameters, bool Cuda::setup( const Config::ParameterContainer& parameters,
const String& prefix ) const String& prefix )
{ {
#ifdef HAVE_CUDA #ifdef HAVE_CUDA
int cudaDevice = parameters.getParameter< int >( "cuda-device" ); int cudaDevice = parameters.getParameter< int >( "cuda-device" );
...@@ -71,18 +73,19 @@ bool Cuda::setup( const Config::ParameterContainer& parameters, ...@@ -71,18 +73,19 @@ bool Cuda::setup( const Config::ParameterContainer& parameters,
void Cuda::insertSmartPointer( SmartPointer* pointer ) void Cuda::insertSmartPointer( SmartPointer* pointer )
{ {
smartPointersRegister.insert( pointer, 0 ); smartPointersRegister.insert( pointer, Devices::CudaDeviceInfo::getActiveDevice() );
} }
void Cuda::removeSmartPointer( SmartPointer* pointer ) void Cuda::removeSmartPointer( SmartPointer* pointer )
{ {
smartPointersRegister.remove( pointer, 0 ); smartPointersRegister.remove( pointer, Devices::CudaDeviceInfo::getActiveDevice() );
} }
bool Cuda::synchronizeDevice( int deviceId ) bool Cuda::synchronizeDevice( int deviceId )
{ {
smartPointersRegister.synchronizeDevice( deviceId ); if( deviceId < 0 )
return checkCudaDevice; deviceId = Devices::CudaDeviceInfo::getActiveDevice();
return smartPointersRegister.synchronizeDevice( deviceId );
} }
} // namespace Devices } // namespace Devices
......
...@@ -104,12 +104,13 @@ class Cuda ...@@ -104,12 +104,13 @@ class Cuda
static void removeSmartPointer( SmartPointer* pointer ); static void removeSmartPointer( SmartPointer* pointer );
static bool synchronizeDevice( int deviceId = 0 ); // Negative deviceId means that CudaDeviceInfo::getActiveDevice will be
// called to get the device ID.
static bool synchronizeDevice( int deviceId = -1 );
protected: protected:
static SmartPointersRegister smartPointersRegister; static SmartPointersRegister smartPointersRegister;
}; };
......
...@@ -407,7 +407,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer ...@@ -407,7 +407,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer
this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer ); this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer );
if( ! this->cuda_pointer ) if( ! this->cuda_pointer )
return false; return false;
// TODO: what if 'this' is already in the register?
Devices::Cuda::insertSmartPointer( this ); Devices::Cuda::insertSmartPointer( this );
return true; return true;
} }
......
...@@ -17,34 +17,36 @@ ...@@ -17,34 +17,36 @@
#include <iostream> #include <iostream>
#include <TNL/SmartPointersRegister.h> #include <TNL/SmartPointersRegister.h>
#include <TNL/Devices/Cuda.h>
SmartPointersRegister::SmartPointersRegister( int devicesCount )
{
Assert( devicesCount > 0, std::cerr << "devicesCount = " << devicesCount );
pointersOnDevices.resize( devicesCount );
this->devicesCount = devicesCount;
}
void SmartPointersRegister::insert( SmartPointer* pointer, int deviceId ) void SmartPointersRegister::insert( SmartPointer* pointer, int deviceId )
{ {
Assert( deviceId >= 0 && deviceId < this->devicesCount,
std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount );
//std::cerr << "Inserting pointer " << pointer << " to the register..." << std::endl; //std::cerr << "Inserting pointer " << pointer << " to the register..." << std::endl;
pointersOnDevices[ deviceId ].push_back( pointer ); pointersOnDevices[ deviceId ].insert( pointer );
} }
void SmartPointersRegister::remove( SmartPointer* pointer, int deviceId ) void SmartPointersRegister::remove( SmartPointer* pointer, int deviceId )
{ {
Assert( deviceId >= 0 && deviceId < this->devicesCount, try {
std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount ); pointersOnDevices.at( deviceId ).erase( pointer );
pointersOnDevices[ deviceId ].remove( pointer ); }
catch( std::out_of_range ) {
std::cerr << "Given deviceId " << deviceId << " does not have any pointers yet. "
<< "Requested to remove pointer " << pointer << ". "
<< "This is most likely a bug in the smart pointer." << std::endl;
throw;
}
} }
bool SmartPointersRegister::synchronizeDevice( int deviceId ) bool SmartPointersRegister::synchronizeDevice( int deviceId )
{ {
for( ListType::iterator it = pointersOnDevices[ deviceId ].begin(); try {
it != pointersOnDevices[ deviceId ].end(); const auto & set = pointersOnDevices.at( deviceId );
it++ ) for( auto&& it : set )
( *it )->synchronize(); ( *it ).synchronize();
return checkCudaDevice;
}
catch( std::out_of_range ) {
return false;
}
} }
...@@ -17,8 +17,8 @@ ...@@ -17,8 +17,8 @@
#pragma once #pragma once
#include <vector> #include <unordered_set>
#include <list> #include <unordered_map>
#include <TNL/SmartPointer.h> #include <TNL/SmartPointer.h>
#include <TNL/Assert.h> #include <TNL/Assert.h>
...@@ -27,8 +27,6 @@ class SmartPointersRegister ...@@ -27,8 +27,6 @@ class SmartPointersRegister
public: public:
SmartPointersRegister( int devicesCount = 1 );
void insert( SmartPointer* pointer, int deviceId ); void insert( SmartPointer* pointer, int deviceId );
void remove( SmartPointer* pointer, int deviceId ); void remove( SmartPointer* pointer, int deviceId );
...@@ -37,10 +35,8 @@ class SmartPointersRegister ...@@ -37,10 +35,8 @@ class SmartPointersRegister
protected: protected:
typedef std::list< SmartPointer* > ListType; typedef std::unordered_set< SmartPointer* > SetType;
std::vector< ListType > pointersOnDevices;
int devicesCount; std::unordered_map< int, SetType > pointersOnDevices;
}; };
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment