Skip to content
Snippets Groups Projects
Atomic.h 7.31 KiB
Newer Older
Jakub Klinkovský's avatar
Jakub Klinkovský committed
/***************************************************************************
                          Atomic.h  -  description
                             -------------------
    begin                : Sep 14, 2018
    copyright            : (C) 2018 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

// Implemented by: Jakub Klinkovský

#pragma once

#include <atomic>  // std::atomic

#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
#include <TNL/param-types.h>

namespace TNL {

template< typename T, typename Device >
class Atomic
{};

template< typename T >
class Atomic< T, Devices::Host >
: public std::atomic< T >
{
public:
   Atomic() noexcept = default;

   // inherit constructors
   using std::atomic< T >::atomic;

   // NOTE: std::atomic is not copyable (see https://stackoverflow.com/a/15250851 for
   // an explanation), but we need copyability for TNL::Containers::Array. Note that
   // this copy-constructor and copy-assignment operator are not atomic as they
   // synchronize only with respect to one or the other object.
   Atomic( const Atomic& desired ) noexcept
   {
      this->store(desired.load());
   }
   Atomic& operator=( const Atomic& desired ) noexcept
   {
      this->store(desired.load());
      return *this;
   }

   // just for compatibility with TNL::Containers::Array...
   static String getType()
   {
      return "Atomic< " +
             TNL::getType< T >() + ", " +
             Devices::Host::getDeviceType() + " >";
   }

   // CAS loops for updating maximum and minimum
   // reference: https://stackoverflow.com/a/16190791
   T fetch_max( T value ) noexcept
   {
      const T old = *this;
      T prev_value = old;
      while(prev_value < value &&
            ! this->compare_exchange_weak(prev_value, value))
         ;
      return old;
   }

   T fetch_min( T value ) noexcept
   {
      const T old = *this;
      T prev_value = old;
      while(prev_value > value &&
            ! this->compare_exchange_weak(prev_value, value))
         ;
      return old;
   }
};

template< typename T >
class Atomic< T, Devices::Cuda >
{
public:
   using value_type = T;
   // FIXME
//   using difference_type = typename std::atomic< T >::difference_type;

   __cuda_callable__
   Atomic() noexcept = default;

   __cuda_callable__
   constexpr Atomic( T desired ) noexcept : value(desired) {}

   __cuda_callable__
   T operator=( T desired ) noexcept
   {
      store( desired );
      return desired;
   }

   // NOTE: std::atomic is not copyable (see https://stackoverflow.com/a/15250851 for
   // an explanation), but we need copyability for TNL::Containers::Array. Note that
   // this copy-constructor and copy-assignment operator are not atomic as they
   // synchronize only with respect to one or the other object.
   __cuda_callable__
   Atomic( const Atomic& desired ) noexcept
   {
      // FIXME
//      *this = desired.load();
      *this = desired.value;
   }
   __cuda_callable__
   Atomic& operator=( const Atomic& desired ) noexcept
   {
      // FIXME
//      *this = desired.load();
      *this = desired.value;
      return *this;
   }

   // just for compatibility with TNL::Containers::Array...
   static String getType()
   {
      return "Atomic< " +
             TNL::getType< T >() + ", " +
             Devices::Host::getDeviceType() + " >";
   }

   bool is_lock_free() const noexcept
   {
      return true;
   }

   constexpr bool is_always_lock_free() const noexcept
   {
      return true;
   }

   __cuda_callable__
   void store( T desired ) noexcept
   {
      // CUDA does not have a native atomic store, but it can be emulated with atomic exchange
      exchange( desired );
   }

   __cuda_callable__
   T load() const noexcept
   {
      // CUDA does not have a native atomic load:
      // https://stackoverflow.com/questions/32341081/how-to-have-atomic-load-in-cuda
      return const_cast<Atomic*>(this)->fetch_add( 0 );
   }

   __cuda_callable__
   operator T() const noexcept
   {
      return load();
   }

   __cuda_callable__
   T exchange( T desired ) noexcept
   {
#ifdef __CUDA_ARCH__
      return atomicExch( &value, desired );
#else
      const T old = value;
      value = desired;
      return old;
#endif
   }

   __cuda_callable__
   bool compare_exchange_weak( T& expected, T desired ) noexcept
   {
      return compare_exchange_strong( expected, desired );
   }

   __cuda_callable__
   bool compare_exchange_strong( T& expected, T desired ) noexcept
   {
#ifdef __CUDA_ARCH__
      const T old = atomicCAS( &value, expected, desired );
      const bool result = old == expected;
      expected = old;
      return result;
#else
      if( value == expected ) {
         value = desired;
         return true;
      }
      else {
         expected = value;
         return false;
      }
#endif
   }

   __cuda_callable__
   T fetch_add( T arg )
   {
#ifdef __CUDA_ARCH__
      return atomicAdd( &value, arg );
#else
      const T old = value;
      value += arg;
      return old;
#endif
   }

   __cuda_callable__
   T fetch_sub( T arg )
   {
#ifdef __CUDA_ARCH__
      return atomicSub( &value, arg );
#else
      const T old = value;
      value -= arg;
      return old;
#endif
   }

   __cuda_callable__
   T fetch_and( T arg )
   {
#ifdef __CUDA_ARCH__
      return atomicAnd( &value, arg );
#else
      const T old = value;
      value = value & arg;
      return old;
#endif
   }

   __cuda_callable__
   T fetch_or( T arg )
   {
#ifdef __CUDA_ARCH__
      return atomicOr( &value, arg );
#else
      const T old = value;
      value = value | arg;
      return old;
#endif
   }

   __cuda_callable__
   T fetch_xor( T arg )
   {
#ifdef __CUDA_ARCH__
      return atomicXor( &value, arg );
#else
      const T old = value;
      value = value ^ arg;
      return old;
#endif
   }

   __cuda_callable__
   T operator+=( T arg ) noexcept
   {
      return fetch_add( arg ) + arg;
   }

   __cuda_callable__
   T operator-=( T arg ) noexcept
   {
      return fetch_sub( arg ) - arg;
   }

   __cuda_callable__
   T operator&=( T arg ) noexcept
   {
      return fetch_and( arg ) & arg;
   }

   __cuda_callable__
   T operator|=( T arg ) noexcept
   {
      return fetch_or( arg ) | arg;
   }

   __cuda_callable__
   T operator^=( T arg ) noexcept
   {
      return fetch_xor( arg ) ^ arg;
   }

   // pre-increment
   __cuda_callable__
   T operator++() noexcept
   {
      return fetch_add(1) + 1;
   }

   // post-increment
   __cuda_callable__
   T operator++(int) noexcept
   {
      return fetch_add(1);
   }

   // pre-decrement
   __cuda_callable__
   T operator--() noexcept
   {
      return fetch_sub(1) - 1;
   }

   // post-decrement
   __cuda_callable__
   T operator--(int) noexcept
   {
      return fetch_sub(1);
   }

   // extensions (methods not present in C++ standards)

   __cuda_callable__
   T fetch_max( T arg ) noexcept
   {
#ifdef __CUDA_ARCH__
      return atomicMax( &value, arg );
#else
      const T old = value;
      value = ( value > arg ) ? value : arg;
      return old;
#endif
   }

   __cuda_callable__
   T fetch_min( T arg ) noexcept
   {
#ifdef __CUDA_ARCH__
      return atomicMin( &value, arg );
#else
      const T old = value;
      value = ( value < arg ) ? value : arg;
      return old;
#endif
   }

protected:
   T value;
};

} // namespace TNL