Skip to content
Snippets Groups Projects
Commit 695e3999 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

One more fix of reduction in distributed vector.

parent 718abdb4
No related branches found
No related tags found
1 merge request!63To/matrices
...@@ -39,7 +39,7 @@ auto DistributedExpressionArgMin( const Expression& expression ) ...@@ -39,7 +39,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
{ {
using RealType = std::decay_t< decltype( expression[0] ) >; using RealType = std::decay_t< decltype( expression[0] ) >;
using IndexType = typename Expression::IndexType; using IndexType = typename Expression::IndexType;
using ResultType = std::pair< IndexType, RealType >; using ResultType = std::pair< RealType, IndexType >;
using CommunicatorType = typename Expression::CommunicatorType; using CommunicatorType = typename Expression::CommunicatorType;
ResultType result( -1, std::numeric_limits< RealType >::max() ); ResultType result( -1, std::numeric_limits< RealType >::max() );
...@@ -48,7 +48,7 @@ auto DistributedExpressionArgMin( const Expression& expression ) ...@@ -48,7 +48,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
// compute local argMin // compute local argMin
ResultType localResult = ExpressionArgMin( expression.getConstLocalView() ); ResultType localResult = ExpressionArgMin( expression.getConstLocalView() );
// transform local index to global index // transform local index to global index
localResult.first += expression.getLocalRange().getBegin(); localResult.second += expression.getLocalRange().getBegin();
// scatter local result to all processes and gather their results // scatter local result to all processes and gather their results
const int nproc = CommunicatorType::GetSize( group ); const int nproc = CommunicatorType::GetSize( group );
...@@ -61,7 +61,7 @@ auto DistributedExpressionArgMin( const Expression& expression ) ...@@ -61,7 +61,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
// reduce the gathered data // reduce the gathered data
const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) { auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
if( a > b ) { if( a > b ) {
a = b; a = b;
...@@ -71,7 +71,7 @@ auto DistributedExpressionArgMin( const Expression& expression ) ...@@ -71,7 +71,7 @@ auto DistributedExpressionArgMin( const Expression& expression )
aIdx = bIdx; aIdx = bIdx;
}; };
result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::max() ); result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::max() );
result.first = gatheredResults[ result.first ].first; result.second = gatheredResults[ result.second ].second;
} }
return result; return result;
} }
...@@ -96,7 +96,7 @@ auto DistributedExpressionArgMax( const Expression& expression ) ...@@ -96,7 +96,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
{ {
using RealType = std::decay_t< decltype( expression[0] ) >; using RealType = std::decay_t< decltype( expression[0] ) >;
using IndexType = typename Expression::IndexType; using IndexType = typename Expression::IndexType;
using ResultType = std::pair< IndexType, RealType >; using ResultType = std::pair< RealType, IndexType >;
using CommunicatorType = typename Expression::CommunicatorType; using CommunicatorType = typename Expression::CommunicatorType;
ResultType result( -1, std::numeric_limits< RealType >::lowest() ); ResultType result( -1, std::numeric_limits< RealType >::lowest() );
...@@ -105,7 +105,7 @@ auto DistributedExpressionArgMax( const Expression& expression ) ...@@ -105,7 +105,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
// compute local argMax // compute local argMax
ResultType localResult = ExpressionArgMax( expression.getConstLocalView() ); ResultType localResult = ExpressionArgMax( expression.getConstLocalView() );
// transform local index to global index // transform local index to global index
localResult.first += expression.getLocalRange().getBegin(); localResult.second += expression.getLocalRange().getBegin();
// scatter local result to all processes and gather their results // scatter local result to all processes and gather their results
const int nproc = CommunicatorType::GetSize( group ); const int nproc = CommunicatorType::GetSize( group );
...@@ -118,7 +118,7 @@ auto DistributedExpressionArgMax( const Expression& expression ) ...@@ -118,7 +118,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
// reduce the gathered data // reduce the gathered data
const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!)
auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; auto fetch = [_data] ( IndexType i ) { return _data[ i ].first; };
auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) { auto reduction = [] ( RealType& a, const RealType& b, IndexType& aIdx, const IndexType& bIdx ) {
if( a < b ) { if( a < b ) {
a = b; a = b;
...@@ -128,7 +128,7 @@ auto DistributedExpressionArgMax( const Expression& expression ) ...@@ -128,7 +128,7 @@ auto DistributedExpressionArgMax( const Expression& expression )
aIdx = bIdx; aIdx = bIdx;
}; };
result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( ( IndexType ) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::lowest() ); result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( ( IndexType ) 0, (IndexType) nproc, reduction, fetch, std::numeric_limits< RealType >::lowest() );
result.first = gatheredResults[ result.first ].first; result.second = gatheredResults[ result.second ].second;
} }
return result; return result;
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment