diff --git a/src/TNL/Devices/Cuda.cpp b/src/TNL/Devices/Cuda.cpp index c180825c37bee8ec05fd8de89cc022cb47bd85aa..baf6bd6176a12611b2e8e6584c13a7687e85461b 100644 --- a/src/TNL/Devices/Cuda.cpp +++ b/src/TNL/Devices/Cuda.cpp @@ -9,6 +9,7 @@ /* See Copyright Notice in tnl/Copyright */ #include <TNL/Devices/Cuda.h> +#include <TNL/Devices/CudaDeviceInfo.h> #include <TNL/core/mfuncs.h> #include <TNL/tnlConfig.h> #include <TNL/Config/ConfigDescription.h> @@ -16,9 +17,9 @@ namespace TNL { namespace Devices { - -SmartPointersRegister Cuda::smartPointersRegister; - + +SmartPointersRegister Cuda::smartPointersRegister; + String Cuda::getDeviceType() { return String( "Cuda" ); @@ -30,13 +31,13 @@ int Cuda::getGPUTransferBufferSize() } int Cuda::getNumberOfBlocks( const int threads, - const int blockSize ) + const int blockSize ) { return roundUpDivision( threads, blockSize ); } int Cuda::getNumberOfGrids( const int blocks, - const int gridSize ) + const int gridSize ) { return roundUpDivision( blocks, gridSize ); } @@ -46,17 +47,18 @@ int Cuda::getNumberOfGrids( const int blocks, }*/ -void Cuda::configSetup( Config::ConfigDescription& config, const String& prefix ) +void Cuda::configSetup( Config::ConfigDescription& config, + const String& prefix ) { #ifdef HAVE_CUDA - config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation.", 0 ); + config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation.", 0 ); #else - config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation (not supported on this system).", 0 ); + config.addEntry< int >( prefix + "cuda-device", "Choose CUDA device to run the computation (not supported on this system).", 0 ); #endif } - + bool Cuda::setup( const Config::ParameterContainer& parameters, - const String& prefix ) + const String& prefix ) { #ifdef HAVE_CUDA int cudaDevice = parameters.getParameter< int >( "cuda-device" ); @@ -71,18 +73,19 @@ bool Cuda::setup( const Config::ParameterContainer& parameters, void Cuda::insertSmartPointer( SmartPointer* pointer ) { - smartPointersRegister.insert( pointer, 0 ); + smartPointersRegister.insert( pointer, Devices::CudaDeviceInfo::getActiveDevice() ); } void Cuda::removeSmartPointer( SmartPointer* pointer ) { - smartPointersRegister.remove( pointer, 0 ); + smartPointersRegister.remove( pointer, Devices::CudaDeviceInfo::getActiveDevice() ); } - + bool Cuda::synchronizeDevice( int deviceId ) { - smartPointersRegister.synchronizeDevice( deviceId ); - return checkCudaDevice; + if( deviceId < 0 ) + deviceId = Devices::CudaDeviceInfo::getActiveDevice(); + return smartPointersRegister.synchronizeDevice( deviceId ); } } // namespace Devices diff --git a/src/TNL/Devices/Cuda.h b/src/TNL/Devices/Cuda.h index 20ea0e0e355b730558330495d89babe5393c982c..c8ee52a105509ff5e193f9a24348116b379f4a85 100644 --- a/src/TNL/Devices/Cuda.h +++ b/src/TNL/Devices/Cuda.h @@ -104,12 +104,13 @@ class Cuda static void removeSmartPointer( SmartPointer* pointer ); - static bool synchronizeDevice( int deviceId = 0 ); + // Negative deviceId means that CudaDeviceInfo::getActiveDevice will be + // called to get the device ID. + static bool synchronizeDevice( int deviceId = -1 ); protected: - static SmartPointersRegister smartPointersRegister; - + static SmartPointersRegister smartPointersRegister; }; diff --git a/src/TNL/SharedPointer.h b/src/TNL/SharedPointer.h index d9106b7bc8cda36ca2d3e5acf4b3d2e4e8a0acc1..72b05d66d8c8219ab57beff9f79d3b538966c2c0 100644 --- a/src/TNL/SharedPointer.h +++ b/src/TNL/SharedPointer.h @@ -407,7 +407,6 @@ class SharedPointer< Object, Devices::Cuda, lazy > : public SmartPointer this->cuda_pointer = Devices::Cuda::passToDevice( *this->pointer ); if( ! this->cuda_pointer ) return false; - // TODO: what if 'this' is already in the register? Devices::Cuda::insertSmartPointer( this ); return true; } diff --git a/src/TNL/SmartPointersRegister.cpp b/src/TNL/SmartPointersRegister.cpp index 7ed57eb437c4992202ecae41193deb28d06ede40..7fe8a654f4a6ac6c435b4d8ccb02aadc84df37fb 100644 --- a/src/TNL/SmartPointersRegister.cpp +++ b/src/TNL/SmartPointersRegister.cpp @@ -17,34 +17,36 @@ #include <iostream> #include <TNL/SmartPointersRegister.h> - -SmartPointersRegister::SmartPointersRegister( int devicesCount ) -{ - Assert( devicesCount > 0, std::cerr << "devicesCount = " << devicesCount ); - pointersOnDevices.resize( devicesCount ); - this->devicesCount = devicesCount; -} +#include <TNL/Devices/Cuda.h> void SmartPointersRegister::insert( SmartPointer* pointer, int deviceId ) { - Assert( deviceId >= 0 && deviceId < this->devicesCount, - std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount ); //std::cerr << "Inserting pointer " << pointer << " to the register..." << std::endl; - pointersOnDevices[ deviceId ].push_back( pointer ); + pointersOnDevices[ deviceId ].insert( pointer ); } void SmartPointersRegister::remove( SmartPointer* pointer, int deviceId ) { - Assert( deviceId >= 0 && deviceId < this->devicesCount, - std::cerr << "deviceId = " << deviceId << " devicesCount = " << this->devicesCount ); - pointersOnDevices[ deviceId ].remove( pointer ); + try { + pointersOnDevices.at( deviceId ).erase( pointer ); + } + catch( std::out_of_range ) { + std::cerr << "Given deviceId " << deviceId << " does not have any pointers yet. " + << "Requested to remove pointer " << pointer << ". " + << "This is most likely a bug in the smart pointer." << std::endl; + throw; + } } - bool SmartPointersRegister::synchronizeDevice( int deviceId ) { - for( ListType::iterator it = pointersOnDevices[ deviceId ].begin(); - it != pointersOnDevices[ deviceId ].end(); - it++ ) - ( *it )->synchronize(); + try { + const auto & set = pointersOnDevices.at( deviceId ); + for( auto&& it : set ) + ( *it ).synchronize(); + return checkCudaDevice; + } + catch( std::out_of_range ) { + return false; + } } diff --git a/src/TNL/SmartPointersRegister.h b/src/TNL/SmartPointersRegister.h index ebbc031d1217bb80dfda5e9abc96087a109e5f82..a2e9c43639501cf0d287e531cc65842c85b36c59 100644 --- a/src/TNL/SmartPointersRegister.h +++ b/src/TNL/SmartPointersRegister.h @@ -17,8 +17,8 @@ #pragma once -#include <vector> -#include <list> +#include <unordered_set> +#include <unordered_map> #include <TNL/SmartPointer.h> #include <TNL/Assert.h> @@ -27,8 +27,6 @@ class SmartPointersRegister public: - SmartPointersRegister( int devicesCount = 1 ); - void insert( SmartPointer* pointer, int deviceId ); void remove( SmartPointer* pointer, int deviceId ); @@ -37,10 +35,8 @@ class SmartPointersRegister protected: - typedef std::list< SmartPointer* > ListType; - - std::vector< ListType > pointersOnDevices; + typedef std::unordered_set< SmartPointer* > SetType; - int devicesCount; + std::unordered_map< int, SetType > pointersOnDevices; };