Commit 72ad8e30 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Removed unnecessary lambda functions from expression templates

parent a1e3a62d
Loading
Loading
Loading
Loading
+11 −60
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <limits>
#include <type_traits>

#include <TNL/Functional.h>
#include <TNL/Algorithms/reduce.h>
#include <TNL/Containers/Expressions/TypeTraits.h>

@@ -34,16 +35,7 @@ auto ExpressionMin( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( const ResultType& a, const ResultType& b )
   {
      // use argument-dependent lookup and make TNL::min available for unqualified calls
      using TNL::min;
      return min( a, b );
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, reduction, std::numeric_limits< ResultType >::max() );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Min{}, TNL::Min::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -54,18 +46,7 @@ auto ExpressionArgMin( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( ResultType& a, const ResultType& b, IndexType& aIdx, const IndexType& bIdx ) {
      if( a > b ) {
         a = b;
         aIdx = bIdx;
      }
      else if( a == b && bIdx < aIdx )
         aIdx = bIdx;
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, reduction, std::numeric_limits< ResultType >::max() );
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::MinWithArg{}, TNL::MinWithArg::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -76,16 +57,7 @@ auto ExpressionMax( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( const ResultType& a, const ResultType& b )
   {
      // use argument-dependent lookup and make TNL::max available for unqualified calls
      using TNL::max;
      return max( a, b );
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, reduction, std::numeric_limits< ResultType >::lowest() );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Max{}, TNL::Max::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -96,18 +68,7 @@ auto ExpressionArgMax( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( ResultType& a, const ResultType& b, IndexType& aIdx, const IndexType& bIdx ) {
      if( a < b ) {
         a = b;
         aIdx = bIdx;
      }
      else if( a == b && bIdx < aIdx )
         aIdx = bIdx;
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, reduction, std::numeric_limits< ResultType >::lowest() );
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::MaxWithArg{}, TNL::MaxWithArg::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -118,8 +79,7 @@ auto ExpressionSum( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::plus<>{}, (ResultType) 0 );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Plus{}, TNL::Plus::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -130,8 +90,7 @@ auto ExpressionProduct( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::multiplies<>{}, (ResultType) 1 );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Multiplies{}, TNL::Multiplies::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -142,10 +101,7 @@ auto ExpressionLogicalAnd( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::logical_and<>{}, std::numeric_limits< ResultType >::max() );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::LogicalAnd{}, TNL::LogicalAnd::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -156,8 +112,7 @@ auto ExpressionLogicalOr( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::logical_or<>{}, (ResultType) 0 );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::LogicalOr{}, TNL::LogicalOr::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -168,10 +123,7 @@ auto ExpressionBinaryAnd( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::bit_and<>{}, std::numeric_limits< ResultType >::max() );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::BitAnd{}, TNL::BitAnd::template getIdempotent< ResultType >() );
}

template< typename Expression >
@@ -182,8 +134,7 @@ auto ExpressionBinaryOr( const Expression& expression )
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), fetch, std::bit_or<>{}, (ResultType) 0 );
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::BitOr{}, TNL::BitOr::template getIdempotent< ResultType >() );
}

} // namespace Expressions
+2 −2
Original line number Diff line number Diff line
@@ -261,7 +261,7 @@ TEST( VectorSpecialCasesTest, reductionOfEmptyVector )
   EXPECT_EQ( product(v), 1 );
   EXPECT_EQ( logicalAnd(v), true );
   EXPECT_EQ( logicalOr(v), false );
   EXPECT_EQ( binaryAnd(v), std::numeric_limits< int >::max() );
   EXPECT_EQ( binaryAnd(v), ~0 );
   EXPECT_EQ( binaryOr(v), 0 );

   EXPECT_EQ( min(v_view), std::numeric_limits< int >::max() );
@@ -272,7 +272,7 @@ TEST( VectorSpecialCasesTest, reductionOfEmptyVector )
   EXPECT_EQ( product(v_view), 1 );
   EXPECT_EQ( logicalAnd(v_view), true );
   EXPECT_EQ( logicalOr(v_view), false );
   EXPECT_EQ( binaryAnd(v_view), std::numeric_limits< int >::max() );
   EXPECT_EQ( binaryAnd(v_view), ~0 );
   EXPECT_EQ( binaryOr(v_view), 0 );
}