Commit ef9b11e4 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Simplification of expression templates

parent 118a6f7f
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -134,8 +134,15 @@ public:
   void scan( IndexType begin = 0, IndexType end = 0 );
};

// Enable expression templates for DistributedVector
namespace Expressions {
   template< typename Real, typename Device, typename Index, typename Communicator >
   struct HasEnabledDistributedExpressionTemplates< DistributedVector< Real, Device, Index, Communicator > >
   : std::true_type
   {};
} // namespace Expressions

} // namespace Containers
} // namespace TNL

#include <TNL/Containers/DistributedVector.hpp>
#include <TNL/Containers/DistributedVectorExpressions.h>
+0 −936

File deleted.

Preview size limit exceeded, changes collapsed.

+9 −1
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#pragma once

#include <TNL/Containers/DistributedArrayView.h>
#include <TNL/Containers/Expressions/DistributedExpressionTemplates.h>
#include <TNL/Containers/VectorView.h>

namespace TNL {
@@ -137,8 +138,15 @@ public:
   void scan( IndexType begin = 0, IndexType end = 0 );
};

// Enable expression templates for DistributedVector
namespace Expressions {
   template< typename Real, typename Device, typename Index, typename Communicator >
   struct HasEnabledDistributedExpressionTemplates< DistributedVectorView< Real, Device, Index, Communicator > >
   : std::true_type
   {};
} // namespace Expressions

} // namespace Containers
} // namespace TNL

#include <TNL/Containers/DistributedVectorView.hpp>
#include <TNL/Containers/DistributedVectorViewExpressions.h>
+0 −718

File deleted.

Preview size limit exceeded, changes collapsed.

+39 −18
Original line number Diff line number Diff line
@@ -25,13 +25,14 @@ namespace Expressions {
// Non-static comparison
template< typename T1,
          typename T2,
          ExpressionVariableType T1Type = ExpressionVariableTypeGetter< T1 >::value,
          ExpressionVariableType T2Type = ExpressionVariableTypeGetter< T2 >::value >
          ExpressionVariableType T1Type = getExpressionVariableType< T1, T2 >(),
          ExpressionVariableType T2Type = getExpressionVariableType< T2, T1 >() >
struct Comparison;

template< typename T1,
          typename T2,
          bool BothAreVectors = IsArrayType< T1 >::value && IsArrayType< T2 >::value >
          bool BothAreNonstaticVectors = IsArrayType< T1 >::value && IsArrayType< T2 >::value &&
                                       ! IsStaticArrayType< T1 >::value && ! IsStaticArrayType< T2 >::value >
struct VectorComparison;

// If both operands are vectors we compare them using array operations.
@@ -64,7 +65,9 @@ struct VectorComparison< T1, T2, false >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] == b[ i ]; };
      const auto view_a = a.getConstView();
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] == view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }
};
@@ -94,7 +97,9 @@ struct Comparison< T1, T2, VectorExpressionVariable, VectorExpressionVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] > b[ i ]; };
      const auto view_a = a.getConstView();
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] > view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -107,7 +112,9 @@ struct Comparison< T1, T2, VectorExpressionVariable, VectorExpressionVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] >= b[ i ]; };
      const auto view_a = a.getConstView();
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] >= view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -120,7 +127,9 @@ struct Comparison< T1, T2, VectorExpressionVariable, VectorExpressionVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] < b[ i ]; };
      const auto view_a = a.getConstView();
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] < view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -133,7 +142,9 @@ struct Comparison< T1, T2, VectorExpressionVariable, VectorExpressionVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] <= b[ i ]; };
      const auto view_a = a.getConstView();
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] <= view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }
};
@@ -149,7 +160,8 @@ struct Comparison< T1, T2, ArithmeticVariable, VectorExpressionVariable >
      using DeviceType = typename T2::DeviceType;
      using IndexType = typename T2::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a == b[ i ]; };
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a == view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( b.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -163,7 +175,8 @@ struct Comparison< T1, T2, ArithmeticVariable, VectorExpressionVariable >
      using DeviceType = typename T2::DeviceType;
      using IndexType = typename T2::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a > b[ i ]; };
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a > view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( b.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -172,7 +185,8 @@ struct Comparison< T1, T2, ArithmeticVariable, VectorExpressionVariable >
      using DeviceType = typename T2::DeviceType;
      using IndexType = typename T2::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a >= b[ i ]; };
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a >= view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( b.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -181,7 +195,8 @@ struct Comparison< T1, T2, ArithmeticVariable, VectorExpressionVariable >
      using DeviceType = typename T2::DeviceType;
      using IndexType = typename T2::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a < b[ i ]; };
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a < view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( b.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -190,7 +205,8 @@ struct Comparison< T1, T2, ArithmeticVariable, VectorExpressionVariable >
      using DeviceType = typename T2::DeviceType;
      using IndexType = typename T2::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a <= b[ i ]; };
      const auto view_b = b.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a <= view_b[ i ]; };
      return Algorithms::Reduction< DeviceType >::reduce( b.getSize(), std::logical_and<>{}, fetch, true );
   }
};
@@ -206,7 +222,8 @@ struct Comparison< T1, T2, VectorExpressionVariable, ArithmeticVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] == b; };
      const auto view_a = a.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] == b; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -220,7 +237,8 @@ struct Comparison< T1, T2, VectorExpressionVariable, ArithmeticVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] > b; };
      const auto view_a = a.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] > b; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -229,7 +247,8 @@ struct Comparison< T1, T2, VectorExpressionVariable, ArithmeticVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] >= b; };
      const auto view_a = a.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] >= b; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -238,7 +257,8 @@ struct Comparison< T1, T2, VectorExpressionVariable, ArithmeticVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] < b; };
      const auto view_a = a.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] < b; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }

@@ -247,7 +267,8 @@ struct Comparison< T1, T2, VectorExpressionVariable, ArithmeticVariable >
      using DeviceType = typename T1::DeviceType;
      using IndexType = typename T1::IndexType;

      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return a[ i ] <= b; };
      const auto view_a = a.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType i ) -> bool { return view_a[ i ] <= b; };
      return Algorithms::Reduction< DeviceType >::reduce( a.getSize(), std::logical_and<>{}, fetch, true );
   }
};
Loading