Commit 695e3999 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

One more fix of reduction in distributed vector.

parent 718abdb4
Loading
Loading
Loading
Loading
+8 −8
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
{
   using RealType = std::decay_t< decltype( expression[0] ) >;
   using IndexType = typename Expression::IndexType;
   using ResultType = std::pair< IndexType, RealType >;
   using ResultType = std::pair< RealType, IndexType >;
   using CommunicatorType = typename Expression::CommunicatorType;

   ResultType result( -1, std::numeric_limits< RealType >::max() );
@@ -48,7 +48,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
      // compute local argMin
      ResultType localResult = ExpressionArgMin( expression.getConstLocalView() );
      // transform local index to global index
      localResult.first += expression.getLocalRange().getBegin();
      localResult.second += expression.getLocalRange().getBegin();

      // scatter local result to all processes and gather their results
      const int nproc = CommunicatorType::GetSize( group );
@@ -61,7 +61,7 @@ auto DistributedExpressionArgMin( const Expression& expression )

      // reduce the gathered data
      const auto* _data = gatheredResults;  // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; };
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
      auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
         if( a > b ) {
            a = b;
@@ -71,7 +71,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
            aIdx = bIdx;
      };
      result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::max() );
      result.first = gatheredResults[ result.first ].first;
      result.second = gatheredResults[ result.second ].second;
   }
   return result;
}
@@ -96,7 +96,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
{
   using RealType = std::decay_t< decltype( expression[0] ) >;
   using IndexType = typename Expression::IndexType;
   using ResultType = std::pair< IndexType, RealType >;
   using ResultType = std::pair< RealType, IndexType >;
   using CommunicatorType = typename Expression::CommunicatorType;

   ResultType result( -1, std::numeric_limits< RealType >::lowest() );
@@ -105,7 +105,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
      // compute local argMax
      ResultType localResult = ExpressionArgMax( expression.getConstLocalView() );
      // transform local index to global index
      localResult.first += expression.getLocalRange().getBegin();
      localResult.second += expression.getLocalRange().getBegin();

      // scatter local result to all processes and gather their results
      const int nproc = CommunicatorType::GetSize( group );
@@ -118,7 +118,7 @@ auto DistributedExpressionArgMax( const Expression& expression )

      // reduce the gathered data
      const auto* _data = gatheredResults;  // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; };
      auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
      auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
         if( a < b ) {
            a = b;
@@ -128,7 +128,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
            aIdx = bIdx;
      };
      result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( ( IndexType ) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::lowest() );
      result.first = gatheredResults[ result.first ].first;
      result.second = gatheredResults[ result.second ].second;
   }
   return result;
}