From c93ac9c9cb6da03869790b9dccea7126d43a80ce Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Thu, 8 Aug 2019 11:59:11 +0200
Subject: [PATCH] Added cast function for vectors

This allows to easily change the type of elements in a vector
expression, which can be useful to force a specific type when combining
vectors with different element types. For example, a double-vector can
be cast to float to avoid expensive computations in double precision.
---
 .../Containers/DistributedVectorExpressions.h | 11 +++
 .../DistributedVectorViewExpressions.h        | 11 +++
 .../DistributedExpressionTemplates.h          | 25 ++++++
 .../Expressions/ExpressionTemplates.h         | 25 ++++++
 .../Expressions/HorizontalOperations.h        | 14 ++++
 .../Expressions/StaticExpressionTemplates.h   | 25 ++++++
 src/TNL/Containers/StaticVectorExpressions.h  | 12 +++
 src/TNL/Containers/VectorExpressions.h        | 11 +++
 src/TNL/Containers/VectorViewExpressions.h    | 11 +++
 src/UnitTests/Containers/VectorTest-1.h       | 84 ++++++++++---------
 .../Containers/VectorUnaryOperationsTest.h    | 28 +++++++
 11 files changed, 217 insertions(+), 40 deletions(-)

diff --git a/src/TNL/Containers/DistributedVectorExpressions.h b/src/TNL/Containers/DistributedVectorExpressions.h
index 246d6cc89e..d87e5b0dbd 100644
--- a/src/TNL/Containers/DistributedVectorExpressions.h
+++ b/src/TNL/Containers/DistributedVectorExpressions.h
@@ -785,6 +785,17 @@ sign( const Containers::DistributedVector< Real, Device, Index, Communicator >&
    return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Sign >( a );
 }
 
+////
+// Cast
+template< typename ResultType, typename Real, typename Device, typename Index, typename Communicator,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::DistributedVector< Real, Device, Index, Communicator >& a )
+{
+   return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename Real,
diff --git a/src/TNL/Containers/DistributedVectorViewExpressions.h b/src/TNL/Containers/DistributedVectorViewExpressions.h
index f70d962aed..d32d30d99a 100644
--- a/src/TNL/Containers/DistributedVectorViewExpressions.h
+++ b/src/TNL/Containers/DistributedVectorViewExpressions.h
@@ -590,6 +590,17 @@ sign( const Containers::DistributedVectorView< Real, Device, Index, Communicator
    return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Sign >( a );
 }
 
+////
+// Cast
+template< typename ResultType, typename Real, typename Device, typename Index, typename Communicator,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::DistributedVectorView< Real, Device, Index, Communicator >& a )
+{
+   return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename Real, typename Device, typename Index, typename Communicator >
diff --git a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h
index de4191d975..ea80cce955 100644
--- a/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h
+++ b/src/TNL/Containers/Expressions/DistributedExpressionTemplates.h
@@ -1774,6 +1774,31 @@ exp( const Containers::Expressions::DistributedUnaryExpressionTemplate< L1, LOpe
    return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a );
 }
 
+////
+// Cast
+template< typename ResultType,
+          typename L1,
+          typename L2,
+          template< typename, typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::DistributedBinaryExpressionTemplate< L1, L2, LOperation >& a )
+{
+   return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
+template< typename ResultType,
+          typename L1,
+          template< typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::DistributedUnaryExpressionTemplate< L1, LOperation >& a )
+{
+   return Containers::Expressions::DistributedUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename L1,
diff --git a/src/TNL/Containers/Expressions/ExpressionTemplates.h b/src/TNL/Containers/Expressions/ExpressionTemplates.h
index f58dae28b2..dd7126b22c 100644
--- a/src/TNL/Containers/Expressions/ExpressionTemplates.h
+++ b/src/TNL/Containers/Expressions/ExpressionTemplates.h
@@ -1693,6 +1693,31 @@ exp( const Containers::Expressions::UnaryExpressionTemplate< L1, LOperation >& a
    return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a );
 }
 
+////
+// Cast
+template< typename ResultType,
+          typename L1,
+          typename L2,
+          template< typename, typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::BinaryExpressionTemplate< L1, L2, LOperation >& a )
+{
+   return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
+template< typename ResultType,
+          typename L1,
+          template< typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::UnaryExpressionTemplate< L1, LOperation >& a )
+{
+   return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename L1,
diff --git a/src/TNL/Containers/Expressions/HorizontalOperations.h b/src/TNL/Containers/Expressions/HorizontalOperations.h
index f71e1b8b32..4bfe6a0146 100644
--- a/src/TNL/Containers/Expressions/HorizontalOperations.h
+++ b/src/TNL/Containers/Expressions/HorizontalOperations.h
@@ -316,6 +316,20 @@ struct Sign
    }
 };
 
+template< typename ResultType >
+struct Cast
+{
+   template< typename T1 >
+   struct Operation
+   {
+      __cuda_callable__
+      static auto evaluate( const T1& a ) -> ResultType
+      {
+         return static_cast<ResultType>( a );
+      }
+   };
+};
+
 } // namespace Expressions
 } // namespace Containers
 } // namespace TNL
diff --git a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h
index 6a54b2bd33..d1ef22db0b 100644
--- a/src/TNL/Containers/Expressions/StaticExpressionTemplates.h
+++ b/src/TNL/Containers/Expressions/StaticExpressionTemplates.h
@@ -1847,6 +1847,31 @@ exp( const Containers::Expressions::StaticUnaryExpressionTemplate< L1, LOperatio
    return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, Containers::Expressions::Exp >( a );
 }
 
+////
+// Cast
+template< typename ResultType,
+          typename L1,
+          typename L2,
+          template< typename, typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::StaticBinaryExpressionTemplate< L1, L2, LOperation >& a )
+{
+   return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
+template< typename ResultType,
+          typename L1,
+          template< typename > class LOperation,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class CastOperation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Expressions::StaticUnaryExpressionTemplate< L1, LOperation >& a )
+{
+   return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, CastOperation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename L1,
diff --git a/src/TNL/Containers/StaticVectorExpressions.h b/src/TNL/Containers/StaticVectorExpressions.h
index 1891ad54d8..db169e4eae 100644
--- a/src/TNL/Containers/StaticVectorExpressions.h
+++ b/src/TNL/Containers/StaticVectorExpressions.h
@@ -637,6 +637,18 @@ sign( const Containers::StaticVector< Size, Real >& a )
    return Containers::Expressions::StaticUnaryExpressionTemplate< Containers::StaticVector< Size, Real >, Containers::Expressions::Sign >( a );
 }
 
+////
+// Cast
+template< typename ResultType, int Size, typename Real,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+__cuda_callable__
+cast( const Containers::StaticVector< Size, Real >& a )
+{
+   return Containers::Expressions::StaticUnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a );
+}
+
 ////
 // Vertical operations - min
 template< int Size, typename Real >
diff --git a/src/TNL/Containers/VectorExpressions.h b/src/TNL/Containers/VectorExpressions.h
index 3aca5dce2a..3c5d128cc6 100644
--- a/src/TNL/Containers/VectorExpressions.h
+++ b/src/TNL/Containers/VectorExpressions.h
@@ -915,6 +915,17 @@ sign( const Containers::Vector< Real, Device, Index, Allocator >& a )
    return Containers::Expressions::UnaryExpressionTemplate< ConstView, Containers::Expressions::Sign >( a.getConstView() );
 }
 
+////
+// Cast
+template< typename ResultType, typename Real, typename Device, typename Index, typename Allocator,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::Vector< Real, Device, Index, Allocator >& a )
+{
+   return Containers::Expressions::UnaryExpressionTemplate< decltype(a.getConstView()), Operation >( a.getConstView() );
+}
+
 ////
 // Vertical operations - min
 template< typename Real,
diff --git a/src/TNL/Containers/VectorViewExpressions.h b/src/TNL/Containers/VectorViewExpressions.h
index 1b29ee78c5..48bb426534 100644
--- a/src/TNL/Containers/VectorViewExpressions.h
+++ b/src/TNL/Containers/VectorViewExpressions.h
@@ -597,6 +597,17 @@ sign( const Containers::VectorView< Real, Device, Index >& a )
    return Containers::Expressions::UnaryExpressionTemplate< Containers::VectorView< Real, Device, Index >, Containers::Expressions::Sign >( a );
 }
 
+////
+// Cast
+template< typename ResultType, typename Real, typename Device, typename Index,
+          // workaround: templated type alias cannot be declared at block level
+          template<typename> class Operation = Containers::Expressions::Cast< ResultType >::template Operation >
+auto
+cast( const Containers::VectorView< Real, Device, Index >& a )
+{
+   return Containers::Expressions::UnaryExpressionTemplate< std::decay_t<decltype(a)>, Operation >( a );
+}
+
 ////
 // Vertical operations - min
 template< typename Real,
diff --git a/src/UnitTests/Containers/VectorTest-1.h b/src/UnitTests/Containers/VectorTest-1.h
index 79c9253f9d..b6602ba141 100644
--- a/src/UnitTests/Containers/VectorTest-1.h
+++ b/src/UnitTests/Containers/VectorTest-1.h
@@ -137,50 +137,54 @@ TEST( VectorSpecialCasesTest, initializationOfVectorViewByArrayView )
 TEST( VectorSpecialCasesTest, sumOfBoolVector )
 {
    using VectorType = Containers::Vector< bool, Devices::Host >;
-   using ViewType = VectorView< bool, Devices::Host >;
-   const float epsilon = 64 * std::numeric_limits< float >::epsilon();
+   using ViewType = typename VectorType::ViewType;
+   const double epsilon = 64 * std::numeric_limits< double >::epsilon();
+   constexpr int size = 4999;
 
-   VectorType v( 512 ), w( 512 );
-   ViewType v_view( v ), w_view( w );
+   VectorType v( size );
+   ViewType v_view( v );
    v.setValue( true );
-   w.setValue( false );
-
-   const int sum = TNL::sum( v );
-   const int l1norm = lpNorm( v, 1.0 );
-   const float l2norm = lpNorm( v, 2.0 );
-   const float l3norm = lpNorm( v, 3.0 );
-   EXPECT_EQ( sum, 512 );
-   EXPECT_EQ( l1norm, 512 );
-   EXPECT_NEAR( l2norm, std::sqrt( 512 ), epsilon );
-   EXPECT_NEAR( l3norm, std::cbrt( 512 ), epsilon );
-
-   const int diff_sum = TNL::sum( v - w );
-   const int diff_l1norm = lpNorm( v - w, 1.0 );
-   const float diff_l2norm = lpNorm( v - w, 2.0 );
-   const float diff_l3norm = lpNorm( v - w, 3.0 );
-   EXPECT_EQ( diff_sum, 512 );
-   EXPECT_EQ( diff_l1norm, 512 );
-   EXPECT_NEAR( diff_l2norm, std::sqrt( 512 ), epsilon );
-   EXPECT_NEAR( diff_l3norm, std::cbrt( 512 ), epsilon );
+
+   // normal sum and lpNorm rely on built-in integral promotion
+   const auto sum = TNL::sum( v );
+   const auto l1norm = l1Norm( v );
+   const auto l2norm = l2Norm( v );
+   const auto l3norm = lpNorm( v, 3.0 );
+   EXPECT_EQ( sum, size );
+   EXPECT_EQ( l1norm, size );
+   EXPECT_EQ( l2norm, std::sqrt( size ) );
+   EXPECT_NEAR( l3norm, std::cbrt( size ), epsilon );
+
+   // explicit cast to double
+   const auto sum_cast = TNL::sum( cast<double>( v ) );
+   const auto l1norm_cast = l1Norm( cast<double>( v ) );
+   const auto l2norm_cast = l2Norm( cast<double>( v ) );
+   const auto l3norm_cast = lpNorm( cast<double>( v ), 3.0 );
+   EXPECT_EQ( sum_cast, size );
+   EXPECT_EQ( l1norm_cast, size );
+   EXPECT_EQ( l2norm_cast, std::sqrt( size ) );
+   EXPECT_NEAR( l3norm_cast, std::cbrt( size ), epsilon );
 
    // test views
-   const int sum_view = TNL::sum( v_view );
-   const int l1norm_view = lpNorm( v_view, 1.0 );
-   const float l2norm_view = lpNorm( v_view, 2.0 );
-   const float l3norm_view = lpNorm( v_view, 3.0 );
-   EXPECT_EQ( sum_view, 512 );
-   EXPECT_EQ( l1norm_view, 512 );
-   EXPECT_NEAR( l2norm_view, std::sqrt( 512 ), epsilon );
-   EXPECT_NEAR( l3norm_view, std::cbrt( 512 ), epsilon );
-
-   const int diff_sum_view = TNL::sum( v_view - w_view );
-   const int diff_l1norm_view = lpNorm( v_view -w_view, 1.0 );
-   const float diff_l2norm_view = lpNorm( v_view - w_view, 2.0 );
-   const float diff_l3norm_view = lpNorm( v_view - w_view, 3.0 );
-   EXPECT_EQ( diff_sum_view, 512 );
-   EXPECT_EQ( diff_l1norm_view, 512 );
-   EXPECT_NEAR( diff_l2norm_view, std::sqrt( 512 ), epsilon );
-   EXPECT_NEAR( diff_l3norm_view, std::cbrt( 512 ), epsilon );
+   // normal sum and lpNorm rely on built-in integral promotion
+   const auto sum_view = TNL::sum( v_view );
+   const auto l1norm_view = l1Norm( v_view );
+   const auto l2norm_view = l2Norm( v_view );
+   const auto l3norm_view = lpNorm( v_view, 3.0 );
+   EXPECT_EQ( sum_view, size );
+   EXPECT_EQ( l1norm_view, size );
+   EXPECT_EQ( l2norm_view, std::sqrt( size ) );
+   EXPECT_NEAR( l3norm_view, std::cbrt( size ), epsilon );
+
+   // explicit cast to double
+   const auto sum_view_cast = TNL::sum( cast<double>( v_view ) );
+   const auto l1norm_view_cast = l1Norm( cast<double>( v_view ) );
+   const auto l2norm_view_cast = l2Norm( cast<double>( v_view ) );
+   const auto l3norm_view_cast = lpNorm( cast<double>( v_view ), 3.0 );
+   EXPECT_EQ( sum_view_cast, size );
+   EXPECT_EQ( l1norm_view_cast, size);
+   EXPECT_EQ( l2norm_view_cast, std::sqrt( size ) );
+   EXPECT_NEAR( l3norm_view_cast, std::cbrt( size ), epsilon );
 }
 
 #endif // HAVE_GTEST
diff --git a/src/UnitTests/Containers/VectorUnaryOperationsTest.h b/src/UnitTests/Containers/VectorUnaryOperationsTest.h
index 8ee68dbde1..00b3f480fc 100644
--- a/src/UnitTests/Containers/VectorUnaryOperationsTest.h
+++ b/src/UnitTests/Containers/VectorUnaryOperationsTest.h
@@ -484,6 +484,34 @@ TYPED_TEST( VectorUnaryOperationsTest, sign )
    EXPECT_EQ( sign(V1), expected );
 }
 
+TYPED_TEST( VectorUnaryOperationsTest, cast )
+{
+   auto identity = [](int i) { return i; };
+   SETUP_UNARY_VECTOR_TEST_FUNCTION( VECTOR_TEST_SIZE, 1, VECTOR_TEST_SIZE, identity );
+
+   // vector or vector view
+   auto expression1 = cast<bool>(V1);
+   static_assert( std::is_same< typename decltype(expression1)::RealType, bool >::value,
+                  "BUG: the cast function does not work for vector or vector view." );
+   EXPECT_EQ( expression1, true );
+
+   // binary expression
+   auto expression2( cast<bool>(V1 + V1) );
+   static_assert( std::is_same< typename decltype(expression2)::RealType, bool >::value,
+                  "BUG: the cast function does not work for binary expression." );
+   // FIXME: expression2 cannot be reused, because expression templates for StaticVector and DistributedVector contain references and the test would crash in Release
+//   EXPECT_EQ( expression2, true );
+   EXPECT_EQ( cast<bool>(V1 + V1), true );
+
+   // unary expression
+   auto expression3( cast<bool>(-V1) );
+   static_assert( std::is_same< typename decltype(expression3)::RealType, bool >::value,
+                  "BUG: the cast function does not work for unary expression." );
+   // FIXME: expression2 cannot be reused, because expression templates for StaticVector and DistributedVector contain references and the test would crash in Release
+//   EXPECT_EQ( expression3, true );
+   EXPECT_EQ( cast<bool>(-V1), true );
+}
+
 
 TYPED_TEST( VectorUnaryOperationsTest, max )
 {
-- 
GitLab