Commit ee68bdc9 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added static asserts for std::numeric_limits<ResultType>::is_specialized to vector reductions

parent 2df931ad
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -25,6 +25,8 @@ auto DistributedExpressionMin( const Expression& expression ) -> std::decay_t< d
   using ResultType = std::decay_t< decltype( expression[0] ) >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != CommunicatorType::NullGroup ) {
      const ResultType localResult = ExpressionMin( expression.getConstLocalView() );
@@ -42,6 +44,8 @@ auto DistributedExpressionArgMin( const Expression& expression )
   using ResultType = std::pair< RealType, IndexType >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< RealType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's real type" );
   ResultType result( -1, std::numeric_limits< RealType >::max() );
   const auto group = expression.getCommunicationGroup();
   if( group != CommunicatorType::NullGroup ) {
@@ -82,6 +86,8 @@ auto DistributedExpressionMax( const Expression& expression ) -> std::decay_t< d
   using ResultType = std::decay_t< decltype( expression[0] ) >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::lowest();
   if( expression.getCommunicationGroup() != CommunicatorType::NullGroup ) {
      const ResultType localResult = ExpressionMax( expression.getConstLocalView() );
@@ -99,6 +105,8 @@ auto DistributedExpressionArgMax( const Expression& expression )
   using ResultType = std::pair< RealType, IndexType >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< RealType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's real type" );
   ResultType result( -1, std::numeric_limits< RealType >::lowest() );
   const auto group = expression.getCommunicationGroup();
   if( group != CommunicatorType::NullGroup ) {
@@ -168,6 +176,8 @@ auto DistributedExpressionLogicalAnd( const Expression& expression ) -> std::dec
   using ResultType = std::decay_t< decltype( expression[0] && expression[0] ) >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != CommunicatorType::NullGroup ) {
      const ResultType localResult = ExpressionLogicalAnd( expression.getConstLocalView() );
@@ -196,6 +206,8 @@ auto DistributedExpressionBinaryAnd( const Expression& expression ) -> std::deca
   using ResultType = std::decay_t< decltype( expression[0] & expression[0] ) >;
   using CommunicatorType = typename Expression::CommunicatorType;

   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != CommunicatorType::NullGroup ) {
      const ResultType localResult = ExpressionLogicalBinaryAnd( expression.getConstLocalView() );
+12 −0
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@ auto ExpressionMin( const Expression& expression )
   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( const ResultType& a, const ResultType& b ) { return TNL::min( a, b ); };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduce( ( IndexType ) 0, expression.getSize(), reduction, fetch, std::numeric_limits< ResultType >::max() );
}

@@ -56,6 +58,8 @@ auto ExpressionArgMin( const Expression& expression )
      else if( a == b && bIdx < aIdx )
         aIdx = bIdx;
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduceWithArgument( ( IndexType ) 0, expression.getSize(), reduction, fetch, std::numeric_limits< ResultType >::max() );
}

@@ -69,6 +73,8 @@ auto ExpressionMax( const Expression& expression )
   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   auto reduction = [] __cuda_callable__ ( const ResultType& a, const ResultType& b ) { return TNL::max( a, b ); };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduce( ( IndexType ) 0, expression.getSize(), reduction, fetch, std::numeric_limits< ResultType >::lowest() );
}

@@ -89,6 +95,8 @@ auto ExpressionArgMax( const Expression& expression )
      else if( a == b && bIdx < aIdx )
         aIdx = bIdx;
   };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduceWithArgument( ( IndexType ) 0, expression.getSize(), reduction, fetch, std::numeric_limits< ResultType >::lowest() );
}

@@ -125,6 +133,8 @@ auto ExpressionLogicalAnd( const Expression& expression )

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduce( ( IndexType ) 0, expression.getSize(), std::logical_and<>{}, fetch, std::numeric_limits< ResultType >::max() );
}

@@ -149,6 +159,8 @@ auto ExpressionBinaryAnd( const Expression& expression )

   const auto view = expression.getConstView();
   auto fetch = [=] __cuda_callable__ ( IndexType i ) { return view[ i ]; };
   static_assert( std::numeric_limits< ResultType >::is_specialized,
                  "std::numeric_limits is not specialized for the reduction's result type" );
   return Algorithms::Reduction< typename Expression::DeviceType >::reduce( ( IndexType ) 0, expression.getSize(), std::bit_and<>{}, fetch, std::numeric_limits< ResultType >::max() );
}