Skip to content
Snippets Groups Projects
Commit ff8f19a6 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Jakub Klinkovský
Browse files

Reimplementing StaticArray using StaticFor, deleting StaticArray specialization by Size.

parent 82b95fa5
No related branches found
No related tags found
1 merge request!35Static vector
......@@ -55,7 +55,7 @@ class StaticArray
// reference: https://stackoverflow.com/q/4610503
template< typename _unused = void >
__cuda_callable__
inline StaticArray( const Value v[ Size ] );
StaticArray( const Value v[ Size ] );
/**
* \brief Constructor that sets all array components to value \e v.
......
......@@ -13,11 +13,86 @@
#include <TNL/param-types.h>
#include <TNL/Math.h>
#include <TNL/Containers/StaticArray.h>
#include <TNL/TemplateStaticFor.h>
#include <TNL/StaticFor.h>
namespace TNL {
namespace Containers {
namespace Detail {
////
// Lambdas used together with StaticFor for static loop unrolling in the
// implementation of the StaticArray
template< typename LeftValue, typename RightValue = LeftValue >
auto assignArrayLambda = [] __cuda_callable__ ( int i, LeftValue* data, const RightValue* v ) { data[ i ] = v[ i ]; };
template< typename LeftValue, typename RightValue = LeftValue >
auto assignValueLambda = [] __cuda_callable__ ( int i, LeftValue* data, const RightValue v ) { data[ i ] = v; };
////
// StaticArrayComparator does static loop unrolling of array comparison
template< int Size, typename LeftValue, typename RightValue, int Index >
struct StaticArrayComparator
{
__cuda_callable__
static bool EQ( const StaticArray< Size, LeftValue >& left,
const StaticArray< Size, RightValue >& right )
{
if( left[ Index ] == right[ Index ] )
return StaticArrayComparator< Size, LeftValue, RightValue, Index + 1 >::EQ( left, right );
return false;
}
};
template< int Size, typename LeftValue, typename RightValue >
struct StaticArrayComparator< Size, LeftValue, RightValue, Size >
{
__cuda_callable__
static bool EQ( const StaticArray< Size, LeftValue >& left,
const StaticArray< Size, RightValue >& right )
{
return true;
}
};
////
// Static array sort does static loop unrolling of array sort.
// It performs static variant of bubble sort as follows:
//
// for( int k = Size - 1; k > 0; k--)
// for( int i = 0; i < k; i++ )
// if( data[ i ] > data[ i+1 ] )
// swap( data[ i ], data[ i+1 ] );
template< int k, int i, typename Value >
struct StaticArraySort
{
__cuda_callable__
static void exec( Value* data ) {
if( data[ i ] > data[ i + 1 ] )
swap( data[ i ], data[ i+1 ] );
StaticArraySort< k, i + 1, Value >::exec( data );
}
};
template< int k, typename Value >
struct StaticArraySort< k, k, Value >
{
__cuda_callable__
static void exec( Value* data ) {
StaticArraySort< k - 1, 0, Value >::exec( data );
}
};
template< typename Value >
struct StaticArraySort< 0, 0, Value >
{
__cuda_callable__
static void exec( Value* data ) {}
};
} //namespace Detail
template< int Size, typename Value >
__cuda_callable__
constexpr int StaticArray< Size, Value >::getSize()
......@@ -31,29 +106,27 @@ inline StaticArray< Size, Value >::StaticArray()
{
};
template< int Size, typename Value >
template< typename _unused >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const Value v[ Size ] )
StaticArray< Size, Value >::StaticArray( const Value v[ Size ] )
{
for( int i = 0; i < Size; i++ )
data[ i ] = v[ i ];
StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, v );
}
template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const Value& v )
{
for( int i = 0; i < Size; i++ )
data[ i ] = v;
StaticFor< 0, Size >::exec( Detail::assignValueLambda< Value >, data, v );
}
template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >::StaticArray( const StaticArray< Size, Value >& v )
{
for( int i = 0; i < Size; i++ )
data[ i ] = v[ i ];
StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, v.getData() );
}
template< int Size, typename Value >
......@@ -174,8 +247,7 @@ template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >& StaticArray< Size, Value >::operator = ( const StaticArray< Size, Value >& array )
{
for( int i = 0; i < Size; i++ )
data[ i ] = array[ i ];
StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value >, data, array.getData() );
return *this;
}
......@@ -184,8 +256,7 @@ template< int Size, typename Value >
__cuda_callable__
inline StaticArray< Size, Value >& StaticArray< Size, Value >::operator = ( const Array& array )
{
for( int i = 0; i < Size; i++ )
data[ i ] = array[ i ];
StaticFor< 0, Size >::exec( Detail::assignArrayLambda< Value, typename Array::ValueType >, data, array.getData() );
return *this;
}
......@@ -194,12 +265,7 @@ template< int Size, typename Value >
__cuda_callable__
inline bool StaticArray< Size, Value >::operator == ( const Array& array ) const
{
if( ( int ) Size != ( int ) Array::getSize() )
return false;
for( int i = 0; i < Size; i++ )
if( data[ i ] != array[ i ] )
return false;
return true;
return Detail::StaticArrayComparator< Size, Value, typename Array::ValueType, 0 >::EQ( *this, array );
}
template< int Size, typename Value >
......@@ -217,8 +283,7 @@ StaticArray< Size, Value >::
operator StaticArray< Size, OtherValue >() const
{
StaticArray< Size, OtherValue > aux;
for( int i = 0; i < Size; i++ )
aux[ i ] = data[ i ];
StaticFor< 0, Size >::exec( Detail::assignArrayLambda< OtherValue, Value >, aux.getData(), data );
return aux;
}
......@@ -226,8 +291,7 @@ template< int Size, typename Value >
__cuda_callable__
inline void StaticArray< Size, Value >::setValue( const ValueType& val )
{
for( int i = 0; i < Size; i++ )
data[ i ] = val;
StaticFor< 0, Size >::exec( Detail::assignValueLambda< Value >, data, val );
}
template< int Size, typename Value >
......@@ -247,14 +311,7 @@ bool StaticArray< Size, Value >::load( File& file)
template< int Size, typename Value >
void StaticArray< Size, Value >::sort()
{
/****
* We assume that the array data is small and so
* may sort it with the bubble sort.
*/
for( int k = Size - 1; k > 0; k--)
for( int i = 0; i < k; i++ )
if( data[ i ] > data[ i+1 ] )
swap( data[ i ], data[ i+1 ] );
Detail::StaticArraySort< Size - 1, 0, Value >::exec( data );
}
template< int Size, typename Value >
......@@ -266,7 +323,6 @@ std::ostream& StaticArray< Size, Value >::write( std::ostream& str, const char*
return str;
}
template< int Size, typename Value >
std::ostream& operator << ( std::ostream& str, const StaticArray< Size, Value >& a )
{
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment