diff --git a/src/TNL/Containers/Algorithms/ArrayAssignment.h b/src/TNL/Containers/Algorithms/ArrayAssignment.h index 5d45bbbceb86b1468da9efeccdec15e3a4b37de8..9a67a36b9190d3243332bf985ac978fcb5b7cae9 100644 --- a/src/TNL/Containers/Algorithms/ArrayAssignment.h +++ b/src/TNL/Containers/Algorithms/ArrayAssignment.h @@ -10,36 +10,16 @@ #pragma once -#include <type_traits> -#include <utility> +#include <TNL/TypeTraits.h> #include <TNL/Containers/Algorithms/ArrayOperations.h> namespace TNL { namespace Containers { namespace Algorithms { -namespace detail { -/** - * SFINAE for checking if T has getArrayData method - */ -template< typename T > -class HasGetArrayData -{ -private: - typedef char YesType[1]; - typedef char NoType[2]; - - template< typename C > static YesType& test( decltype(std::declval< C >().getArrayData()) ); - template< typename C > static NoType& test(...); - -public: - static constexpr bool value = ( sizeof( test< T >(0) ) == sizeof( YesType ) ); -}; -} // namespace detail - template< typename Array, typename T, - bool hasGetArrayData = detail::HasGetArrayData< T >::value > + bool isArrayType = IsArrayType< T >::value > struct ArrayAssignment; /** diff --git a/src/TNL/Containers/Algorithms/VectorAssignment.h b/src/TNL/Containers/Algorithms/VectorAssignment.h index e9de7241f8faff2ee8e9ada81e12c9df1ff74446..f2b63bf0d78710af9e5a94d38e718e78f1722cae 100644 --- a/src/TNL/Containers/Algorithms/VectorAssignment.h +++ b/src/TNL/Containers/Algorithms/VectorAssignment.h @@ -10,41 +10,20 @@ #pragma once -#include <type_traits> -#include <utility> -#include <TNL/Containers/Algorithms/VectorOperations.h> +#include <TNL/TypeTraits.h> #include <TNL/ParallelFor.h> +#include <TNL/Containers/Algorithms/VectorOperations.h> namespace TNL { namespace Containers { namespace Algorithms { -namespace detail { -/** - * SFINAE for checking if T has getSize method - * TODO: We should better test operator[] but we need to know the indexing type. - */ -template< typename T > -class HasSubscriptOperator -{ -private: - typedef char YesType[1]; - typedef char NoType[2]; - - template< typename C > static YesType& test( decltype(std::declval< C >().getSize() ) ); - template< typename C > static NoType& test(...); - -public: - static constexpr bool value = ( sizeof( test< T >(0) ) == sizeof( YesType ) ); -}; -} // namespace detail - /** * \brief Vector assignment */ template< typename Vector, typename T, - bool hasSubscriptOperator = detail::HasSubscriptOperator< T >::value > + bool hasSubscriptOperator = HasSubscriptOperator< T >::value > struct VectorAssignment{}; /** @@ -52,7 +31,7 @@ struct VectorAssignment{}; */ template< typename Vector, typename T, - bool hasSubscriptOperator = detail::HasSubscriptOperator< T >::value > + bool hasSubscriptOperator = HasSubscriptOperator< T >::value > struct VectorAddition{}; /** @@ -60,7 +39,7 @@ struct VectorAddition{}; */ template< typename Vector, typename T, - bool hasSubscriptOperator = detail::HasSubscriptOperator< T >::value > + bool hasSubscriptOperator = HasSubscriptOperator< T >::value > struct VectorSubtraction{}; /** @@ -68,7 +47,7 @@ struct VectorSubtraction{}; */ template< typename Vector, typename T, - bool hasSubscriptOperator = detail::HasSubscriptOperator< T >::value > + bool hasSubscriptOperator = HasSubscriptOperator< T >::value > struct VectorMultiplication{}; /** @@ -76,7 +55,7 @@ struct VectorMultiplication{}; */ template< typename Vector, typename T, - bool hasSubscriptOperator = detail::HasSubscriptOperator< T >::value > + bool hasSubscriptOperator = HasSubscriptOperator< T >::value > struct VectorDivision{}; /** @@ -95,7 +74,7 @@ struct VectorAssignment< Vector, T, true > static void assignStatic( Vector& v, const T& t ) { TNL_ASSERT_EQ( v.getSize(), t.getSize(), "The sizes of the vectors must be equal." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] = t[ i ]; } @@ -107,11 +86,11 @@ struct VectorAssignment< Vector, T, true > using IndexType = typename Vector::IndexType; RealType* data = v.getData(); - auto ass = [=] __cuda_callable__ ( IndexType i ) + auto assignment = [=] __cuda_callable__ ( IndexType i ) { data[ i ] = t[ i ]; }; - ParallelFor< DeviceType >::exec( ( IndexType ) 0, v.getSize(), ass ); + ParallelFor< DeviceType >::exec( ( IndexType ) 0, v.getSize(), assignment ); } }; @@ -131,7 +110,7 @@ struct VectorAssignment< Vector, T, false > static void assignStatic( Vector& v, const T& t ) { TNL_ASSERT_GT( v.getSize(), 0, "Cannot assign value to empty vector." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] = t; } @@ -142,11 +121,11 @@ struct VectorAssignment< Vector, T, false > using IndexType = typename Vector::IndexType; RealType* data = v.getData(); - auto ass = [=] __cuda_callable__ ( IndexType i ) + auto assignment = [=] __cuda_callable__ ( IndexType i ) { data[ i ] = t; }; - ParallelFor< DeviceType >::exec( ( IndexType ) 0, v.getSize(), ass ); + ParallelFor< DeviceType >::exec( ( IndexType ) 0, v.getSize(), assignment ); } }; @@ -161,7 +140,7 @@ struct VectorAddition< Vector, T, true > static void additionStatic( Vector& v, const T& t ) { TNL_ASSERT_EQ( v.getSize(), t.getSize(), "The sizes of the vectors must be equal." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] += t[ i ]; } @@ -193,7 +172,7 @@ struct VectorAddition< Vector, T, false > static void additionStatic( Vector& v, const T& t ) { TNL_ASSERT_GT( v.getSize(), 0, "Cannot assign value to empty vector." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] += t; } @@ -223,7 +202,7 @@ struct VectorSubtraction< Vector, T, true > static void subtractionStatic( Vector& v, const T& t ) { TNL_ASSERT_EQ( v.getSize(), t.getSize(), "The sizes of the vectors must be equal." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] -= t[ i ]; } @@ -255,7 +234,7 @@ struct VectorSubtraction< Vector, T, false > static void subtractionStatic( Vector& v, const T& t ) { TNL_ASSERT_GT( v.getSize(), 0, "Cannot assign value to empty vector." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] -= t; } @@ -285,7 +264,7 @@ struct VectorMultiplication< Vector, T, true > static void multiplicationStatic( Vector& v, const T& t ) { TNL_ASSERT_EQ( v.getSize(), t.getSize(), "The sizes of the vectors must be equal." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] *= t[ i ]; } @@ -317,7 +296,7 @@ struct VectorMultiplication< Vector, T, false > static void multiplicationStatic( Vector& v, const T& t ) { TNL_ASSERT_GT( v.getSize(), 0, "Cannot assign value to empty vector." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] *= t; } @@ -348,7 +327,7 @@ struct VectorDivision< Vector, T, true > static void divisionStatic( Vector& v, const T& t ) { TNL_ASSERT_EQ( v.getSize(), t.getSize(), "The sizes of the vectors must be equal." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] /= t[ i ]; } @@ -380,7 +359,7 @@ struct VectorDivision< Vector, T, false > static void divisionStatic( Vector& v, const T& t ) { TNL_ASSERT_GT( v.getSize(), 0, "Cannot assign value to empty vector." ); - for( decltype( v.getSize() ) i = 0; i < v.getSize(); i ++ ) + for( decltype( v.getSize() ) i = 0; i < v.getSize(); i++ ) v[ i ] /= t; } diff --git a/src/TNL/Containers/Array.h b/src/TNL/Containers/Array.h index 9e304a8311f3982feb63926df08f34252b18dd5a..135e29ff12038c6b533a59660edb8eb445869755 100644 --- a/src/TNL/Containers/Array.h +++ b/src/TNL/Containers/Array.h @@ -478,7 +478,9 @@ class Array * \param data Reference to the source array or value. * \return Reference to this array. */ - template< typename T > + template< typename T, + typename..., + typename = std::enable_if_t< std::is_convertible< T, ValueType >::value || IsArrayType< T >::value > > Array& operator=( const T& data ); /** @@ -680,13 +682,6 @@ template< typename Value, typename Device, typename Index, typename Allocator > File& operator>>( File&& file, Array< Value, Device, Index, Allocator >& array ); } // namespace Containers - -template< typename Value_, typename Device, typename Index > -struct IsStatic< Containers::Array< Value_, Device, Index > > -{ - static constexpr bool Value = false; -}; - } // namespace TNL #include <TNL/Containers/Array.hpp> diff --git a/src/TNL/Containers/Array.hpp b/src/TNL/Containers/Array.hpp index 76ce6384439c7ecb4f7dcc62a76d183ad4fd7e91..758de1846356b914ac82bb4a6c2ca7d09160545a 100644 --- a/src/TNL/Containers/Array.hpp +++ b/src/TNL/Containers/Array.hpp @@ -584,7 +584,7 @@ template< typename Value, typename Device, typename Index, typename Allocator > - template< typename T > + template< typename T, typename..., typename > Array< Value, Device, Index, Allocator >& Array< Value, Device, Index, Allocator >:: operator=( const T& data ) diff --git a/src/TNL/Containers/ArrayView.h b/src/TNL/Containers/ArrayView.h index 3584efd217a2463e8dd5e622ecc82ebb032edaac..d6381d443981ec4805b784df9826e44c06487d3a 100644 --- a/src/TNL/Containers/ArrayView.h +++ b/src/TNL/Containers/ArrayView.h @@ -14,6 +14,7 @@ #include <type_traits> // std::add_const_t +#include <TNL/TypeTraits.h> #include <TNL/File.h> #include <TNL/Devices/Host.h> #include <TNL/Devices/Cuda.h> @@ -221,7 +222,9 @@ public: * \param data Reference to the source array or value. * \return Reference to this array view. */ - template< typename T > + template< typename T, + typename..., + typename = std::enable_if_t< std::is_convertible< T, ValueType >::value || IsArrayType< T >::value > > ArrayView& operator=( const T& array ); /** @@ -490,14 +493,6 @@ template< typename Value, typename Device, typename Index > File& operator>>( File&& file, ArrayView< Value, Device, Index > view ); } // namespace Containers - -template< typename Value_, typename Device, typename Index > -struct IsStatic< Containers::ArrayView< Value_, Device, Index > > -{ - static constexpr bool Value = false; -}; - - } // namespace TNL #include <TNL/Containers/ArrayView.hpp> diff --git a/src/TNL/Containers/ArrayView.hpp b/src/TNL/Containers/ArrayView.hpp index 8c70b5c0548c38d2c473edcff4bd877ed2de9409..e1070860cd327bfb0bcb03d232058d0c141b8ee6 100644 --- a/src/TNL/Containers/ArrayView.hpp +++ b/src/TNL/Containers/ArrayView.hpp @@ -132,7 +132,7 @@ operator=( const ArrayView& view ) template< typename Value, typename Device, typename Index > - template< typename T > + template< typename T, typename..., typename > ArrayView< Value, Device, Index >& ArrayView< Value, Device, Index >:: operator=( const T& data ) diff --git a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h index e4e96e86fc81ae1b861eec27978af7b934a9b7ff..f2bed09812172b1e6d41f3a8b80af7ba6113d508 100644 --- a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h @@ -15,6 +15,7 @@ #include <TNL/Containers/Expressions/ExpressionVariableType.h> #include <TNL/Containers/Expressions/DistributedComparison.h> #include <TNL/Containers/Expressions/IsStatic.h> +#include <TNL/Containers/Expressions/TypeTraits.h> #include <TNL/Communicators/MPIPrint.h> @@ -34,8 +35,16 @@ template< typename T1, typename Communicator, ExpressionVariableType T1Type = ExpressionVariableTypeGetter< T1 >::value > struct DistributedUnaryExpressionTemplate -{ -}; +{}; + +template< typename T1, + template< typename > class Operation, + typename Parameter, + typename Communicator, + ExpressionVariableType T1Type > +struct IsExpressionTemplate< DistributedUnaryExpressionTemplate< T1, Operation, Parameter, Communicator, T1Type > > +: std::true_type +{}; //// // Distributed binary expression template @@ -46,8 +55,17 @@ template< typename T1, ExpressionVariableType T1Type = ExpressionVariableTypeGetter< T1 >::value, ExpressionVariableType T2Type = ExpressionVariableTypeGetter< T2 >::value > struct DistributedBinaryExpressionTemplate -{ -}; +{}; + +template< typename T1, + typename T2, + template< typename, typename > class Operation, + typename Communicator, + ExpressionVariableType T1Type, + ExpressionVariableType T2Type > +struct IsExpressionTemplate< DistributedBinaryExpressionTemplate< T1, T2, Operation, Communicator, T1Type, T2Type > > +: std::true_type +{}; template< typename T1, typename T2, @@ -58,7 +76,6 @@ struct DistributedBinaryExpressionTemplate< T1, T2, Operation, Communicator, Vec using RealType = typename std::remove_const< typename T1::RealType >::type; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; using CommunicatorType = Communicator; //Communicators::MpiCommunicator; using CommunicationGroup = typename CommunicatorType::CommunicationGroup; @@ -116,7 +133,6 @@ struct DistributedBinaryExpressionTemplate< T1, T2, Operation, Communicator, Vec using CommunicatorType = Communicator; using CommunicationGroup = typename CommunicatorType::CommunicationGroup; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return false; } DistributedBinaryExpressionTemplate( const T1& a, const T2& b, const CommunicationGroup& group ) @@ -169,7 +185,6 @@ struct DistributedBinaryExpressionTemplate< T1, T2, Operation, Communicator, Ari using CommunicatorType = Communicator; using CommunicationGroup = typename CommunicatorType::CommunicationGroup; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return false; } DistributedBinaryExpressionTemplate( const T1& a, const T2& b, const CommunicationGroup& group ) @@ -224,7 +239,6 @@ struct DistributedUnaryExpressionTemplate< T1, Operation, Parameter, Communicato using RealType = typename std::remove_const< typename T1::RealType >::type; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; using CommunicatorType = Communicator; using CommunicationGroup = typename CommunicatorType::CommunicationGroup; static constexpr bool isStatic() { return false; } @@ -280,7 +294,6 @@ struct DistributedUnaryExpressionTemplate< T1, Operation, void, Communicator, Ve using RealType = typename std::remove_const< typename T1::RealType >::type; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; using CommunicatorType = Communicator; using CommunicationGroup = typename CommunicatorType::CommunicationGroup; static constexpr bool isStatic() { return false; } diff --git a/src/TNL/Containers/Expressions/ExpressionTemplates.h b/src/TNL/Containers/Expressions/ExpressionTemplates.h index 93525009d71e2a7f4b29112e22279b082e1278c3..12e763a7586e3792827ee57391dff74383bb49d9 100644 --- a/src/TNL/Containers/Expressions/ExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/ExpressionTemplates.h @@ -15,7 +15,7 @@ #include <TNL/Containers/Expressions/ExpressionVariableType.h> #include <TNL/Containers/Expressions/Comparison.h> #include <TNL/Containers/Expressions/IsStatic.h> -#include <TNL/Containers/Expressions/IsNumericExpression.h> +#include <TNL/Containers/Expressions/TypeTraits.h> namespace TNL { namespace Containers { @@ -34,7 +34,7 @@ template< typename T1, template< typename > class Operation, typename Parameter, ExpressionVariableType T1Type > -struct IsNumericExpression< UnaryExpressionTemplate< T1, Operation, Parameter, T1Type > > +struct IsExpressionTemplate< UnaryExpressionTemplate< T1, Operation, Parameter, T1Type > > : std::true_type {}; @@ -53,7 +53,7 @@ template< typename T1, template< typename, typename > class Operation, ExpressionVariableType T1Type, ExpressionVariableType T2Type > -struct IsNumericExpression< BinaryExpressionTemplate< T1, T2, Operation, T1Type, T2Type > > +struct IsExpressionTemplate< BinaryExpressionTemplate< T1, T2, Operation, T1Type, T2Type > > : std::true_type {}; @@ -65,7 +65,6 @@ struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, Ve using RealType = typename T1::RealType; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; static_assert( std::is_same< typename T1::DeviceType, typename T2::DeviceType >::value, "Attempt to mix operands allocated on different device types." ); static_assert( IsStaticType< T1 >::value == IsStaticType< T2 >::value, "Attempt to mix static and non-static operands in binary expression templates." ); @@ -111,7 +110,6 @@ struct BinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariable, Ar using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; static constexpr bool is() { return false; } BinaryExpressionTemplate( const T1& a, const T2& b ): op1( a ), op2( b ){} @@ -154,7 +152,6 @@ struct BinaryExpressionTemplate< T1, T2, Operation, ArithmeticVariable, VectorEx using DeviceType = typename T2::DeviceType; using IndexType = typename T2::IndexType; - using IsExpressionTemplate = bool; static constexpr bool is() { return false; } BinaryExpressionTemplate( const T1& a, const T2& b ): op1( a ), op2( b ){} @@ -201,7 +198,7 @@ struct UnaryExpressionTemplate< T1, Operation, Parameter, VectorExpressionVariab using RealType = typename T1::RealType; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; + static constexpr bool is() { return false; } UnaryExpressionTemplate( const T1& a, const Parameter& p ) @@ -248,7 +245,7 @@ struct UnaryExpressionTemplate< T1, Operation, void, VectorExpressionVariable > using RealType = typename T1::RealType; using DeviceType = typename T1::DeviceType; using IndexType = typename T1::IndexType; - using IsExpressionTemplate = bool; + static constexpr bool is() { return false; } UnaryExpressionTemplate( const T1& a ): operand( a ){} diff --git a/src/TNL/Containers/Expressions/ExpressionVariableType.h b/src/TNL/Containers/Expressions/ExpressionVariableType.h index 0a9bbf0681c104284ebf54fd69cdd14e04a81df8..0d503f3f22c1c2a216057e0333a4adffa953b595 100644 --- a/src/TNL/Containers/Expressions/ExpressionVariableType.h +++ b/src/TNL/Containers/Expressions/ExpressionVariableType.h @@ -11,6 +11,7 @@ #pragma once #include <type_traits> +#include <TNL/Containers/Expressions/TypeTraits.h> namespace TNL { namespace Containers { @@ -51,23 +52,6 @@ namespace Expressions { enum ExpressionVariableType { ArithmeticVariable, VectorVariable, VectorExpressionVariable, OtherVariable }; -/** - * SFINAE for checking if T has getSize method - */ -template< typename T > -class IsExpressionTemplate -{ -private: - typedef char YesType[1]; - typedef char NoType[2]; - - template< typename C > static YesType& test( typename C::IsExpressionTemplate ); - template< typename C > static NoType& test(...); - -public: - static constexpr bool value = ( sizeof( test< typename std::remove_reference< T >::type >(0) ) == sizeof( YesType ) ); -}; - template< typename T > struct IsVectorType { diff --git a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h index 7981f54f148362d48c9b9361bd9987961179d09a..33941cc1e082392e1d8d7e7292660683dbf3fdb9 100644 --- a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h @@ -15,6 +15,7 @@ #include <TNL/Containers/Expressions/ExpressionVariableType.h> #include <TNL/Containers/Expressions/StaticComparison.h> #include <TNL/Containers/Expressions/IsStatic.h> +#include <TNL/Containers/Expressions/TypeTraits.h> #include <TNL/Containers/Expressions/VerticalOperations.h> namespace TNL { @@ -27,8 +28,15 @@ template< typename T1, typename Parameter = void, ExpressionVariableType T1Type = ExpressionVariableTypeGetter< T1 >::value > struct StaticUnaryExpressionTemplate -{ -}; +{}; + +template< typename T1, + template< typename > class Operation, + typename Parameter, + ExpressionVariableType T1Type > +struct IsExpressionTemplate< StaticUnaryExpressionTemplate< T1, Operation, Parameter, T1Type > > +: std::true_type +{}; template< typename T1, typename T2, @@ -36,8 +44,16 @@ template< typename T1, ExpressionVariableType T1Type = ExpressionVariableTypeGetter< T1 >::value, ExpressionVariableType T2Type = ExpressionVariableTypeGetter< T2 >::value > struct StaticBinaryExpressionTemplate -{ -}; +{}; + +template< typename T1, + typename T2, + template< typename, typename > class Operation, + ExpressionVariableType T1Type, + ExpressionVariableType T2Type > +struct IsExpressionTemplate< StaticBinaryExpressionTemplate< T1, T2, Operation, T1Type, T2Type > > +: std::true_type +{}; template< typename T1, @@ -67,7 +83,7 @@ struct StaticBinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariab static_assert( IsStaticType< T1 >::value, "Left-hand side operand of static expression is not static, i.e. based on static vector." ); static_assert( IsStaticType< T2 >::value, "Right-hand side operand of static expression is not static, i.e. based on static vector." ); using RealType = typename T1::RealType; - using IsExpressionTemplate = bool; + static_assert( IsStaticType< T1 >::value == IsStaticType< T2 >::value, "Attempt to mix static and non-static operands in binary expression templates" ); static_assert( T1::getSize() == T2::getSize(), "Attempt to mix static operands with different sizes." ); @@ -128,7 +144,6 @@ struct StaticBinaryExpressionTemplate< T1, T2, Operation, VectorExpressionVariab static_assert( IsStaticType< T1 >::value, "Left-hand side operand of static expression is not static, i.e. based on static vector." ); using RealType = typename T1::RealType; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return true; } @@ -189,7 +204,6 @@ struct StaticBinaryExpressionTemplate< T1, T2, Operation, ArithmeticVariable, Ve static_assert( IsStaticType< T2 >::value, "Right-hand side operand of static expression is not static, i.e. based on static vector." ); using RealType = typename T2::RealType; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return true; } @@ -254,7 +268,6 @@ struct StaticUnaryExpressionTemplate< T1, Operation, Parameter, VectorExpression static_assert( IsStaticType< T1 >::value, "Operand of static expression is not static, i.e. based on static vector." ); using RealType = typename T1::RealType; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return true; } @@ -317,7 +330,6 @@ template< typename T1, struct StaticUnaryExpressionTemplate< T1, Operation, void, VectorExpressionVariable > { using RealType = typename T1::RealType; - using IsExpressionTemplate = bool; static constexpr bool isStatic() { return true; } diff --git a/src/TNL/Containers/Expressions/IsNumericExpression.h b/src/TNL/Containers/Expressions/TypeTraits.h similarity index 69% rename from src/TNL/Containers/Expressions/IsNumericExpression.h rename to src/TNL/Containers/Expressions/TypeTraits.h index a710574692140f69b519f58b624a527a46c7695a..951df0fdc62babfc1e3de77e9beadb525b28ce8b 100644 --- a/src/TNL/Containers/Expressions/IsNumericExpression.h +++ b/src/TNL/Containers/Expressions/TypeTraits.h @@ -1,5 +1,5 @@ /*************************************************************************** - IsNumericExpression.h - description + TypeTraits.h - description ------------------- begin : Jul 26, 2019 copyright : (C) 2019 by Tomas Oberhuber et al. @@ -17,7 +17,14 @@ namespace Containers { namespace Expressions { template< typename T > -struct IsNumericExpression : std::is_arithmetic< T > +struct IsExpressionTemplate : std::false_type +{}; + +template< typename T > +struct IsNumericExpression +: std::integral_constant< bool, + std::is_arithmetic< T >::value || + IsExpressionTemplate< T >::value > {}; } //namespace Expressions diff --git a/src/TNL/Containers/Vector.h b/src/TNL/Containers/Vector.h index e199d8e7bfeb004385ea60b8185f6f27b2729dd1..2c94080ecb5ea10a725c019422b56bbb6091ad40 100644 --- a/src/TNL/Containers/Vector.h +++ b/src/TNL/Containers/Vector.h @@ -10,7 +10,6 @@ #pragma once -#include <TNL/TypeTraits.h> #include <TNL/Containers/Array.h> #include <TNL/Containers/VectorView.h> @@ -146,13 +145,14 @@ public: const RealType& value, const Scalar thisElementMultiplicator ); - template< typename Real_, typename Device_, typename Index_, typename Allocator_ > - Vector& operator=( const Vector< Real_, Device_, Index_, Allocator_ >& v ); - - template< typename Real_, typename Device_, typename Index_ > - Vector& operator=( const VectorView< Real_, Device_, Index_ >& v ); - - template< typename VectorExpression > + /** + * \brief Assigns a vector expression to this vector. + */ + template< typename VectorExpression, + typename..., + typename = std::enable_if_t< Expressions::IsExpressionTemplate< VectorExpression >::value >, + // workaround for nvcc 10.1: adding one more template parameter fixes a problem with inheriting operator= from the base class + typename = void > Vector& operator=( const VectorExpression& expression ); /** @@ -257,14 +257,6 @@ public: }; } // namespace Containers - -template< typename Real, typename Device, typename Index > -struct IsStatic< Containers::Vector< Real, Device, Index > > -{ - static constexpr bool Value = false; -}; - - } // namespace TNL #include <TNL/Containers/Vector.hpp> diff --git a/src/TNL/Containers/Vector.hpp b/src/TNL/Containers/Vector.hpp index c692e0921882aa6344dc1e930f3fa64798f164ea..a3c191aa6455f9b8bf42dd2c0e356390f6a3aef7 100644 --- a/src/TNL/Containers/Vector.hpp +++ b/src/TNL/Containers/Vector.hpp @@ -141,7 +141,7 @@ template< typename Real, typename Device, typename Index, typename Allocator > - template< typename VectorExpression > + template< typename VectorExpression, typename..., typename, typename > Vector< Real, Device, Index, Allocator >& Vector< Real, Device, Index, Allocator >:: operator=( const VectorExpression& expression ) @@ -150,32 +150,6 @@ operator=( const VectorExpression& expression ) return *this; } -template< typename Real, - typename Device, - typename Index, - typename Allocator > - template< typename Real_, typename Device_, typename Index_, typename Allocator_ > -Vector< Real, Device, Index, Allocator >& -Vector< Real, Device, Index, Allocator >:: -operator=( const Vector< Real_, Device_, Index_, Allocator_ >& vector ) -{ - Array< Real, Device, Index, Allocator >::operator=( vector ); - return *this; -} - -template< typename Real, - typename Device, - typename Index, - typename Allocator > - template< typename Real_, typename Device_, typename Index_ > -Vector< Real, Device, Index, Allocator >& -Vector< Real, Device, Index, Allocator >:: -operator=( const VectorView< Real_, Device_, Index_ >& view ) -{ - Array< Real, Device, Index, Allocator >::operator=( view ); - return *this; -} - template< typename Real, typename Device, typename Index, diff --git a/src/TNL/Containers/VectorView.h b/src/TNL/Containers/VectorView.h index 81f6f50757ca99811e43105a5be17ee5a32c1dc3..4f4686cd43d503d4ddc08a0e02b85ab5ae7ae155 100644 --- a/src/TNL/Containers/VectorView.h +++ b/src/TNL/Containers/VectorView.h @@ -19,9 +19,6 @@ namespace TNL { namespace Containers { -template< typename Real, typename Device, typename Index, typename Allocator > -class Vector; - template< typename Real = double, typename Device = Devices::Host, typename Index = int > @@ -55,17 +52,6 @@ public: VectorView( const ArrayView< Real_, Device, Index >& view ) : BaseType::ArrayView( view ) {} - template< typename T1, - typename T2, - template< typename, typename > class Operation > - __cuda_callable__ - VectorView( const Expressions::BinaryExpressionTemplate< T1, T2, Operation >& expression ); - - template< typename T, - template< typename > class Operation > - __cuda_callable__ - VectorView( const Expressions::UnaryExpressionTemplate< T, Operation >& expression ); - /** * \brief Returns a modifiable view of the vector view. * @@ -113,12 +99,6 @@ public: static String getType(); - //template< typename VectorOperationType > - //void evaluate( const VectorOperationType& vo ); - - template< typename VectorOperationType > - void evaluateFor( const VectorOperationType& vo ); - // All other Vector methods follow... void addElement( IndexType i, RealType value ); @@ -127,13 +107,11 @@ public: RealType value, Scalar thisElementMultiplicator ); - template< typename Real_, typename Device_, typename Index_ > - VectorView& operator=( const VectorView< Real_, Device_, Index_ >& v ); - - template< typename Real_, typename Device_, typename Index_, typename Allocator_ > - VectorView& operator=( const Vector< Real_, Device_, Index_, Allocator_ >& v ); - - template< typename VectorExpression > + template< typename VectorExpression, + typename..., + typename = std::enable_if_t< Expressions::IsExpressionTemplate< VectorExpression >::value >, + // workaround for nvcc 10.1: adding one more template parameter fixes a problem with inheriting operator= from the base class + typename = void > VectorView& operator=( const VectorExpression& expression ); template< typename VectorExpression > @@ -197,13 +175,6 @@ public: }; } // namespace Containers - -template< typename Real, typename Device, typename Index > -struct IsStatic< Containers::VectorView< Real, Device, Index > > -{ - static constexpr bool Value = false; -}; - } // namespace TNL #include <TNL/Containers/VectorViewExpressions.h> diff --git a/src/TNL/Containers/VectorView.hpp b/src/TNL/Containers/VectorView.hpp index 435e70a8646a6657f9f6898651c606846a7b740c..a768b3961930d03d1693e840903996e0581bca11 100644 --- a/src/TNL/Containers/VectorView.hpp +++ b/src/TNL/Containers/VectorView.hpp @@ -19,28 +19,6 @@ namespace TNL { namespace Containers { -template< typename Real, - typename Device, - typename Index > - template< typename T1, - typename T2, - template< typename, typename > class Operation > -VectorView< Real, Device, Index >::VectorView( const Expressions::BinaryExpressionTemplate< T1, T2, Operation >& expression ) -{ - Algorithms::VectorAssignment< VectorView< Real, Device, Index >, Expressions::BinaryExpressionTemplate< T1, T2, Operation > >::assign( *this, expression ); -}; - -template< typename Real, - typename Device, - typename Index > - template< typename T, - template< typename > class Operation > -__cuda_callable__ -VectorView< Real, Device, Index >::VectorView( const Expressions::UnaryExpressionTemplate< T, Operation >& expression ) -{ - Algorithms::VectorAssignment< VectorView< Real, Device, Index >, Expressions::UnaryExpressionTemplate< T, Operation > >::assign( *this, expression ); -}; - template< typename Real, typename Device, typename Index > @@ -116,29 +94,7 @@ addElement( IndexType i, RealType value, Scalar thisElementMultiplicator ) template< typename Real, typename Device, typename Index > - template< typename Real_, typename Device_, typename Index_ > -VectorView< Real, Device, Index >& -VectorView< Real, Device, Index >::operator=( const VectorView< Real_, Device_, Index_ >& v ) -{ - ArrayView< Real, Device, Index >::operator=( v ); - return *this; -} - -template< typename Real, - typename Device, - typename Index > - template< typename Real_, typename Device_, typename Index_, typename Allocator_ > -VectorView< Real, Device, Index >& -VectorView< Real, Device, Index >::operator=( const Vector< Real_, Device_, Index_, Allocator_ >& v ) -{ - ArrayView< Real, Device, Index >::operator=( v ); - return *this; -} - -template< typename Real, - typename Device, - typename Index > - template< typename VectorExpression > + template< typename VectorExpression, typename..., typename, typename > VectorView< Real, Device, Index >& VectorView< Real, Device, Index >::operator=( const VectorExpression& expression ) { @@ -202,7 +158,6 @@ typename VectorView< Real, Device, Index >::NonConstReal VectorView< Real, Device, Index >:: operator,( const Vector_& v ) const { - static_assert( std::is_same< DeviceType, typename Vector_::DeviceType >::value, "Cannot compute product of vectors allocated on different devices." ); return dot( *this, v ); } diff --git a/src/TNL/TypeTraits.h b/src/TNL/TypeTraits.h index 134ebc4cac4f9f21f600dd2f54ef62f952b25e08..6a7affca60c5033597c27a4aa8281b61be7ff9e0 100644 --- a/src/TNL/TypeTraits.h +++ b/src/TNL/TypeTraits.h @@ -11,9 +11,86 @@ #pragma once #include <type_traits> +#include <utility> namespace TNL { +/** + * \brief Type trait for checking if T has getArrayData method. + */ +template< typename T > +class HasGetArrayDataMethod +{ +private: + typedef char YesType[1]; + typedef char NoType[2]; + + template< typename C > static YesType& test( decltype(std::declval< C >().getArrayData()) ); + template< typename C > static NoType& test(...); + +public: + static constexpr bool value = ( sizeof( test< T >(0) ) == sizeof( YesType ) ); +}; + +/** + * \brief Type trait for checking if T has getSize method. + */ +template< typename T > +class HasGetSizeMethod +{ +private: + typedef char YesType[1]; + typedef char NoType[2]; + + template< typename C > static YesType& test( decltype(std::declval< C >().getSize() ) ); + template< typename C > static NoType& test(...); + +public: + static constexpr bool value = ( sizeof( test< T >(0) ) == sizeof( YesType ) ); +}; + +/** + * \brief Type trait for checking if T has operator[] taking one index argument. + */ +template< typename T > +class HasSubscriptOperator +{ +private: + template< typename U > + static constexpr auto check(U*) + -> typename + std::enable_if_t< + ! std::is_same< + decltype( std::declval<U>()[ std::declval<U>().getSize() ] ), + void + >::value, + std::true_type + >; + + template< typename > + static constexpr std::false_type check(...); + + using type = decltype(check<T>(0)); + +public: + static constexpr bool value = type::value; +}; + +/** + * \brief Type trait for checking if T is an array type, e.g. + * \ref Containers::Array or \ref Containers::Vector. + * + * The trait combines \ref HasGetArrayDataMethod, \ref HasGetSizeMethod, + * and \ref HasSubscriptOperator. + */ +template< typename T > +struct IsArrayType +: public std::integral_constant< bool, + HasGetArrayDataMethod< T >::value && + HasGetSizeMethod< T >::value && + HasSubscriptOperator< T >::value > +{}; + template< typename T > struct IsStatic { diff --git a/src/UnitTests/Containers/StaticVectorTest.cpp b/src/UnitTests/Containers/StaticVectorTest.cpp index ee2251b11963069a9704e0c31e8174c5cad58127..3f57237281b0df36aa17f32979b187ed2383866b 100644 --- a/src/UnitTests/Containers/StaticVectorTest.cpp +++ b/src/UnitTests/Containers/StaticVectorTest.cpp @@ -94,7 +94,7 @@ TYPED_TEST( StaticVectorTest, operators ) using VectorType = typename TestFixture::VectorType; constexpr int size = VectorType::size; - static_assert( Algorithms::detail::HasSubscriptOperator< VectorType >::value, "Subscript operator detection by SFINAE does not work for StaticVector." ); + static_assert( HasSubscriptOperator< VectorType >::value, "Subscript operator detection by SFINAE does not work for StaticVector." ); VectorType u1( 1 ), u2( 2 ), u3( 3 ); diff --git a/src/UnitTests/Containers/VectorTest-1.h b/src/UnitTests/Containers/VectorTest-1.h index ced6051ab2fb2a3a61b1af51c9f547d67284488d..32932470be0c6a5e1c585791b756825d7b614fff 100644 --- a/src/UnitTests/Containers/VectorTest-1.h +++ b/src/UnitTests/Containers/VectorTest-1.h @@ -99,8 +99,8 @@ TEST( VectorSpecialCasesTest, assignmentThroughView ) using VectorType = Containers::Vector< int, Devices::Host >; using ViewType = VectorView< int, Devices::Host >; - static_assert( Algorithms::detail::HasSubscriptOperator< VectorType >::value, "Subscript operator detection by SFINAE does not work for Vector." ); - static_assert( Algorithms::detail::HasSubscriptOperator< ViewType >::value, "Subscript operator detection by SFINAE does not work for VectorView." ); + static_assert( HasSubscriptOperator< VectorType >::value, "Subscript operator detection by SFINAE does not work for Vector." ); + static_assert( HasSubscriptOperator< ViewType >::value, "Subscript operator detection by SFINAE does not work for VectorView." ); VectorType u( 100 ), v( 100 ); ViewType u_view( u ), v_view( v );