/*************************************************************************** Atomic.h - description ------------------- begin : Sep 14, 2018 copyright : (C) 2018 by Tomas Oberhuber et al. email : tomas.oberhuber@fjfi.cvut.cz ***************************************************************************/ /* See Copyright Notice in tnl/Copyright */ // Implemented by: Jakub Klinkovský #pragma once #include <atomic> // std::atomic #include <TNL/Devices/Host.h> #include <TNL/Devices/Cuda.h> #include <TNL/param-types.h> namespace TNL { template< typename T, typename Device > class Atomic {}; template< typename T > class Atomic< T, Devices::Host > : public std::atomic< T > { public: Atomic() noexcept = default; // inherit constructors using std::atomic< T >::atomic; // NOTE: std::atomic is not copyable (see https://stackoverflow.com/a/15250851 for // an explanation), but we need copyability for TNL::Containers::Array. Note that // this copy-constructor and copy-assignment operator are not atomic as they // synchronize only with respect to one or the other object. Atomic( const Atomic& desired ) noexcept { this->store(desired.load()); } Atomic& operator=( const Atomic& desired ) noexcept { this->store(desired.load()); return *this; } // just for compatibility with TNL::Containers::Array... static String getType() { return "Atomic< " + TNL::getType< T >() + ", " + Devices::Host::getDeviceType() + " >"; } // CAS loops for updating maximum and minimum // reference: https://stackoverflow.com/a/16190791 T fetch_max( T value ) noexcept { const T old = *this; T prev_value = old; while(prev_value < value && ! this->compare_exchange_weak(prev_value, value)) ; return old; } T fetch_min( T value ) noexcept { const T old = *this; T prev_value = old; while(prev_value > value && ! this->compare_exchange_weak(prev_value, value)) ; return old; } }; template< typename T > class Atomic< T, Devices::Cuda > { public: using value_type = T; // FIXME // using difference_type = typename std::atomic< T >::difference_type; __cuda_callable__ Atomic() noexcept = default; __cuda_callable__ constexpr Atomic( T desired ) noexcept : value(desired) {} __cuda_callable__ T operator=( T desired ) noexcept { store( desired ); return desired; } // NOTE: std::atomic is not copyable (see https://stackoverflow.com/a/15250851 for // an explanation), but we need copyability for TNL::Containers::Array. Note that // this copy-constructor and copy-assignment operator are not atomic as they // synchronize only with respect to one or the other object. __cuda_callable__ Atomic( const Atomic& desired ) noexcept { // FIXME // *this = desired.load(); *this = desired.value; } __cuda_callable__ Atomic& operator=( const Atomic& desired ) noexcept { // FIXME // *this = desired.load(); *this = desired.value; return *this; } // just for compatibility with TNL::Containers::Array... static String getType() { return "Atomic< " + TNL::getType< T >() + ", " + Devices::Host::getDeviceType() + " >"; } bool is_lock_free() const noexcept { return true; } constexpr bool is_always_lock_free() const noexcept { return true; } __cuda_callable__ void store( T desired ) noexcept { // CUDA does not have a native atomic store, but it can be emulated with atomic exchange exchange( desired ); } __cuda_callable__ T load() const noexcept { // CUDA does not have a native atomic load: // https://stackoverflow.com/questions/32341081/how-to-have-atomic-load-in-cuda return const_cast<Atomic*>(this)->fetch_add( 0 ); } __cuda_callable__ operator T() const noexcept { return load(); } __cuda_callable__ T exchange( T desired ) noexcept { #ifdef __CUDA_ARCH__ return atomicExch( &value, desired ); #else const T old = value; value = desired; return old; #endif } __cuda_callable__ bool compare_exchange_weak( T& expected, T desired ) noexcept { return compare_exchange_strong( expected, desired ); } __cuda_callable__ bool compare_exchange_strong( T& expected, T desired ) noexcept { #ifdef __CUDA_ARCH__ const T old = atomicCAS( &value, expected, desired ); const bool result = old == expected; expected = old; return result; #else if( value == expected ) { value = desired; return true; } else { expected = value; return false; } #endif } __cuda_callable__ T fetch_add( T arg ) { #ifdef __CUDA_ARCH__ return atomicAdd( &value, arg ); #else const T old = value; value += arg; return old; #endif } __cuda_callable__ T fetch_sub( T arg ) { #ifdef __CUDA_ARCH__ return atomicSub( &value, arg ); #else const T old = value; value -= arg; return old; #endif } __cuda_callable__ T fetch_and( T arg ) { #ifdef __CUDA_ARCH__ return atomicAnd( &value, arg ); #else const T old = value; value = value & arg; return old; #endif } __cuda_callable__ T fetch_or( T arg ) { #ifdef __CUDA_ARCH__ return atomicOr( &value, arg ); #else const T old = value; value = value | arg; return old; #endif } __cuda_callable__ T fetch_xor( T arg ) { #ifdef __CUDA_ARCH__ return atomicXor( &value, arg ); #else const T old = value; value = value ^ arg; return old; #endif } __cuda_callable__ T operator+=( T arg ) noexcept { return fetch_add( arg ) + arg; } __cuda_callable__ T operator-=( T arg ) noexcept { return fetch_sub( arg ) - arg; } __cuda_callable__ T operator&=( T arg ) noexcept { return fetch_and( arg ) & arg; } __cuda_callable__ T operator|=( T arg ) noexcept { return fetch_or( arg ) | arg; } __cuda_callable__ T operator^=( T arg ) noexcept { return fetch_xor( arg ) ^ arg; } // pre-increment __cuda_callable__ T operator++() noexcept { return fetch_add(1) + 1; } // post-increment __cuda_callable__ T operator++(int) noexcept { return fetch_add(1); } // pre-decrement __cuda_callable__ T operator--() noexcept { return fetch_sub(1) - 1; } // post-decrement __cuda_callable__ T operator--(int) noexcept { return fetch_sub(1); } // extensions (methods not present in C++ standards) __cuda_callable__ T fetch_max( T arg ) noexcept { #ifdef __CUDA_ARCH__ return atomicMax( &value, arg ); #else const T old = value; value = ( value > arg ) ? value : arg; return old; #endif } __cuda_callable__ T fetch_min( T arg ) noexcept { #ifdef __CUDA_ARCH__ return atomicMin( &value, arg ); #else const T old = value; value = ( value < arg ) ? value : arg; return old; #endif } protected: T value; }; } // namespace TNL