Skip to content
Snippets Groups Projects
Commit 34670339 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed binary expression templates for nested vector types

Fixes #60
parent 4c5ddb35
No related branches found
No related tags found
1 merge request!68Expression templates for nested vectors
......@@ -85,13 +85,70 @@ using EnableIfDistributedBinaryExpression_t = std::enable_if_t<
) >;
// helper trait class for recursively turning expression template classes into compatible vectors
template<class T, class R = void>
struct enable_if_type { typedef R type; };
template< typename R, typename Enable = void >
struct RemoveExpressionTemplate
{
using type = std::decay_t< R >;
};
template< typename R >
struct RemoveExpressionTemplate< R, typename enable_if_type< typename std::decay_t< R >::VectorOperandType >::type >
{
using type = typename RemoveExpressionTemplate< typename std::decay_t< R >::VectorOperandType >::type;
};
template< typename R >
using RemoveET = typename RemoveExpressionTemplate< R >::type;
template< typename T1, typename T2 >
constexpr std::enable_if_t<
! ( std::is_arithmetic< T1 >::value && std::is_arithmetic< T2 >::value ) &&
! ( IsStaticArrayType< T1 >::value && IsStaticArrayType< T2 >::value ) &&
! ( IsArrayType< T1 >::value && IsArrayType< T2 >::value )
, bool >
compatibleForVectorAssignment()
{
return false;
}
template< typename T1, typename T2 >
constexpr std::enable_if_t< std::is_arithmetic< T1 >::value && std::is_arithmetic< T2 >::value, bool >
compatibleForVectorAssignment()
{
return true;
}
template< typename T1, typename T2 >
constexpr std::enable_if_t< IsStaticArrayType< T1 >::value && IsStaticArrayType< T2 >::value, bool >
compatibleForVectorAssignment()
{
return T1::getSize() == T2::getSize() &&
compatibleForVectorAssignment< typename RemoveET< T1 >::ValueType, typename RemoveET< T2 >::ValueType >();
}
template< typename T1, typename T2 >
constexpr std::enable_if_t< IsArrayType< T1 >::value && IsArrayType< T2 >::value, bool >
compatibleForVectorAssignment()
{
return compatibleForVectorAssignment< typename RemoveET< T1 >::ValueType, typename RemoveET< T2 >::ValueType >();
}
// helper trait class for proper classification of expression operands using getExpressionVariableType
template< typename T, typename V,
bool enabled = IsVectorType< V >::value >
bool enabled = HasEnabledExpressionTemplates< V >::value ||
HasEnabledStaticExpressionTemplates< V >::value ||
HasEnabledDistributedExpressionTemplates< V >::value >
struct IsArithmeticSubtype
: public std::integral_constant< bool,
// TODO: use std::is_assignable?
std::is_same< T, typename std::decay_t< V >::RealType >::value >
// Note that using std::is_same would not be general enough, because e.g.
// StaticVector<3, int> may be assigned to StaticVector<3, double>
compatibleForVectorAssignment< typename V::RealType, T >() >
{};
template< typename T >
......@@ -110,25 +167,6 @@ struct IsArithmeticSubtype< T, V, false >
{};
// helper trait class (used in unit tests)
template<class T, class R = void>
struct enable_if_type { typedef R type; };
template< typename R, typename Enable = void >
struct RemoveExpressionTemplate
{
using type = std::decay_t< R >;
};
template< typename R >
struct RemoveExpressionTemplate< R, typename enable_if_type< typename std::decay_t< R >::VectorOperandType >::type >
{
using type = typename RemoveExpressionTemplate< typename std::decay_t< R >::VectorOperandType >::type;
};
template< typename R >
using RemoveET = typename RemoveExpressionTemplate< R >::type;
// helper trait class for Static*ExpressionTemplates classes
template< typename R, typename Enable = void >
struct OperandMemberType
......
......@@ -58,18 +58,18 @@ class VectorBinaryOperationsTest : public ::testing::Test
protected:
using Left = typename Pair::Left;
using Right = typename Pair::Right;
using LeftReal = std::remove_const_t< typename Left::RealType >;
using RightReal = std::remove_const_t< typename Right::RealType >;
#ifndef STATIC_VECTOR
using LeftNonConstReal = std::remove_const_t< typename Left::RealType >;
using RightNonConstReal = std::remove_const_t< typename Right::RealType >;
#ifdef DISTRIBUTED_VECTOR
using CommunicatorType = typename Left::CommunicatorType;
static_assert( std::is_same< typename Right::CommunicatorType, CommunicatorType >::value,
"CommunicatorType must be the same for both Left and Right vectors." );
using LeftVector = DistributedVector< LeftNonConstReal, typename Left::DeviceType, typename Left::IndexType, CommunicatorType >;
using RightVector = DistributedVector< RightNonConstReal, typename Right::DeviceType, typename Right::IndexType, CommunicatorType >;
using LeftVector = DistributedVector< LeftReal, typename Left::DeviceType, typename Left::IndexType, CommunicatorType >;
using RightVector = DistributedVector< RightReal, typename Right::DeviceType, typename Right::IndexType, CommunicatorType >;
#else
using LeftVector = Vector< LeftNonConstReal, typename Left::DeviceType, typename Left::IndexType >;
using RightVector = Vector< RightNonConstReal, typename Right::DeviceType, typename Right::IndexType >;
using LeftVector = Vector< LeftReal, typename Left::DeviceType, typename Left::IndexType >;
using RightVector = Vector< RightReal, typename Right::DeviceType, typename Right::IndexType >;
#endif
#endif
......@@ -132,6 +132,8 @@ protected:
#define SETUP_BINARY_TEST_ALIASES \
using Left = typename TestFixture::Left; \
using Right = typename TestFixture::Right; \
using LeftReal = typename TestFixture::LeftReal; \
using RightReal = typename TestFixture::RightReal; \
Left& L1 = this->L1; \
Left& L2 = this->L2; \
Right& R1 = this->R1; \
......@@ -263,6 +265,8 @@ TYPED_TEST( VectorBinaryOperationsTest, EQ )
EXPECT_EQ( L1, R1 ); // vector or vector view
EXPECT_EQ( L1, 1 ); // right scalar
EXPECT_EQ( 1, R1 ); // left scalar
EXPECT_EQ( L1, RightReal(1) ); // right scalar
EXPECT_EQ( LeftReal(1), R1 ); // left scalar
EXPECT_EQ( L2, R1 + R1 ); // right expression
EXPECT_EQ( L1 + L1, R2 ); // left expression
EXPECT_EQ( L1 + L1, R1 + R1 ); // two expressions
......@@ -282,6 +286,8 @@ TYPED_TEST( VectorBinaryOperationsTest, NE )
EXPECT_NE( L1, R2 ); // vector or vector view
EXPECT_NE( L1, 2 ); // right scalar
EXPECT_NE( 2, R1 ); // left scalar
EXPECT_NE( L1, RightReal(2) ); // right scalar
EXPECT_NE( LeftReal(2), R1 ); // left scalar
EXPECT_NE( L1, R1 + R1 ); // right expression
EXPECT_NE( L1 + L1, R1 ); // left expression
EXPECT_NE( L1 + L1, R2 + R2 ); // two expressions
......@@ -301,6 +307,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LT )
EXPECT_LT( L1, R2 ); // vector or vector view
EXPECT_LT( L1, 2 ); // right scalar
EXPECT_LT( 1, R2 ); // left scalar
EXPECT_LT( L1, RightReal(2) ); // right scalar
EXPECT_LT( LeftReal(1), R2 ); // left scalar
EXPECT_LT( L1, R1 + R1 ); // right expression
EXPECT_LT( L1 - L1, R1 ); // left expression
EXPECT_LT( L1 - L1, R1 + R1 ); // two expressions
......@@ -313,6 +321,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GT )
EXPECT_GT( L2, R1 ); // vector or vector view
EXPECT_GT( L2, 1 ); // right scalar
EXPECT_GT( 2, R1 ); // left scalar
EXPECT_GT( L2, RightReal(1) ); // right scalar
EXPECT_GT( LeftReal(2), R1 ); // left scalar
EXPECT_GT( L1, R1 - R1 ); // right expression
EXPECT_GT( L1 + L1, R1 ); // left expression
EXPECT_GT( L1 + L1, R1 - R1 ); // two expressions
......@@ -326,6 +336,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LE )
EXPECT_LE( L1, R2 ); // vector or vector view
EXPECT_LE( L1, 2 ); // right scalar
EXPECT_LE( 1, R2 ); // left scalar
EXPECT_LE( L1, RightReal(2) ); // right scalar
EXPECT_LE( LeftReal(1), R2 ); // left scalar
EXPECT_LE( L1, R1 + R1 ); // right expression
EXPECT_LE( L1 - L1, R1 ); // left expression
EXPECT_LE( L1 - L1, R1 + R1 ); // two expressions
......@@ -334,6 +346,8 @@ TYPED_TEST( VectorBinaryOperationsTest, LE )
EXPECT_LE( L1, R1 ); // vector or vector view
EXPECT_LE( L1, 1 ); // right scalar
EXPECT_LE( 1, R1 ); // left scalar
EXPECT_LE( L1, RightReal(1) ); // right scalar
EXPECT_LE( LeftReal(1), R1 ); // left scalar
EXPECT_LE( L2, R1 + R1 ); // right expression
EXPECT_LE( L1 + L1, R2 ); // left expression
EXPECT_LE( L1 + L1, R1 + R2 ); // two expressions
......@@ -347,6 +361,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GE )
EXPECT_GE( L2, R1 ); // vector or vector view
EXPECT_GE( L2, 1 ); // right scalar
EXPECT_GE( 2, R1 ); // left scalar
EXPECT_GE( L2, RightReal(1) ); // right scalar
EXPECT_GE( LeftReal(2), R1 ); // left scalar
EXPECT_GE( L1, R1 - R1 ); // right expression
EXPECT_GE( L1 + L1, R1 ); // left expression
EXPECT_GE( L1 + L1, R1 - R1 ); // two expressions
......@@ -355,6 +371,8 @@ TYPED_TEST( VectorBinaryOperationsTest, GE )
EXPECT_LE( L1, R1 ); // vector or vector view
EXPECT_LE( L1, 1 ); // right scalar
EXPECT_LE( 1, R1 ); // left scalar
EXPECT_LE( L1, RightReal(1) ); // right scalar
EXPECT_LE( LeftReal(1), R1 ); // left scalar
EXPECT_LE( L2, R1 + R1 ); // right expression
EXPECT_LE( L1 + L1, R2 ); // left expression
EXPECT_LE( L1 + L1, R1 + R2 ); // two expressions
......@@ -369,6 +387,8 @@ TYPED_TEST( VectorBinaryOperationsTest, addition )
// with scalar
EXPECT_EQ( L1 + 1, 2 );
EXPECT_EQ( 1 + L1, 2 );
EXPECT_EQ( L1 + LeftReal(1), 2 );
EXPECT_EQ( LeftReal(1) + L1, 2 );
// with expression
EXPECT_EQ( L1 + (L1 + L1), 3 );
EXPECT_EQ( (L1 + L1) + L1, 3 );
......@@ -376,6 +396,11 @@ TYPED_TEST( VectorBinaryOperationsTest, addition )
EXPECT_EQ( (L1 + L1) + R1, 3 );
// with two expressions
EXPECT_EQ( (L1 + L1) + (L1 + L1), 4 );
// with expression and scalar
EXPECT_EQ( (L1 + L1) + 1, 3 );
EXPECT_EQ( (L1 + L1) + RightReal(1), 3 );
EXPECT_EQ( 1 + (R1 + R1), 3 );
EXPECT_EQ( LeftReal(1) + (R1 + R1), 3 );
}
TYPED_TEST( VectorBinaryOperationsTest, subtraction )
......@@ -387,6 +412,8 @@ TYPED_TEST( VectorBinaryOperationsTest, subtraction )
// with scalar
EXPECT_EQ( L1 - 1, 0 );
EXPECT_EQ( 1 - L1, 0 );
EXPECT_EQ( L1 - LeftReal(1), 0 );
EXPECT_EQ( LeftReal(1) - L1, 0 );
// with expression
EXPECT_EQ( L2 - (L1 + L1), 0 );
EXPECT_EQ( (L1 + L1) - L2, 0 );
......@@ -394,6 +421,11 @@ TYPED_TEST( VectorBinaryOperationsTest, subtraction )
EXPECT_EQ( (L1 + L1) - R2, 0 );
// with two expressions
EXPECT_EQ( (L1 + L1) - (L1 + L1), 0 );
// with expression and scalar
EXPECT_EQ( (L1 + L1) - 1, 1 );
EXPECT_EQ( (L1 + L1) - RightReal(1), 1 );
EXPECT_EQ( 1 - (R1 + R1), -1 );
EXPECT_EQ( LeftReal(1) - (R1 + R1), -1 );
}
TYPED_TEST( VectorBinaryOperationsTest, multiplication )
......@@ -405,6 +437,8 @@ TYPED_TEST( VectorBinaryOperationsTest, multiplication )
// with scalar
EXPECT_EQ( L1 * 2, L2 );
EXPECT_EQ( 2 * L1, L2 );
EXPECT_EQ( L1 * LeftReal(2), L2 );
EXPECT_EQ( LeftReal(2) * L1, L2 );
// with expression
EXPECT_EQ( L1 * (L1 + L1), L2 );
EXPECT_EQ( (L1 + L1) * L1, L2 );
......@@ -412,6 +446,11 @@ TYPED_TEST( VectorBinaryOperationsTest, multiplication )
EXPECT_EQ( (L1 + L1) * R1, L2 );
// with two expressions
EXPECT_EQ( (L1 + L1) * (L1 + L1), 4 );
// with expression and scalar
EXPECT_EQ( (L1 + L1) * 1, 2 );
EXPECT_EQ( (L1 + L1) * RightReal(1), 2 );
EXPECT_EQ( 1 * (R1 + R1), 2 );
EXPECT_EQ( LeftReal(1) * (R1 + R1), 2 );
}
TYPED_TEST( VectorBinaryOperationsTest, division )
......@@ -423,6 +462,8 @@ TYPED_TEST( VectorBinaryOperationsTest, division )
// with scalar
EXPECT_EQ( L2 / 2, L1 );
EXPECT_EQ( 2 / L2, L1 );
EXPECT_EQ( L2 / LeftReal(2), L1 );
EXPECT_EQ( LeftReal(2) / L2, L1 );
// with expression
EXPECT_EQ( L2 / (L1 + L1), L1 );
EXPECT_EQ( (L1 + L1) / L2, L1 );
......@@ -430,6 +471,11 @@ TYPED_TEST( VectorBinaryOperationsTest, division )
EXPECT_EQ( (L1 + L1) / R2, L1 );
// with two expressions
EXPECT_EQ( (L1 + L1) / (L1 + L1), L1 );
// with expression and scalar
EXPECT_EQ( (L1 + L1) / 1, 2 );
EXPECT_EQ( (L1 + L1) / RightReal(1), 2 );
EXPECT_EQ( 2 / (R1 + R1), 1 );
EXPECT_EQ( LeftReal(2) / (R1 + R1), 1 );
}
template< typename Left, typename Right, std::enable_if_t< std::is_const<typename Left::RealType>::value, bool > = true >
......@@ -438,12 +484,15 @@ void test_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
using RightReal = std::remove_const_t< typename Right::RealType >;
// with vector or vector view
L1 = R2;
EXPECT_EQ( L1, R2 );
// with scalar
L1 = 1;
EXPECT_EQ( L1, 1 );
L1 = RightReal(1);
EXPECT_EQ( L1, 1 );
// with expression
L1 = R1 + R1;
EXPECT_EQ( L1, R1 + R1 );
......@@ -460,6 +509,7 @@ void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
using RightReal = std::remove_const_t< typename Right::RealType >;
// with vector or vector view
L1 += R2;
EXPECT_EQ( L1, R1 + R2 );
......@@ -467,6 +517,9 @@ void test_add_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
L1 = 1;
L1 += 2;
EXPECT_EQ( L1, 3 );
L1 = 1;
L1 += RightReal(2);
EXPECT_EQ( L1, 3 );
// with expression
L1 = 1;
L1 += R1 + R1;
......@@ -484,6 +537,7 @@ void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
using RightReal = std::remove_const_t< typename Right::RealType >;
// with vector or vector view
L1 -= R2;
EXPECT_EQ( L1, R1 - R2 );
......@@ -491,6 +545,9 @@ void test_subtract_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
L1 = 1;
L1 -= 2;
EXPECT_EQ( L1, -1 );
L1 = 1;
L1 -= RightReal(2);
EXPECT_EQ( L1, -1 );
// with expression
L1 = 1;
L1 -= R1 + R1;
......@@ -508,6 +565,7 @@ void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
using RightReal = std::remove_const_t< typename Right::RealType >;
// with vector or vector view
L1 *= R2;
EXPECT_EQ( L1, R2 );
......@@ -515,6 +573,9 @@ void test_multiply_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
L1 = 1;
L1 *= 2;
EXPECT_EQ( L1, 2 );
L1 = 1;
L1 *= RightReal(2);
EXPECT_EQ( L1, 2 );
// with expression
L1 = 1;
L1 *= R1 + R1;
......@@ -532,6 +593,7 @@ void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
template< typename Left, typename Right, std::enable_if_t< ! std::is_const<typename Left::RealType>::value, bool > = true >
void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
{
using RightReal = std::remove_const_t< typename Right::RealType >;
// with vector or vector view
L2 /= R2;
EXPECT_EQ( L1, R1 );
......@@ -539,6 +601,9 @@ void test_divide_assignment( Left& L1, Left& L2, Right& R1, Right& R2 )
L2 = 2;
L2 /= 2;
EXPECT_EQ( L1, 1 );
L1 = 2;
L1 /= RightReal(2);
EXPECT_EQ( L1, 1 );
// with expression
L2 = 2;
L2 /= R1 + R1;
......@@ -602,6 +667,11 @@ TYPED_TEST( VectorBinaryOperationsTest, min )
EXPECT_EQ( TNL::min(L1 + L1, R1), R1 );
// with two expressions
EXPECT_EQ( TNL::min(L1 + L1, R1 + R2), L2 );
// with expression and scalar
EXPECT_EQ( TNL::min(L1 + L1, 1), L1 );
EXPECT_EQ( TNL::min(L1 + L1, RightReal(1)), L1 );
EXPECT_EQ( TNL::min(1, R1 + R1), L1 );
EXPECT_EQ( TNL::min(LeftReal(1), R1 + R1), L1 );
}
TYPED_TEST( VectorBinaryOperationsTest, max )
......@@ -620,6 +690,11 @@ TYPED_TEST( VectorBinaryOperationsTest, max )
EXPECT_EQ( TNL::max(L1 + L1, R1), R2 );
// with two expressions
EXPECT_EQ( TNL::max(L1 - L1, R1 + R1), L2 );
// with expression and scalar
EXPECT_EQ( TNL::max(L1 + L1, 1), L2 );
EXPECT_EQ( TNL::max(L1 + L1, RightReal(1)), L2 );
EXPECT_EQ( TNL::max(1, R1 + R1), L2 );
EXPECT_EQ( TNL::max(LeftReal(1), R1 + R1), L2 );
}
#if defined(HAVE_CUDA) && !defined(STATIC_VECTOR)
......
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"
#include "VectorOfStaticVectorsTest.h"
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"
#include "VectorOfStaticVectorsTest.h"
#define VECTOR_OF_STATIC_VECTORS
#include "VectorBinaryOperationsTest.h"
#include "VectorUnaryOperationsTest.h"
#include "VectorVerticalOperationsTest.h"
#include "../main.h"
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment