diff --git a/src/TNL/Containers/DistributedVectorExpressions.h b/src/TNL/Containers/DistributedVectorExpressions.h index 246d6cc89e440018c25bd5c987c5737a53ee490d..d87e5b0dbdb3f73ee2f0b90591bde73c33d1133a 100644 --- a/src/TNL/Containers/DistributedVectorExpressions.h +++ b/src/TNL/Containers/DistributedVectorExpressions.h @@ -785,6 +785,17 @@ sign( const Containers::DistributedVector< Real, Device, Index, Communicator >& return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Sign >( a ); } +//// +// Cast +template< typename ResultType, typename Real, typename Device, typename Index, typename Communicator, + // workaround: templated type alias cannot be declared at block level + template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::DistributedVector< Real, Device, Index, Communicator >& a ) +{ + return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a ); +} + //// // Vertical operations - min template< typename Real, diff --git a/src/TNL/Containers/DistributedVectorViewExpressions.h b/src/TNL/Containers/DistributedVectorViewExpressions.h index f70d962aed8cec5f2c3e439c10ca763a8cdd17a9..d32d30d99addd82041f16f27b7ed0ddd0a468b21 100644 --- a/src/TNL/Containers/DistributedVectorViewExpressions.h +++ b/src/TNL/Containers/DistributedVectorViewExpressions.h @@ -590,6 +590,17 @@ sign( const Containers::DistributedVectorView< Real, Device, Index, Communicator return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Sign >( a ); } +//// +// Cast +template< typename ResultType, typename Real, typename Device, typename Index, typename Communicator, + // workaround: templated type alias cannot be declared at block level + template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::DistributedVectorView< Real, Device, Index, Communicator >& a ) +{ + return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a ); +} + //// // Vertical operations - min template< typename Real, typename Device, typename Index, typename Communicator > diff --git a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h index de4191d975fd9909937c0c777d40e6eff33ae7fd..ea80cce95537c2b89a69d034630fa4366cc1383c 100644 --- a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h @@ -1774,6 +1774,31 @@ exp( const Containers::Expressions::DistributedUnaryExpressionTemplate< L1, LOpe return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a ); } +//// +// Cast +template< typename ResultType, + typename L1, + typename L2, + template< typename, typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::DistributedBinaryExpressionTemplate< L1, L2, LOperation >& a ) +{ + return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + +template< typename ResultType, + typename L1, + template< typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::DistributedUnaryExpressionTemplate< L1, LOperation >& a ) +{ + return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + //// // Vertical operations - min template< typename L1, diff --git a/src/TNL/Containers/Expressions/ExpressionTemplates.h b/src/TNL/Containers/Expressions/ExpressionTemplates.h index f58dae28b28b44b6d0da4e49353a8586b196bf7e..dd7126b22c1807e76dd2f77120d9f34229523442 100644 --- a/src/TNL/Containers/Expressions/ExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/ExpressionTemplates.h @@ -1693,6 +1693,31 @@ exp( const Containers::Expressions::UnaryExpressionTemplate< L1, LOperation >& a return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a ); } +//// +// Cast +template< typename ResultType, + typename L1, + typename L2, + template< typename, typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::BinaryExpressionTemplate< L1, L2, LOperation >& a ) +{ + return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + +template< typename ResultType, + typename L1, + template< typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::UnaryExpressionTemplate< L1, LOperation >& a ) +{ + return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + //// // Vertical operations - min template< typename L1, diff --git a/src/TNL/Containers/Expressions/HorizontalOperations.h b/src/TNL/Containers/Expressions/HorizontalOperations.h index f71e1b8b3291f2a5341335940da53d92c3f5408a..4bfe6a0146fde1f3ac52a2e2175ea56f422ad56c 100644 --- a/src/TNL/Containers/Expressions/HorizontalOperations.h +++ b/src/TNL/Containers/Expressions/HorizontalOperations.h @@ -316,6 +316,20 @@ struct Sign } }; +template< typename ResultType > +struct Cast +{ + template< typename T1 > + struct Operation + { + __cuda_callable__ + static auto evaluate( const T1& a ) -> ResultType + { + return static_cast<ResultType>( a ); + } + }; +}; + } // namespace Expressions } // namespace Containers } // namespace TNL diff --git a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h index 6a54b2bd33f56f96bd9149feb5a8371d84117104..d1ef22db0b23048cbe45ee44368cdb7cb9e54288 100644 --- a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h +++ b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h @@ -1847,6 +1847,31 @@ exp( const Containers::Expressions::StaticUnaryExpressionTemplate< L1, LOperatio return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a ); } +//// +// Cast +template< typename ResultType, + typename L1, + typename L2, + template< typename, typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::StaticBinaryExpressionTemplate< L1, L2, LOperation >& a ) +{ + return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + +template< typename ResultType, + typename L1, + template< typename > class LOperation, + // workaround: templated type alias cannot be declared at block level + template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Expressions::StaticUnaryExpressionTemplate< L1, LOperation >& a ) +{ + return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a ); +} + //// // Vertical operations - min template< typename L1, diff --git a/src/TNL/Containers/StaticVectorExpressions.h b/src/TNL/Containers/StaticVectorExpressions.h index 1891ad54d83b539630182abd33e6f89aa1623aee..db169e4eae0c4d2de08593ffb150059ac9095fd8 100644 --- a/src/TNL/Containers/StaticVectorExpressions.h +++ b/src/TNL/Containers/StaticVectorExpressions.h @@ -637,6 +637,18 @@ sign( const Containers::StaticVector< Size, Real >& a ) return Containers::Expressions::StaticUnaryExpressionTemplate< Containers::StaticVector< Size, Real >, Containers::Expressions::Sign >( a ); } +//// +// Cast +template< typename ResultType, int Size, typename Real, + // workaround: templated type alias cannot be declared at block level + template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +__cuda_callable__ +cast( const Containers::StaticVector< Size, Real >& a ) +{ + return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a ); +} + //// // Vertical operations - min template< int Size, typename Real > diff --git a/src/TNL/Containers/VectorExpressions.h b/src/TNL/Containers/VectorExpressions.h index 3aca5dce2a6ff35604f0923709cd74b42d38fd57..3c5d128cc6f7a95efd75978c54993f456982d83c 100644 --- a/src/TNL/Containers/VectorExpressions.h +++ b/src/TNL/Containers/VectorExpressions.h @@ -915,6 +915,17 @@ sign( const Containers::Vector< Real, Device, Index, Allocator >& a ) return Containers::Expressions::UnaryExpressionTemplate< ConstView, Containers::Expressions::Sign >( a.getConstView() ); } +//// +// Cast +template< typename ResultType, typename Real, typename Device, typename Index, typename Allocator, + // workaround: templated type alias cannot be declared at block level + template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::Vector< Real, Device, Index, Allocator >& a ) +{ + return Containers::Expressions::UnaryExpressionTemplate< decltype(a.getConstView()), Operation >( a.getConstView() ); +} + //// // Vertical operations - min template< typename Real, diff --git a/src/TNL/Containers/VectorViewExpressions.h b/src/TNL/Containers/VectorViewExpressions.h index 1b29ee78c5b37cf033f2ee5d8e9f71008f29b601..48bb42653499b2c25062cea816e9d099ec6adc1e 100644 --- a/src/TNL/Containers/VectorViewExpressions.h +++ b/src/TNL/Containers/VectorViewExpressions.h @@ -597,6 +597,17 @@ sign( const Containers::VectorView< Real, Device, Index >& a ) return Containers::Expressions::UnaryExpressionTemplate< Containers::VectorView< Real, Device, Index >, Containers::Expressions::Sign >( a ); } +//// +// Cast +template< typename ResultType, typename Real, typename Device, typename Index, + // workaround: templated type alias cannot be declared at block level + template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation > +auto +cast( const Containers::VectorView< Real, Device, Index >& a ) +{ + return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a ); +} + //// // Vertical operations - min template< typename Real, diff --git a/src/UnitTests/Containers/VectorTest-1.h b/src/UnitTests/Containers/VectorTest-1.h index 79c9253f9d5c1393d7e2ae3b4e1b8c6426a79f63..b6602ba141dbeb3a45470c58cf6130544b85d944 100644 --- a/src/UnitTests/Containers/VectorTest-1.h +++ b/src/UnitTests/Containers/VectorTest-1.h @@ -137,50 +137,54 @@ TEST( VectorSpecialCasesTest, initializationOfVectorViewByArrayView ) TEST( VectorSpecialCasesTest, sumOfBoolVector ) { using VectorType = Containers::Vector< bool, Devices::Host >; - using ViewType = VectorView< bool, Devices::Host >; - const float epsilon = 64 * std::numeric_limits< float >::epsilon(); + using ViewType = typename VectorType::ViewType; + const double epsilon = 64 * std::numeric_limits< double >::epsilon(); + constexpr int size = 4999; - VectorType v( 512 ), w( 512 ); - ViewType v_view( v ), w_view( w ); + VectorType v( size ); + ViewType v_view( v ); v.setValue( true ); - w.setValue( false ); - - const int sum = TNL::sum( v ); - const int l1norm = lpNorm( v, 1.0 ); - const float l2norm = lpNorm( v, 2.0 ); - const float l3norm = lpNorm( v, 3.0 ); - EXPECT_EQ( sum, 512 ); - EXPECT_EQ( l1norm, 512 ); - EXPECT_NEAR( l2norm, std::sqrt( 512 ), epsilon ); - EXPECT_NEAR( l3norm, std::cbrt( 512 ), epsilon ); - - const int diff_sum = TNL::sum( v - w ); - const int diff_l1norm = lpNorm( v - w, 1.0 ); - const float diff_l2norm = lpNorm( v - w, 2.0 ); - const float diff_l3norm = lpNorm( v - w, 3.0 ); - EXPECT_EQ( diff_sum, 512 ); - EXPECT_EQ( diff_l1norm, 512 ); - EXPECT_NEAR( diff_l2norm, std::sqrt( 512 ), epsilon ); - EXPECT_NEAR( diff_l3norm, std::cbrt( 512 ), epsilon ); + + // normal sum and lpNorm rely on built-in integral promotion + const auto sum = TNL::sum( v ); + const auto l1norm = l1Norm( v ); + const auto l2norm = l2Norm( v ); + const auto l3norm = lpNorm( v, 3.0 ); + EXPECT_EQ( sum, size ); + EXPECT_EQ( l1norm, size ); + EXPECT_EQ( l2norm, std::sqrt( size ) ); + EXPECT_NEAR( l3norm, std::cbrt( size ), epsilon ); + + // explicit cast to double + const auto sum_cast = TNL::sum( cast<double>( v ) ); + const auto l1norm_cast = l1Norm( cast<double>( v ) ); + const auto l2norm_cast = l2Norm( cast<double>( v ) ); + const auto l3norm_cast = lpNorm( cast<double>( v ), 3.0 ); + EXPECT_EQ( sum_cast, size ); + EXPECT_EQ( l1norm_cast, size ); + EXPECT_EQ( l2norm_cast, std::sqrt( size ) ); + EXPECT_NEAR( l3norm_cast, std::cbrt( size ), epsilon ); // test views - const int sum_view = TNL::sum( v_view ); - const int l1norm_view = lpNorm( v_view, 1.0 ); - const float l2norm_view = lpNorm( v_view, 2.0 ); - const float l3norm_view = lpNorm( v_view, 3.0 ); - EXPECT_EQ( sum_view, 512 ); - EXPECT_EQ( l1norm_view, 512 ); - EXPECT_NEAR( l2norm_view, std::sqrt( 512 ), epsilon ); - EXPECT_NEAR( l3norm_view, std::cbrt( 512 ), epsilon ); - - const int diff_sum_view = TNL::sum( v_view - w_view ); - const int diff_l1norm_view = lpNorm( v_view -w_view, 1.0 ); - const float diff_l2norm_view = lpNorm( v_view - w_view, 2.0 ); - const float diff_l3norm_view = lpNorm( v_view - w_view, 3.0 ); - EXPECT_EQ( diff_sum_view, 512 ); - EXPECT_EQ( diff_l1norm_view, 512 ); - EXPECT_NEAR( diff_l2norm_view, std::sqrt( 512 ), epsilon ); - EXPECT_NEAR( diff_l3norm_view, std::cbrt( 512 ), epsilon ); + // normal sum and lpNorm rely on built-in integral promotion + const auto sum_view = TNL::sum( v_view ); + const auto l1norm_view = l1Norm( v_view ); + const auto l2norm_view = l2Norm( v_view ); + const auto l3norm_view = lpNorm( v_view, 3.0 ); + EXPECT_EQ( sum_view, size ); + EXPECT_EQ( l1norm_view, size ); + EXPECT_EQ( l2norm_view, std::sqrt( size ) ); + EXPECT_NEAR( l3norm_view, std::cbrt( size ), epsilon ); + + // explicit cast to double + const auto sum_view_cast = TNL::sum( cast<double>( v_view ) ); + const auto l1norm_view_cast = l1Norm( cast<double>( v_view ) ); + const auto l2norm_view_cast = l2Norm( cast<double>( v_view ) ); + const auto l3norm_view_cast = lpNorm( cast<double>( v_view ), 3.0 ); + EXPECT_EQ( sum_view_cast, size ); + EXPECT_EQ( l1norm_view_cast, size); + EXPECT_EQ( l2norm_view_cast, std::sqrt( size ) ); + EXPECT_NEAR( l3norm_view_cast, std::cbrt( size ), epsilon ); } #endif // HAVE_GTEST diff --git a/src/UnitTests/Containers/VectorUnaryOperationsTest.h b/src/UnitTests/Containers/VectorUnaryOperationsTest.h index 8ee68dbde1bf984781da4636ad372955df083302..00b3f480fcfeb83333abca58dc6e5841c4cfdbcb 100644 --- a/src/UnitTests/Containers/VectorUnaryOperationsTest.h +++ b/src/UnitTests/Containers/VectorUnaryOperationsTest.h @@ -484,6 +484,34 @@ TYPED_TEST( VectorUnaryOperationsTest, sign ) EXPECT_EQ( sign(V1), expected ); } +TYPED_TEST( VectorUnaryOperationsTest, cast ) +{ + auto identity = [](int i) { return i; }; + SETUP_UNARY_VECTOR_TEST_FUNCTION( VECTOR_TEST_SIZE, 1, VECTOR_TEST_SIZE, identity ); + + // vector or vector view + auto expression1 = cast<bool>(V1); + static_assert( std::is_same< typename decltype(expression1)::RealType, bool >::value, + "BUG: the cast function does not work for vector or vector view." ); + EXPECT_EQ( expression1, true ); + + // binary expression + auto expression2( cast<bool>(V1 + V1) ); + static_assert( std::is_same< typename decltype(expression2)::RealType, bool >::value, + "BUG: the cast function does not work for binary expression." ); + // FIXME: expression2 cannot be reused, because expression templates for StaticVector and DistributedVector contain references and the test would crash in Release +// EXPECT_EQ( expression2, true ); + EXPECT_EQ( cast<bool>(V1 + V1), true ); + + // unary expression + auto expression3( cast<bool>(-V1) ); + static_assert( std::is_same< typename decltype(expression3)::RealType, bool >::value, + "BUG: the cast function does not work for unary expression." ); + // FIXME: expression2 cannot be reused, because expression templates for StaticVector and DistributedVector contain references and the test would crash in Release +// EXPECT_EQ( expression3, true ); + EXPECT_EQ( cast<bool>(-V1), true ); +} + TYPED_TEST( VectorUnaryOperationsTest, max ) {