Commit ff8f19a6 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Reimplementing StaticArray using StaticFor, deleting StaticArray specialization by Size.

parent 82b95fa5
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -55,7 +55,7 @@ class StaticArray
   // reference: https://stackoverflow.com/q/4610503
   template< typename _unused = void >
   __cuda_callable__
   inline StaticArray( const Value v[ Size ] );
   StaticArray( const Value v[ Size ] );

   /**
    * \brief Constructor that sets all array components to value \e v.
+87 −31
Original line number Diff line number Diff line
@@ -13,11 +13,86 @@
#include <TNL/param-types.h>
#include <TNL/Math.h>
#include <TNL/Containers/StaticArray.h>
#include <TNL/TemplateStaticFor.h>
#include <TNL/StaticFor.h>

namespace TNL {
namespace Containers {

namespace Detail {

////
// Lambdas used together with StaticFor for static loop unrolling in the
// implementation of the StaticArray
template< typename LeftValue, typename RightValue = LeftValue >
auto assignArrayLambda = [] __cuda_callable__ ( int i, LeftValue* data, const RightValue* v ) { data[ i ] = v[ i ]; };

template< typename LeftValue, typename RightValue = LeftValue >
auto assignValueLambda = [] __cuda_callable__ ( int i, LeftValue* data, const RightValue v ) { data[ i ] = v; };

////
// StaticArrayComparator does static loop unrolling of array comparison
template< int Size, typename LeftValue, typename RightValue, int Index >
struct StaticArrayComparator
{
   __cuda_callable__
   static bool EQ( const StaticArray< Size, LeftValue >& left,
                   const StaticArray< Size, RightValue >& right )
   {
      if( left[ Index ] == right[ Index ] )
         return StaticArrayComparator< Size, LeftValue, RightValue, Index + 1 >::EQ( left, right );
      return false;
   }
};

template< int Size, typename LeftValue, typename RightValue >
struct StaticArrayComparator< Size, LeftValue, RightValue, Size >
{
   __cuda_callable__
   static bool EQ( const StaticArray< Size, LeftValue >& left,
                   const StaticArray< Size, RightValue >& right )
   {
      return true;
   }
};

////
// Static array sort does static loop unrolling of array sort.
// It performs static variant of bubble sort as follows:
// 
// for( int k = Size - 1; k > 0; k--)
//   for( int i = 0; i < k; i++ )
//      if( data[ i ] > data[ i+1 ] )
//         swap( data[ i ], data[ i+1 ] );
template< int k, int i, typename Value >
struct StaticArraySort
{
   __cuda_callable__
   static void exec( Value* data ) {
      if( data[ i ] > data[  i + 1 ] )
         swap( data[ i ], data[ i+1 ] );
      StaticArraySort< k, i + 1, Value >::exec( data );
   }
};

template< int k, typename Value >
struct StaticArraySort< k, k, Value >
{
   __cuda_callable__
   static void exec( Value* data ) {
      StaticArraySort< k - 1, 0, Value >::exec( data );
   }
};

template< typename Value >
struct StaticArraySort< 0, 0, Value >
{
   __cuda_callable__
   static void exec( Value* data ) {}
};

} //namespace Detail


template< int Size, typename Value >
__cuda_callable__
constexpr int StaticArray< Size, Value >::getSize()
@@ -31,29 +106,27 @@ inline StaticArray< Size, Value >::StaticArray()
{
};


template< int Size, typename Value >
   template< typename _unused >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const Value v[ Size ] )
StaticArray< Size, Value >::StaticArray( const Value v[ Size ] )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = v[ i ];
   StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, v );
}

template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const Value& v )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = v;
   StaticFor< 0, Size >::exec( Detail::assignValueLambda< Value >, data, v );
}

template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const StaticArray< Size, Value >& v )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = v[ i ];
   StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, v.getData() );
}

template< int Size, typename Value >
@@ -174,8 +247,7 @@ template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >& StaticArray< Size, Value >::operator = ( const StaticArray< Size, Value >& array )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = array[ i ];
   StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, array.getData() );
   return *this;
}

@@ -184,8 +256,7 @@ template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >& StaticArray< Size, Value >::operator = ( const Array& array )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = array[ i ];
   StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value, typename Array::ValueType >, data, array.getData() );
   return *this;
}

@@ -194,12 +265,7 @@ template< int Size, typename Value >
__cuda_callable__
inline bool StaticArray< Size, Value >::operator == ( const Array& array ) const
{
   if( ( int ) Size != ( int ) Array::getSize() )
      return false;
   for( int i = 0; i < Size; i++ )
      if( data[ i ] != array[ i ] )
         return false;
   return true;
   return Detail::StaticArrayComparator< Size, Value, typename Array::ValueType, 0 >::EQ( *this, array );
}

template< int Size, typename Value >
@@ -217,8 +283,7 @@ StaticArray< Size, Value >::
operator StaticArray< Size, OtherValue >() const
{
   StaticArray< Size, OtherValue > aux;
   for( int i = 0; i < Size; i++ )
      aux[ i ] = data[ i ];
   StaticFor< 0, Size >::exec( Detail::assignArrayLambda< OtherValue, Value >, aux.getData(), data );
   return aux;
}

@@ -226,8 +291,7 @@ template< int Size, typename Value >
__cuda_callable__
inline void StaticArray< Size, Value >::setValue( const ValueType& val )
{
   for( int i = 0; i < Size; i++ )
      data[ i ] = val;
   StaticFor< 0, Size >::exec( Detail::assignValueLambda< Value >, data, val );
}

template< int Size, typename Value >
@@ -247,14 +311,7 @@ bool StaticArray< Size, Value >::load( File& file)
template< int Size, typename Value >
void StaticArray< Size, Value >::sort()
{
   /****
    * We assume that the array data is small and so
    * may sort it with the bubble sort.
    */
   for( int k = Size - 1; k > 0; k--)
      for( int i = 0; i < k; i++ )
         if( data[ i ] > data[ i+1 ] )
            swap( data[ i ], data[ i+1 ] );
   Detail::StaticArraySort< Size - 1, 0, Value >::exec( data );
}

template< int Size, typename Value >
@@ -266,7 +323,6 @@ std::ostream& StaticArray< Size, Value >::write( std::ostream& str, const char*
   return str;
}


template< int Size, typename Value >
std::ostream& operator << ( std::ostream& str, const StaticArray< Size, Value >& a )
{