Commit 05903a8f authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Removed useless vertical operations and used RemoveET in reduce.h

parent 0d329226
Loading
Loading
Loading
Loading
+7 −6
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@

#include <TNL/Functional.h>  // extension of STL functionals for reduction
#include <TNL/Algorithms/detail/Reduction.h>
#include <TNL/Containers/Expressions/TypeTraits.h>  // RemoveET

namespace TNL {
namespace Algorithms {
@@ -131,7 +132,7 @@ auto reduce( const Index begin,
             Fetch&& fetch,
             Reduction&& reduction = TNL::Plus{} )
{
   using Result = std::decay_t< decltype( fetch( 0 ) ) >;
   using Result = Containers::Expressions::RemoveET< decltype( reduction( fetch(0), fetch(0) ) ) >;
   return reduce< Device >( begin,
                            end,
                            std::forward< Fetch >( fetch ),
@@ -204,10 +205,10 @@ template< typename Array,
auto reduce( const Array& array,
             Reduction&& reduction = TNL::Plus{} )
{
   using ValueType = typename Array::ValueType;
   using Result = Containers::Expressions::RemoveET< decltype( reduction( array(0), array(0) ) ) >;
   return reduce< Array, Device >( array,
                                   std::forward< Reduction >( reduction ),
                                   reduction.template getIdentity< ValueType >() );
                                   reduction.template getIdentity< Result >() );
}

/**
@@ -329,7 +330,7 @@ reduceWithArgument( const Index begin,
                    Fetch&& fetch,
                    Reduction&& reduction )
{
   using Result = std::decay_t< decltype( fetch( 0 ) ) >;
   using Result = Containers::Expressions::RemoveET< decltype( fetch(0) ) >;
   return reduceWithArgument< Device >( begin,
                                        end,
                                        std::forward< Fetch >( fetch ),
@@ -400,10 +401,10 @@ template< typename Array,
auto reduceWithArgument( const Array& array,
                         Reduction&& reduction )
{
   using ValueType = typename Array::ValueType;
   using Result = Containers::Expressions::RemoveET< decltype( array(0) ) >;
   return reduceWithArgument< Array, Device >( array,
                                               std::forward< Reduction >( reduction ),
                                               reduction.template getIdentity< ValueType >() );
                                               reduction.template getIdentity< Result >() );
}

} // namespace Algorithms
+13 −29
Original line number Diff line number Diff line
@@ -10,8 +10,8 @@

#pragma once

#include <TNL/Containers/Expressions/VerticalOperations.h>
#include <TNL/MPI/Wrappers.h>
#include <TNL/Algorithms/reduce.h>

namespace TNL {
namespace Containers {
@@ -26,7 +26,7 @@ auto DistributedExpressionMin( const Expression& expression ) -> std::decay_t< d
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionMin( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::Min{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_MIN, expression.getCommunicationGroup() );
   }
   return result;
@@ -46,7 +46,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
   const auto group = expression.getCommunicationGroup();
   if( group != MPI::NullGroup() ) {
      // compute local argMin
      ResultType localResult = ExpressionArgMin( expression.getConstLocalView() );
      ResultType localResult = Algorithms::reduceWithArgument( expression.getConstLocalView(), TNL::MinWithArg{} );
      // transform local index to global index
      localResult.second += expression.getLocalRange().getBegin();

@@ -62,15 +62,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
      // reduce the gathered data
      const auto* _data = gatheredResults;  // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
      auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
         if( a > b ) {
            a = b;
            aIdx = bIdx;
         }
         else if( a == b && bIdx < aIdx )
            aIdx = bIdx;
      };
      result = Algorithms::reduceWithArgument< Devices::Host >( (IndexType) 0, (IndexType) nproc, fetch, reduction, std::numeric_limits< RealType >::max() );
      result = Algorithms::reduceWithArgument< Devices::Host >( (IndexType) 0, (IndexType) nproc, fetch, TNL::MinWithArg{} );
      result.second = gatheredResults[ result.second ].second;
   }
   return result;
@@ -85,7 +77,7 @@ auto DistributedExpressionMax( const Expression& expression ) -> std::decay_t< d
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::lowest();
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionMax( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::Max{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_MAX, expression.getCommunicationGroup() );
   }
   return result;
@@ -105,7 +97,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
   const auto group = expression.getCommunicationGroup();
   if( group != MPI::NullGroup() ) {
      // compute local argMax
      ResultType localResult = ExpressionArgMax( expression.getConstLocalView() );
      ResultType localResult = Algorithms::reduceWithArgument( expression.getConstLocalView(), TNL::MaxWithArg{} );
      // transform local index to global index
      localResult.second += expression.getLocalRange().getBegin();

@@ -121,15 +113,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
      // reduce the gathered data
      const auto* _data = gatheredResults;  // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
      auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
         if( a < b ) {
            a = b;
            aIdx = bIdx;
         }
         else if( a == b && bIdx < aIdx )
            aIdx = bIdx;
      };
      result = Algorithms::reduceWithArgument< Devices::Host >( ( IndexType ) 0, (IndexType) nproc, fetch, reduction, std::numeric_limits< RealType >::lowest() );
      result = Algorithms::reduceWithArgument< Devices::Host >( ( IndexType ) 0, (IndexType) nproc, fetch, TNL::MaxWithArg{} );
      result.second = gatheredResults[ result.second ].second;
   }
   return result;
@@ -142,7 +126,7 @@ auto DistributedExpressionSum( const Expression& expression ) -> std::decay_t< d

   ResultType result = 0;
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionSum( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::Plus{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_SUM, expression.getCommunicationGroup() );
   }
   return result;
@@ -155,7 +139,7 @@ auto DistributedExpressionProduct( const Expression& expression ) -> std::decay_

   ResultType result = 1;
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionProduct( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::Multiplies{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_PROD, expression.getCommunicationGroup() );
   }
   return result;
@@ -170,7 +154,7 @@ auto DistributedExpressionLogicalAnd( const Expression& expression ) -> std::dec
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionLogicalAnd( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::LogicalAnd{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_LAND, expression.getCommunicationGroup() );
   }
   return result;
@@ -183,7 +167,7 @@ auto DistributedExpressionLogicalOr( const Expression& expression ) -> std::deca

   ResultType result = 0;
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionLogicalOr( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::LogicalOr{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_LOR, expression.getCommunicationGroup() );
   }
   return result;
@@ -198,7 +182,7 @@ auto DistributedExpressionBinaryAnd( const Expression& expression ) -> std::deca
                  "std::numeric_limits is not specialized for the reduction's result type" );
   ResultType result = std::numeric_limits< ResultType >::max();
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionLogicalBinaryAnd( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::BitAnd{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_BAND, expression.getCommunicationGroup() );
   }
   return result;
@@ -211,7 +195,7 @@ auto DistributedExpressionBinaryOr( const Expression& expression ) -> std::decay

   ResultType result = 0;
   if( expression.getCommunicationGroup() != MPI::NullGroup() ) {
      const ResultType localResult = ExpressionBinaryOr( expression.getConstLocalView() );
      const ResultType localResult = Algorithms::reduce( expression.getConstLocalView(), TNL::BitOr{} );
      MPI::Allreduce( &localResult, &result, 1, MPI_BOR, expression.getCommunicationGroup() );
   }
   return result;
+12 −12
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@
#include <TNL/Containers/Expressions/ExpressionVariableType.h>
#include <TNL/Containers/Expressions/Comparison.h>
#include <TNL/Containers/Expressions/HorizontalOperations.h>
#include <TNL/Containers/Expressions/VerticalOperations.h>
#include <TNL/Algorithms/reduce.h>

namespace TNL {
namespace Containers {
@@ -370,7 +370,7 @@ template< typename ET1, typename ET2,
auto
operator,( const ET1& a, const ET2& b )
{
   return ExpressionSum( a * b );
   return Algorithms::reduce( a * b, TNL::Plus{} );
}

template< typename ET1, typename ET2,
@@ -662,7 +662,7 @@ template< typename ET1,
auto
min( const ET1& a )
{
   return ExpressionMin( a );
   return Algorithms::reduce( a, TNL::Min{} );
}

template< typename ET1,
@@ -670,7 +670,7 @@ template< typename ET1,
auto
argMin( const ET1& a )
{
   return ExpressionArgMin( a );
   return Algorithms::reduceWithArgument( a, TNL::MinWithArg{} );
}

template< typename ET1,
@@ -678,7 +678,7 @@ template< typename ET1,
auto
max( const ET1& a )
{
   return ExpressionMax( a );
   return Algorithms::reduce( a, TNL::Max{} );
}

template< typename ET1,
@@ -686,7 +686,7 @@ template< typename ET1,
auto
argMax( const ET1& a )
{
   return ExpressionArgMax( a );
   return Algorithms::reduceWithArgument( a, TNL::MaxWithArg{} );
}

template< typename ET1,
@@ -694,7 +694,7 @@ template< typename ET1,
auto
sum( const ET1& a )
{
   return ExpressionSum( a );
   return Algorithms::reduce( a, TNL::Plus{} );
}

template< typename ET1,
@@ -743,7 +743,7 @@ template< typename ET1,
auto
product( const ET1& a )
{
   return ExpressionProduct( a );
   return Algorithms::reduce( a, TNL::Multiplies{} );
}

template< typename ET1,
@@ -751,7 +751,7 @@ template< typename ET1,
auto
logicalAnd( const ET1& a )
{
   return ExpressionLogicalAnd( a );
   return Algorithms::reduce( a, TNL::LogicalAnd{} );
}

template< typename ET1,
@@ -759,7 +759,7 @@ template< typename ET1,
auto
logicalOr( const ET1& a )
{
   return ExpressionLogicalOr( a );
   return Algorithms::reduce( a, TNL::LogicalOr{} );
}

template< typename ET1,
@@ -767,7 +767,7 @@ template< typename ET1,
auto
binaryAnd( const ET1& a )
{
   return ExpressionBinaryAnd( a );
   return Algorithms::reduce( a, TNL::BitAnd{} );
}

template< typename ET1,
@@ -775,7 +775,7 @@ template< typename ET1,
auto
binaryOr( const ET1& a )
{
   return ExpressionBinaryOr( a );
   return Algorithms::reduce( a, TNL::BitOr{} );
}

#endif // DOXYGEN_ONLY
+0 −142
Original line number Diff line number Diff line
/***************************************************************************
                          VerticalOperations.h  -  description
                             -------------------
    begin                : May 1, 2019
    copyright            : (C) 2019 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

#include <limits>
#include <type_traits>

#include <TNL/Functional.h>
#include <TNL/Algorithms/reduce.h>
#include <TNL/Containers/Expressions/TypeTraits.h>

////
// By vertical operations we mean those applied across vector elements or
// vector expression elements. It means for example minim/maximum of all
// vector elements etc.
namespace TNL {
namespace Containers {
namespace Expressions {

////
// Vertical operations
template< typename Expression >
auto ExpressionMin( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Min{}, TNL::Min::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionArgMin( const Expression& expression )
-> RemoveET< std::pair< std::decay_t< decltype( expression[0] ) >, typename Expression::IndexType > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::MinWithArg{}, TNL::MinWithArg::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionMax( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Max{}, TNL::Max::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionArgMax( const Expression& expression )
-> RemoveET< std::pair< std::decay_t< decltype( expression[0] ) >, typename Expression::IndexType > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduceWithArgument< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::MaxWithArg{}, TNL::MaxWithArg::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionSum( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] + expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] + expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Plus{}, TNL::Plus::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionProduct( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] * expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] * expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::Multiplies{}, TNL::Multiplies::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionLogicalAnd( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] && expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] && expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::LogicalAnd{}, TNL::LogicalAnd::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionLogicalOr( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] || expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] || expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::LogicalOr{}, TNL::LogicalOr::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionBinaryAnd( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] & expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] & expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::BitAnd{}, TNL::BitAnd::template getIdentity< ResultType >() );
}

template< typename Expression >
auto ExpressionBinaryOr( const Expression& expression )
-> RemoveET< std::decay_t< decltype( expression[0] | expression[0] ) > >
{
   using ResultType = RemoveET< std::decay_t< decltype( expression[0] | expression[0] ) > >;
   using IndexType = typename Expression::IndexType;

   const auto view = expression.getConstView();
   return Algorithms::reduce< typename Expression::DeviceType >( ( IndexType ) 0, expression.getSize(), view, TNL::BitOr{}, TNL::BitOr::template getIdentity< ResultType >() );
}

} // namespace Expressions
} // namespace Containers
} // namespace TNL
+2 −2
Original line number Diff line number Diff line
@@ -80,7 +80,7 @@ struct Max
};

/**
 * \brief Extension of \ref std::min<void> for use with \ref TNL::Algorithms::reduceWithArgument.
 * \brief Function object implementing `argmin(x, y, i, j)` for use with \ref TNL::Algorithms::reduceWithArgument.
 */
struct MinWithArg
{
@@ -108,7 +108,7 @@ struct MinWithArg
};

/**
 * \brief Extension of \ref std::max<void> for use with \ref TNL::Algorithms::reduceWithArgument.
 * \brief Function object implementing `argmax(x, y, i, j)` for use with \ref TNL::Algorithms::reduceWithArgument.
 */
struct MaxWithArg
{