Commit d37e5540 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Simplification of expression templates

parent 2fe47e38
Loading
Loading
Loading
Loading
+21 −25
Original line number Diff line number Diff line
@@ -22,11 +22,11 @@ namespace Expressions {
////
// Distributed unary expression template
template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct DistributedUnaryExpressionTemplate;

template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct HasEnabledDistributedExpressionTemplates< DistributedUnaryExpressionTemplate< T1, Operation > >
: std::true_type
{};
@@ -35,7 +35,7 @@ struct HasEnabledDistributedExpressionTemplates< DistributedUnaryExpressionTempl
// Distributed binary expression template
template< typename T1,
          typename T2,
          template< typename, typename > class Operation,
          typename Operation,
          ExpressionVariableType T1Type = getExpressionVariableType< T1, T2 >(),
          ExpressionVariableType T2Type = getExpressionVariableType< T2, T1 >() >
struct DistributedBinaryExpressionTemplate
@@ -43,7 +43,7 @@ struct DistributedBinaryExpressionTemplate

template< typename T1,
          typename T2,
          template< typename, typename > class Operation,
          typename Operation,
          ExpressionVariableType T1Type,
          ExpressionVariableType T2Type >
struct HasEnabledDistributedExpressionTemplates< DistributedBinaryExpressionTemplate< T1, T2, Operation, T1Type, T2Type > >
@@ -52,11 +52,10 @@ struct HasEnabledDistributedExpressionTemplates< DistributedBinaryExpressionTemp

template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct DistributedBinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, VectorExpressionVariable >
{
   using RealType = decltype( Operation< typename T1::RealType, typename T2::RealType >::
                              evaluate( std::declval<T1>()[0], std::declval<T2>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0], std::declval<T2>()[0] ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using CommunicatorType = typename T1::CommunicatorType;
@@ -121,11 +120,10 @@ protected:

template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct DistributedBinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, ArithmeticVariable >
{
   using RealType = decltype( Operation< typename T1::RealType, T2 >::
                              evaluate( std::declval<T1>()[0], std::declval<T2>() ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0], std::declval<T2>() ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using CommunicatorType = typename T1::CommunicatorType;
@@ -175,11 +173,10 @@ protected:

template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct DistributedBinaryExpressionTemplate< T1, T2, Operation, ArithmeticVariable, VectorExpressionVariable >
{
   using RealType = decltype( Operation< T1, typename T2::RealType >::
                              evaluate( std::declval<T1>(), std::declval<T2>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>(), std::declval<T2>()[0] ) );
   using DeviceType = typename T2::DeviceType;
   using IndexType = typename T2::IndexType;
   using CommunicatorType = typename T2::CommunicatorType;
@@ -230,11 +227,10 @@ protected:
////
// Distributed unary expression template
template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct DistributedUnaryExpressionTemplate
{
   using RealType = decltype( Operation< typename T1::RealType >::
                              evaluate( std::declval<T1>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0] ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using CommunicatorType = typename T1::CommunicatorType;
@@ -667,7 +663,7 @@ template< typename ResultType,
          typename ET1,
          typename..., typename = EnableIfDistributedUnaryExpression_t< ET1 >,
          // workaround: templated type alias cannot be declared at block level
          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation,
          typename CastOperation = typename Cast< ResultType >::Operation,
          typename = void, typename = void >
auto
cast( const ET1& a )
@@ -802,7 +798,7 @@ binaryAnd( const ET1& a )
// Output stream
template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
std::ostream& operator<<( std::ostream& str, const DistributedBinaryExpressionTemplate< T1, T2, Operation >& expression )
{
   str << "[ ";
@@ -813,7 +809,7 @@ std::ostream& operator<<( std::ostream& str, const DistributedBinaryExpressionTe
}

template< typename T,
          template< typename > class Operation >
          typename Operation >
std::ostream& operator<<( std::ostream& str, const DistributedUnaryExpressionTemplate< T, Operation >& expression )
{
   str << "[ ";
@@ -930,7 +926,7 @@ using Containers::binaryOr;
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result evaluateAndReduce( Vector& lhs,
@@ -949,7 +945,7 @@ Result evaluateAndReduce( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result evaluateAndReduce( Vector& lhs,
@@ -971,7 +967,7 @@ Result evaluateAndReduce( Vector& lhs,
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduce( Vector& lhs,
@@ -994,7 +990,7 @@ Result addAndReduce( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduce( Vector& lhs,
@@ -1020,7 +1016,7 @@ Result addAndReduce( Vector& lhs,
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduceAbs( Vector& lhs,
@@ -1043,7 +1039,7 @@ Result addAndReduceAbs( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduceAbs( Vector& lhs,
+29 −33
Original line number Diff line number Diff line
@@ -25,25 +25,25 @@ namespace Containers {
namespace Expressions {

template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct UnaryExpressionTemplate;

template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct HasEnabledExpressionTemplates< UnaryExpressionTemplate< T1, Operation > >
: std::true_type
{};

template< typename T1,
          typename T2,
          template< typename, typename > class Operation,
          typename Operation,
          ExpressionVariableType T1Type = getExpressionVariableType< T1, T2 >(),
          ExpressionVariableType T2Type = getExpressionVariableType< T2, T1 >() >
struct BinaryExpressionTemplate;

template< typename T1,
          typename T2,
          template< typename, typename > class Operation,
          typename Operation,
          ExpressionVariableType T1Type,
          ExpressionVariableType T2Type >
struct HasEnabledExpressionTemplates< BinaryExpressionTemplate< T1, T2, Operation, T1Type, T2Type > >
@@ -55,11 +55,10 @@ struct HasEnabledExpressionTemplates< BinaryExpressionTemplate< T1, T2, Operatio
// Non-static binary expression template
template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, VectorExpressionVariable >
{
   using RealType = decltype( Operation< typename T1::RealType, typename T2::RealType >::
                              evaluate( std::declval<T1>()[0], std::declval<T2>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0], std::declval<T2>()[0] ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using ConstViewType = BinaryExpressionTemplate;
@@ -78,13 +77,13 @@ struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, Ve

   RealType getElement( const IndexType i ) const
   {
      return Operation< typename T1::RealType, typename T2::RealType >::evaluate( op1.getElement( i ), op2.getElement( i ) );
      return Operation::evaluate( op1.getElement( i ), op2.getElement( i ) );
   }

   __cuda_callable__
   RealType operator[]( const IndexType i ) const
   {
      return Operation< typename T1::RealType, typename T2::RealType >::evaluate( op1[ i ], op2[ i ] );
      return Operation::evaluate( op1[ i ], op2[ i ] );
   }

   __cuda_callable__
@@ -105,11 +104,10 @@ protected:

template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, ArithmeticVariable >
{
   using RealType = decltype( Operation< typename T1::RealType, T2 >::
                              evaluate( std::declval<T1>()[0], std::declval<T2>() ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0], std::declval<T2>() ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using ConstViewType = BinaryExpressionTemplate;
@@ -119,13 +117,13 @@ struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, Ar

   RealType getElement( const IndexType i ) const
   {
      return Operation< typename T1::RealType, T2 >::evaluate( op1.getElement( i ), op2 );
      return Operation::evaluate( op1.getElement( i ), op2 );
   }

   __cuda_callable__
   RealType operator[]( const IndexType i ) const
   {
      return Operation< typename T1::RealType, T2 >::evaluate( op1[ i ], op2 );
      return Operation::evaluate( op1[ i ], op2 );
   }

   __cuda_callable__
@@ -146,11 +144,10 @@ protected:

template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
struct BinaryExpressionTemplate< T1, T2, Operation, ArithmeticVariable, VectorExpressionVariable >
{
   using RealType = decltype( Operation< T1, typename T2::RealType >::
                              evaluate( std::declval<T1>(), std::declval<T2>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>(), std::declval<T2>()[0] ) );
   using DeviceType = typename T2::DeviceType;
   using IndexType = typename T2::IndexType;
   using ConstViewType = BinaryExpressionTemplate;
@@ -160,13 +157,13 @@ struct BinaryExpressionTemplate< T1, T2, Operation, ArithmeticVariable, VectorEx

   RealType getElement( const IndexType i ) const
   {
      return Operation< T1, typename T2::RealType >::evaluate( op1, op2.getElement( i ) );
      return Operation::evaluate( op1, op2.getElement( i ) );
   }

   __cuda_callable__
   RealType operator[]( const IndexType i ) const
   {
      return Operation< T1, typename T2::RealType >::evaluate( op1, op2[ i ] );
      return Operation::evaluate( op1, op2[ i ] );
   }

   __cuda_callable__
@@ -188,11 +185,10 @@ protected:
////
// Non-static unary expression template
template< typename T1,
          template< typename > class Operation >
          typename Operation >
struct UnaryExpressionTemplate
{
   using RealType = decltype( Operation< typename T1::RealType >::
                              evaluate( std::declval<T1>()[0] ) );
   using RealType = decltype( Operation::evaluate( std::declval<T1>()[0] ) );
   using DeviceType = typename T1::DeviceType;
   using IndexType = typename T1::IndexType;
   using ConstViewType = UnaryExpressionTemplate;
@@ -202,13 +198,13 @@ struct UnaryExpressionTemplate

   RealType getElement( const IndexType i ) const
   {
      return Operation< typename T1::RealType >::evaluate( operand.getElement( i ) );
      return Operation::evaluate( operand.getElement( i ) );
   }

   __cuda_callable__
   RealType operator[]( const IndexType i ) const
   {
      return Operation< typename T1::RealType >::evaluate( operand[ i ] );
      return Operation::evaluate( operand[ i ] );
   }

   __cuda_callable__
@@ -612,7 +608,7 @@ template< typename ResultType,
          typename ET1,
          typename..., typename = EnableIfUnaryExpression_t< ET1 >,
          // workaround: templated type alias cannot be declared at block level
          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation,
          typename CastOperation = typename Cast< ResultType >::Operation,
          typename = void >
auto
cast( const ET1& a )
@@ -749,7 +745,7 @@ binaryOr( const ET1& a )
// Output stream
template< typename T1,
          typename T2,
          template< typename, typename > class Operation >
          typename Operation >
std::ostream& operator<<( std::ostream& str, const BinaryExpressionTemplate< T1, T2, Operation >& expression )
{
   str << "[ ";
@@ -760,7 +756,7 @@ std::ostream& operator<<( std::ostream& str, const BinaryExpressionTemplate< T1,
}

template< typename T,
          template< typename > class Operation >
          typename Operation >
std::ostream& operator<<( std::ostream& str, const UnaryExpressionTemplate< T, Operation >& expression )
{
   str << "[ ";
@@ -875,7 +871,7 @@ using Containers::binaryOr;
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result evaluateAndReduce( Vector& lhs,
@@ -894,7 +890,7 @@ Result evaluateAndReduce( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result evaluateAndReduce( Vector& lhs,
@@ -916,7 +912,7 @@ Result evaluateAndReduce( Vector& lhs,
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduce( Vector& lhs,
@@ -939,7 +935,7 @@ Result addAndReduce( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduce( Vector& lhs,
@@ -965,7 +961,7 @@ Result addAndReduce( Vector& lhs,
template< typename Vector,
   typename T1,
   typename T2,
   template< typename, typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduceAbs( Vector& lhs,
@@ -988,7 +984,7 @@ Result addAndReduceAbs( Vector& lhs,

template< typename Vector,
   typename T1,
   template< typename > class Operation,
   typename Operation,
   typename Reduction,
   typename Result >
Result addAndReduceAbs( Vector& lhs,
+31 −31
Original line number Diff line number Diff line
@@ -16,9 +16,9 @@ namespace TNL {
namespace Containers {
namespace Expressions {

template< typename T1, typename T2 >
struct Addition
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( a + b )
   {
@@ -26,9 +26,9 @@ struct Addition
   }
};

template< typename T1, typename T2 >
struct Subtraction
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( a - b )
   {
@@ -36,9 +36,9 @@ struct Subtraction
   }
};

template< typename T1, typename T2 >
struct Multiplication
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( a * b )
   {
@@ -46,9 +46,9 @@ struct Multiplication
   }
};

template< typename T1, typename T2 >
struct Division
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( a / b )
   {
@@ -56,9 +56,9 @@ struct Division
   }
};

template< typename T1, typename T2 >
struct Min
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( min( a , b ) )
   {
@@ -66,9 +66,9 @@ struct Min
   }
};

template< typename T1, typename T2 >
struct Max
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& b ) -> decltype( max( a, b ) )
   {
@@ -76,9 +76,9 @@ struct Max
   }
};

template< typename T1 >
struct Minus
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( -a )
   {
@@ -86,9 +86,9 @@ struct Minus
   }
};

template< typename T1 >
struct Abs
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( abs( a ) )
   {
@@ -96,9 +96,9 @@ struct Abs
   }
};

template< typename T1, typename T2 >
struct Pow
{
   template< typename T1, typename T2 >
   __cuda_callable__
   static auto evaluate( const T1& a, const T2& exp ) -> decltype( pow( a, exp ) )
   {
@@ -106,9 +106,9 @@ struct Pow
   }
};

template< typename T1 >
struct Exp
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( exp( a ) )
   {
@@ -116,9 +116,9 @@ struct Exp
   }
};

template< typename T1 >
struct Sqrt
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( sqrt( a ) )
   {
@@ -126,9 +126,9 @@ struct Sqrt
   }
};

template< typename T1 >
struct Cbrt
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( cbrt( a ) )
   {
@@ -136,9 +136,9 @@ struct Cbrt
   }
};

template< typename T1 >
struct Log
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( log( a ) )
   {
@@ -146,9 +146,9 @@ struct Log
   }
};

template< typename T1 >
struct Log10
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( log10( a ) )
   {
@@ -156,9 +156,9 @@ struct Log10
   }
};

template< typename T1 >
struct Log2
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( log2( a ) )
   {
@@ -166,9 +166,9 @@ struct Log2
   }
};

template< typename T1 >
struct Sin
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( sin( a ) )
   {
@@ -176,9 +176,9 @@ struct Sin
   }
};

template< typename T1 >
struct Cos
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( cos( a ) )
   {
@@ -186,9 +186,9 @@ struct Cos
   }
};

template< typename T1 >
struct Tan
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( tan( a ) )
   {
@@ -196,9 +196,9 @@ struct Tan
   }
};

template< typename T1 >
struct Asin
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( asin( a ) )
   {
@@ -206,9 +206,9 @@ struct Asin
   }
};

template< typename T1 >
struct Acos
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( acos( a ) )
   {
@@ -216,9 +216,9 @@ struct Acos
   }
};

template< typename T1 >
struct Atan
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( atan( a ) )
   {
@@ -226,9 +226,9 @@ struct Atan
   }
};

template< typename T1 >
struct Sinh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( sinh( a ) )
   {
@@ -236,9 +236,9 @@ struct Sinh
   }
};

template< typename T1 >
struct Cosh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( cosh( a ) )
   {
@@ -246,9 +246,9 @@ struct Cosh
   }
};

template< typename T1 >
struct Tanh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( tanh( a ) )
   {
@@ -256,9 +256,9 @@ struct Tanh
   }
};

template< typename T1 >
struct Asinh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( asinh( a ) )
   {
@@ -266,9 +266,9 @@ struct Asinh
   }
};

template< typename T1 >
struct Acosh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( acosh( a ) )
   {
@@ -276,9 +276,9 @@ struct Acosh
   }
};

template< typename T1 >
struct Atanh
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( atanh( a ) )
   {
@@ -286,9 +286,9 @@ struct Atanh
   }
};

template< typename T1 >
struct Floor
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( floor( a ) )
   {
@@ -296,9 +296,9 @@ struct Floor
   }
};

template< typename T1 >
struct Ceil
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( ceil( a ) )
   {
@@ -306,9 +306,9 @@ struct Ceil
   }
};

template< typename T1 >
struct Sign
{
   template< typename T1 >
   __cuda_callable__
   static auto evaluate( const T1& a ) -> decltype( sign( a ) )
   {
@@ -319,9 +319,9 @@ struct Sign
template< typename ResultType >
struct Cast
{
   template< typename T1 >
   struct Operation
   {
      template< typename T1 >
      __cuda_callable__
      static auto evaluate( const T1& a ) -> ResultType
      {
+25 −29

File changed.

Preview size limit exceeded, changes collapsed.

+2 −2
Original line number Diff line number Diff line
@@ -73,7 +73,7 @@ public:
    */
   template< typename T1,
             typename T2,
             template< typename, typename > class Operation >
             typename Operation >
   __cuda_callable__
   StaticVector( const Expressions::StaticBinaryExpressionTemplate< T1, T2, Operation >& expr );

@@ -83,7 +83,7 @@ public:
    * \param expr is unary expression
    */
   template< typename T,
             template< typename > class Operation >
             typename Operation >
   __cuda_callable__
   StaticVector( const Expressions::StaticUnaryExpressionTemplate< T, Operation >& expr );

Loading