Commit 34670339 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed binary expression templates for nested vector types

Fixes #60
parent 4c5ddb35
Loading
Loading
Loading
Loading
+60 −22
Original line number Diff line number Diff line
@@ -85,13 +85,70 @@ using EnableIfDistributedBinaryExpression_t = std::enable_if_t<
      ) >;


// helper trait class for recursively turning expression template classes into compatible vectors
template<class T, class R = void>
struct enable_if_type { typedef R type; };

template< typename R, typename Enable = void >
struct RemoveExpressionTemplate
{
   using type = std::decay_t< R >;
};

template< typename R >
struct RemoveExpressionTemplate< R, typename enable_if_type< typename std::decay_t< R >::VectorOperandType >::type >
{
   using type = typename RemoveExpressionTemplate< typename std::decay_t< R >::VectorOperandType >::type;
};

template< typename R >
using RemoveET = typename RemoveExpressionTemplate< R >::type;


template< typename T1, typename T2 >
constexpr std::enable_if_t<
      ! ( std::is_arithmetic< T1 >::value && std::is_arithmetic< T2 >::value ) &&
      ! ( IsStaticArrayType< T1 >::value && IsStaticArrayType< T2 >::value ) &&
      ! ( IsArrayType< T1 >::value && IsArrayType< T2 >::value )
, bool >
compatibleForVectorAssignment()
{
   return false;
}

template< typename T1, typename T2 >
constexpr std::enable_if_t< std::is_arithmetic< T1 >::value && std::is_arithmetic< T2 >::value, bool >
compatibleForVectorAssignment()
{
   return true;
}

template< typename T1, typename T2 >
constexpr std::enable_if_t< IsStaticArrayType< T1 >::value && IsStaticArrayType< T2 >::value, bool >
compatibleForVectorAssignment()
{
   return T1::getSize() == T2::getSize() &&
          compatibleForVectorAssignment< typename RemoveET< T1 >::ValueType, typename RemoveET< T2 >::ValueType >();
}

template< typename T1, typename T2 >
constexpr std::enable_if_t< IsArrayType< T1 >::value && IsArrayType< T2 >::value, bool >
compatibleForVectorAssignment()
{
   return compatibleForVectorAssignment< typename RemoveET< T1 >::ValueType, typename RemoveET< T2 >::ValueType >();
}


// helper trait class for proper classification of expression operands using getExpressionVariableType
template< typename T, typename V,
          bool enabled = IsVectorType< V >::value >
          bool enabled = HasEnabledExpressionTemplates< V >::value ||
                         HasEnabledStaticExpressionTemplates< V >::value ||
                         HasEnabledDistributedExpressionTemplates< V >::value >
struct IsArithmeticSubtype
: public std::integral_constant< bool,
            // TODO: use std::is_assignable?
            std::is_same< T, typename std::decay_t< V >::RealType >::value >
            // Note that using std::is_same would not be general enough, because e.g.
            // StaticVector<3, int> may be assigned to StaticVector<3, double>
            compatibleForVectorAssignment< typename V::RealType, T >() >
{};

template< typename T >
@@ -110,25 +167,6 @@ struct IsArithmeticSubtype< T, V, false >
{};


// helper trait class (used in unit tests)
template<class T, class R = void>
struct enable_if_type { typedef R type; };

template< typename R, typename Enable = void >
struct RemoveExpressionTemplate
{
   using type = std::decay_t< R >;
};

template< typename R >
struct RemoveExpressionTemplate< R, typename enable_if_type< typename std::decay_t< R >::VectorOperandType >::type >
{
   using type = typename RemoveExpressionTemplate< typename std::decay_t< R >::VectorOperandType >::type;
};

template< typename R >
using RemoveET = typename RemoveExpressionTemplate< R >::type;

// helper trait class for Static*ExpressionTemplates classes
template< typename R, typename Enable = void >
struct OperandMemberType
+81 −6
Original line number Diff line number Diff line
@@ -58,18 +58,18 @@ class VectorBinaryOperationsTest : public ::testing::Test
protected:
   using Left = typename Pair::Left;
   using Right = typename Pair::Right;
   using LeftReal = std::remove_const_t< typename Left::RealType >;
   using RightReal = std::remove_const_t< typename Right::RealType >;
#ifndef STATIC_VECTOR
   using LeftNonConstReal = std::remove_const_t< typename Left::RealType >;
   using RightNonConstReal = std::remove_const_t< typename Right::RealType >;
   #ifdef DISTRIBUTED_VECTOR
      using CommunicatorType = typename Left::CommunicatorType;
      static_assert( std::is_same< typename Right::CommunicatorType, CommunicatorType >::value,
                     "CommunicatorType must be the same for both Left and Right vectors." );
      using LeftVector = DistributedVector< LeftNonConstReal, typename Left::DeviceType, typename Left::IndexType, CommunicatorType >;
      using RightVector = DistributedVector< RightNonConstReal, typename Right::DeviceType, typename Right::IndexType, CommunicatorType >;
      using LeftVector = DistributedVector< LeftReal, typename Left::DeviceType, typename Left::IndexType, CommunicatorType >;
      using RightVector = DistributedVector< RightReal, typename Right::DeviceType, typename Right::IndexType, CommunicatorType >;
   #else
      using LeftVector = Vector< LeftNonConstReal, typename Left::DeviceType, typename Left::IndexType >;
      using RightVector = Vector< RightNonConstReal, typename Right::DeviceType, typename Right::IndexType >;
      using LeftVector = Vector< LeftReal, typename Left::DeviceType, typename Left::IndexType >;
      using RightVector = Vector< RightReal, typename Right::DeviceType, typename Right::IndexType >;
   #endif
#endif

@@ -132,6 +132,8 @@ protected:
#define SETUP_BINARY_TEST_ALIASES \
   using Left = typename TestFixture::Left;                 \
   using Right = typename TestFixture::Right;               \
   using LeftReal = typename TestFixture::LeftReal;         \
   using RightReal = typename TestFixture::RightReal;       \
   Left& L1 = this->L1;                                     \
   Left& L2 = this->L2;                                     \
   Right& R1 = this->R1;                                    \
@@ -263,6 +265,8 @@ TYPED_TEST( VectorBinaryOperationsTest, EQ )
   EXPECT_EQ( L1, R1 );       // vector or vector view
   EXPECT_EQ( L1, 1 );        // right scalar
   EXPECT_EQ( 1, R1 );        // left scalar
   EXPECT_EQ( L1, RightReal(1) );   // right scalar
   EXPECT_EQ( LeftReal(1), R1 );    // left scalar
   EXPECT_EQ( L2, R1 + R1 );  // right expression
   EXPECT_EQ( L1 + L1, R2 );  // left expression
   EXPECT_EQ( L1 + L1, R1 + R1 );  // two expressions
@@ -282,6 +286,8 @@ TYPED_TEST( VectorBinaryOperationsTest, NE )
   EXPECT_NE( L1, R2 );       // vector or vector view
   EXPECT_NE( L1, 2 );        // right scalar
   EXPECT_NE( 2, R1 );        // left scalar
   EXPECT_NE( L1, RightReal(2) );   // right scalar
   EXPECT_NE( LeftReal(2), R1 );    // left scalar
   EXPECT_NE( L1, R1 + R1 );  // right expression
   EXPECT_NE( L1 + L1, R1 );  // left expression
   EXPECT_NE( L1 + L1, R2 + R2 );  // two expressions
@@ -301,6 +307,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LT )
   EXPECT_LT( L1, R2 );       // vector or vector view
   EXPECT_LT( L1, 2 );        // right scalar
   EXPECT_LT( 1, R2 );        // left scalar
   EXPECT_LT( L1, RightReal(2) );   // right scalar
   EXPECT_LT( LeftReal(1), R2 );    // left scalar
   EXPECT_LT( L1, R1 + R1 );  // right expression
   EXPECT_LT( L1 - L1, R1 );  // left expression
   EXPECT_LT( L1 - L1, R1 + R1 );  // two expressions
@@ -313,6 +321,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GT )
   EXPECT_GT( L2, R1 );       // vector or vector view
   EXPECT_GT( L2, 1 );        // right scalar
   EXPECT_GT( 2, R1 );        // left scalar
   EXPECT_GT( L2, RightReal(1) );   // right scalar
   EXPECT_GT( LeftReal(2), R1 );    // left scalar
   EXPECT_GT( L1, R1 - R1 );  // right expression
   EXPECT_GT( L1 + L1, R1 );  // left expression
   EXPECT_GT( L1 + L1, R1 - R1 );  // two expressions
@@ -326,6 +336,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LE )
   EXPECT_LE( L1, R2 );       // vector or vector view
   EXPECT_LE( L1, 2 );        // right scalar
   EXPECT_LE( 1, R2 );        // left scalar
   EXPECT_LE( L1, RightReal(2) );   // right scalar
   EXPECT_LE( LeftReal(1), R2 );    // left scalar
   EXPECT_LE( L1, R1 + R1 );  // right expression
   EXPECT_LE( L1 - L1, R1 );  // left expression
   EXPECT_LE( L1 - L1, R1 + R1 );  // two expressions
@@ -334,6 +346,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LE )
   EXPECT_LE( L1, R1 );       // vector or vector view
   EXPECT_LE( L1, 1 );        // right scalar
   EXPECT_LE( 1, R1 );        // left scalar
   EXPECT_LE( L1, RightReal(1) );   // right scalar
   EXPECT_LE( LeftReal(1), R1 );    // left scalar
   EXPECT_LE( L2, R1 + R1 );  // right expression
   EXPECT_LE( L1 + L1, R2 );  // left expression
   EXPECT_LE( L1 + L1, R1 + R2 );  // two expressions
@@ -347,6 +361,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GE )
   EXPECT_GE( L2, R1 );       // vector or vector view
   EXPECT_GE( L2, 1 );        // right scalar
   EXPECT_GE( 2, R1 );        // left scalar
   EXPECT_GE( L2, RightReal(1) );   // right scalar
   EXPECT_GE( LeftReal(2), R1 );    // left scalar
   EXPECT_GE( L1, R1 - R1 );  // right expression
   EXPECT_GE( L1 + L1, R1 );  // left expression
   EXPECT_GE( L1 + L1, R1 - R1 );  // two expressions
@@ -355,6 +371,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GE )
   EXPECT_LE( L1, R1 );       // vector or vector view
   EXPECT_LE( L1, 1 );        // right scalar
   EXPECT_LE( 1, R1 );        // left scalar
   EXPECT_LE( L1, RightReal(1) );   // right scalar
   EXPECT_LE( LeftReal(1), R1 );    // left scalar
   EXPECT_LE( L2, R1 + R1 );  // right expression
   EXPECT_LE( L1 + L1, R2 );  // left expression
   EXPECT_LE( L1 + L1, R1 + R2 );  // two expressions
@@ -369,6 +387,8 @@ TYPED_TEST( VectorBinaryOperationsTest, addition )
   // with scalar
   EXPECT_EQ( L1 + 1, 2 );
   EXPECT_EQ( 1 + L1, 2 );
   EXPECT_EQ( L1 + LeftReal(1), 2 );
   EXPECT_EQ( LeftReal(1) + L1, 2 );
   // with expression
   EXPECT_EQ( L1 + (L1 + L1), 3 );
   EXPECT_EQ( (L1 + L1) + L1, 3 );
@@ -376,6 +396,11 @@ TYPED_TEST( VectorBinaryOperationsTest, addition )
   EXPECT_EQ( (L1 + L1) + R1, 3 );
   // with two expressions
   EXPECT_EQ( (L1 + L1) + (L1 + L1), 4 );
   // with expression and scalar
   EXPECT_EQ( (L1 + L1) + 1, 3 );
   EXPECT_EQ( (L1 + L1) + RightReal(1), 3 );
   EXPECT_EQ( 1 + (R1 + R1), 3 );
   EXPECT_EQ( LeftReal(1) + (R1 + R1), 3 );
}

TYPED_TEST( VectorBinaryOperationsTest, subtraction )
@@ -387,6 +412,8 @@ TYPED_TEST( VectorBinaryOperationsTest, subtraction )
   // with scalar
   EXPECT_EQ( L1 - 1, 0 );
   EXPECT_EQ( 1 - L1, 0 );
   EXPECT_EQ( L1 - LeftReal(1), 0 );
   EXPECT_EQ( LeftReal(1) - L1, 0 );
   // with expression
   EXPECT_EQ( L2 - (L1 + L1), 0 );
   EXPECT_EQ( (L1 + L1) - L2, 0 );
@@ -394,6 +421,11 @@ TYPED_TEST( VectorBinaryOperationsTest, subtraction )
   EXPECT_EQ( (L1 + L1) - R2, 0 );
   // with two expressions
   EXPECT_EQ( (L1 + L1) - (L1 + L1), 0 );
   // with expression and scalar
   EXPECT_EQ( (L1 + L1) - 1, 1 );
   EXPECT_EQ( (L1 + L1) - RightReal(1), 1 );
   EXPECT_EQ( 1 - (R1 + R1), -1 );
   EXPECT_EQ( LeftReal(1) - (R1 + R1), -1 );
}

TYPED_TEST( VectorBinaryOperationsTest, multiplication )
@@ -405,6 +437,8 @@ TYPED_TEST( VectorBinaryOperationsTest, multiplication )
   // with scalar
   EXPECT_EQ( L1 * 2, L2 );
   EXPECT_EQ( 2 * L1, L2 );
   EXPECT_EQ( L1 * LeftReal(2), L2 );
   EXPECT_EQ( LeftReal(2) * L1, L2 );
   // with expression
   EXPECT_EQ( L1 * (L1 + L1), L2 );
   EXPECT_EQ( (L1 + L1) * L1, L2 );
@@ -412,6 +446,11 @@ TYPED_TEST( VectorBinaryOperationsTest, multiplication )
   EXPECT_EQ( (L1 + L1) * R1, L2 );
   // with two expressions
   EXPECT_EQ( (L1 + L1) * (L1 + L1), 4 );
   // with expression and scalar
   EXPECT_EQ( (L1 + L1) * 1, 2 );
   EXPECT_EQ( (L1 + L1) * RightReal(1), 2 );
   EXPECT_EQ( 1 * (R1 + R1), 2 );
   EXPECT_EQ( LeftReal(1) * (R1 + R1), 2 );
}

TYPED_TEST( VectorBinaryOperationsTest, division )
@@ -423,6 +462,8 @@ TYPED_TEST( VectorBinaryOperationsTest, division )
   // with scalar
   EXPECT_EQ( L2 / 2, L1 );
   EXPECT_EQ( 2 / L2, L1 );
   EXPECT_EQ( L2 / LeftReal(2), L1 );
   EXPECT_EQ( LeftReal(2) / L2, L1 );
   // with expression
   EXPECT_EQ( L2 / (L1 + L1), L1 );
   EXPECT_EQ( (L1 + L1) / L2, L1 );
@@ -430,6 +471,11 @@ TYPED_TEST( VectorBinaryOperationsTest, division )
   EXPECT_EQ( (L1 + L1) / R2, L1 );
   // with two expressions
   EXPECT_EQ( (L1 + L1) / (L1 + L1), L1 );
   // with expression and scalar
   EXPECT_EQ( (L1 + L1) / 1, 2 );
   EXPECT_EQ( (L1 + L1) / RightReal(1), 2 );
   EXPECT_EQ( 2 / (R1 + R1), 1 );
   EXPECT_EQ( LeftReal(2) / (R1 + R1), 1 );
}

template< typename Left, typename Right, std::enable_if_t< std::is_const<typename Left::RealType>::value, bool > = true >
@@ -438,12 +484,15 @@ void test_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
   using RightReal = std::remove_const_t< typename Right::RealType >;
   // with vector or vector view
   L1 = R2;
   EXPECT_EQ( L1, R2 );
   // with scalar
   L1 = 1;
   EXPECT_EQ( L1, 1 );
   L1 = RightReal(1);
   EXPECT_EQ( L1, 1 );
   // with expression
   L1 = R1 + R1;
   EXPECT_EQ( L1, R1 + R1 );
@@ -460,6 +509,7 @@ void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
   using RightReal = std::remove_const_t< typename Right::RealType >;
   // with vector or vector view
   L1 += R2;
   EXPECT_EQ( L1, R1 + R2 );
@@ -467,6 +517,9 @@ void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
   L1 = 1;
   L1 += 2;
   EXPECT_EQ( L1, 3 );
   L1 = 1;
   L1 += RightReal(2);
   EXPECT_EQ( L1, 3 );
   // with expression
   L1 = 1;
   L1 += R1 + R1;
@@ -484,6 +537,7 @@ void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
   using RightReal = std::remove_const_t< typename Right::RealType >;
   // with vector or vector view
   L1 -= R2;
   EXPECT_EQ( L1, R1 - R2 );
@@ -491,6 +545,9 @@ void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
   L1 = 1;
   L1 -= 2;
   EXPECT_EQ( L1, -1 );
   L1 = 1;
   L1 -= RightReal(2);
   EXPECT_EQ( L1, -1 );
   // with expression
   L1 = 1;
   L1 -= R1 + R1;
@@ -508,6 +565,7 @@ void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
   using RightReal = std::remove_const_t< typename Right::RealType >;
   // with vector or vector view
   L1 *= R2;
   EXPECT_EQ( L1, R2 );
@@ -515,6 +573,9 @@ void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
   L1 = 1;
   L1 *= 2;
   EXPECT_EQ( L1, 2 );
   L1 = 1;
   L1 *= RightReal(2);
   EXPECT_EQ( L1, 2 );
   // with expression
   L1 = 1;
   L1 *= R1 + R1;
@@ -532,6 +593,7 @@ void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
   using RightReal = std::remove_const_t< typename Right::RealType >;
   // with vector or vector view
   L2 /= R2;
   EXPECT_EQ( L1, R1 );
@@ -539,6 +601,9 @@ void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
   L2 = 2;
   L2 /= 2;
   EXPECT_EQ( L1, 1 );
   L1 = 2;
   L1 /= RightReal(2);
   EXPECT_EQ( L1, 1 );
   // with expression
   L2 = 2;
   L2 /= R1 + R1;
@@ -602,6 +667,11 @@ TYPED_TEST( VectorBinaryOperationsTest, min )
   EXPECT_EQ( TNL::min(L1 + L1, R1), R1 );
   // with two expressions
   EXPECT_EQ( TNL::min(L1 + L1, R1 + R2), L2 );
   // with expression and scalar
   EXPECT_EQ( TNL::min(L1 + L1, 1), L1 );
   EXPECT_EQ( TNL::min(L1 + L1, RightReal(1)), L1 );
   EXPECT_EQ( TNL::min(1, R1 + R1), L1 );
   EXPECT_EQ( TNL::min(LeftReal(1), R1 + R1), L1 );
}

TYPED_TEST( VectorBinaryOperationsTest, max )
@@ -620,6 +690,11 @@ TYPED_TEST( VectorBinaryOperationsTest, max )
   EXPECT_EQ( TNL::max(L1 + L1, R1), R2 );
   // with two expressions
   EXPECT_EQ( TNL::max(L1 - L1, R1 + R1), L2 );
   // with expression and scalar
   EXPECT_EQ( TNL::max(L1 + L1, 1), L2 );
   EXPECT_EQ( TNL::max(L1 + L1, RightReal(1)), L2 );
   EXPECT_EQ( TNL::max(1, R1 + R1), L2 );
   EXPECT_EQ( TNL::max(LeftReal(1), R1 + R1), L2 );
}

#if defined(HAVE_CUDA) && !defined(STATIC_VECTOR)
+1 −5
Original line number Diff line number Diff line
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"
#include "VectorOfStaticVectorsTest.h"
+1 −5
Original line number Diff line number Diff line
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"
#include "VectorOfStaticVectorsTest.h"
+5 −0
Original line number Diff line number Diff line
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"