From e4ed26289d07fafd221fcb7119d7b807573ab98e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz> Date: Sat, 10 Aug 2019 18:06:32 +0200 Subject: [PATCH] Implemented distributed argMin and argMax Fixes #41 --- .../DistributedVerticalOperations.h | 95 ++++++++++++++++++- .../Expressions/VerticalOperations.h | 2 - .../Containers/VectorVerticalOperationsTest.h | 10 -- 3 files changed, 90 insertions(+), 17 deletions(-) diff --git a/src/TNL/Containers/Expressions/DistributedVerticalOperations.h b/src/TNL/Containers/Expressions/DistributedVerticalOperations.h index ef793900ae..4e569a45d4 100644 --- a/src/TNL/Containers/Expressions/DistributedVerticalOperations.h +++ b/src/TNL/Containers/Expressions/DistributedVerticalOperations.h @@ -12,7 +12,6 @@ #include <TNL/Containers/Expressions/VerticalOperations.h> #include <TNL/Communicators/MpiDefs.h> -#include <TNL/Exceptions/NotImplementedError.h> namespace TNL { namespace Containers { @@ -38,8 +37,51 @@ template< typename Expression > auto DistributedExpressionArgMin( const Expression& expression ) -> std::pair< typename Expression::IndexType, std::decay_t< decltype( expression[0] ) > > { - using ResultType = std::decay_t< decltype( expression[0] ) >; - throw Exceptions::NotImplementedError("DistributedExpressionArgMin is not implemented yet"); + using RealType = std::decay_t< decltype( expression[0] ) >; + using IndexType = typename Expression::IndexType; + using ResultType = std::pair< IndexType, RealType >; + using CommunicatorType = typename Expression::CommunicatorType; + + ResultType result( -1, std::numeric_limits< RealType >::max() ); + const auto group = expression.getCommunicationGroup(); + if( group != CommunicatorType::NullGroup ) { + // compute local argMin + ResultType localResult = ExpressionArgMin( expression.getConstLocalView() ); + // transform local index to global index + localResult.first += expression.getLocalRange().getBegin(); + + // scatter local result to all processes and gather their results + const int nproc = CommunicatorType::GetSize( group ); + ResultType dataForScatter[ nproc ]; + for( int i = 0; i < nproc; i++ ) dataForScatter[ i ] = localResult; + ResultType gatheredResults[ nproc ]; + // NOTE: exchanging general data types does not work with MPI + //CommunicatorType::Alltoall( dataForScatter, 1, gatheredResults, 1, group ); + CommunicatorType::Alltoall( (char*) dataForScatter, sizeof(ResultType), (char*) gatheredResults, sizeof(ResultType), group ); + + // reduce the gathered data + const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) + auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; + auto reduction = [] ( IndexType& aIdx, const IndexType& bIdx, RealType& a, const RealType& b ) { + if( a > b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + auto volatileReduction = [] ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile RealType& a, volatile RealType& b ) { + if( a > b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) nproc, reduction, volatileReduction, fetch, std::numeric_limits< RealType >::max() ); + result.first = gatheredResults[ result.first ].first; + } + return result; } template< typename Expression > @@ -60,8 +102,51 @@ template< typename Expression > auto DistributedExpressionArgMax( const Expression& expression ) -> std::pair< typename Expression::IndexType, std::decay_t< decltype( expression[0] ) > > { - using ResultType = std::decay_t< decltype( expression[0] ) >; - throw Exceptions::NotImplementedError("DistributedExpressionArgMax is not implemented yet"); + using RealType = std::decay_t< decltype( expression[0] ) >; + using IndexType = typename Expression::IndexType; + using ResultType = std::pair< IndexType, RealType >; + using CommunicatorType = typename Expression::CommunicatorType; + + ResultType result( -1, std::numeric_limits< RealType >::lowest() ); + const auto group = expression.getCommunicationGroup(); + if( group != CommunicatorType::NullGroup ) { + // compute local argMax + ResultType localResult = ExpressionArgMax( expression.getConstLocalView() ); + // transform local index to global index + localResult.first += expression.getLocalRange().getBegin(); + + // scatter local result to all processes and gather their results + const int nproc = CommunicatorType::GetSize( group ); + ResultType dataForScatter[ nproc ]; + for( int i = 0; i < nproc; i++ ) dataForScatter[ i ] = localResult; + ResultType gatheredResults[ nproc ]; + // NOTE: exchanging general data types does not work with MPI + //CommunicatorType::Alltoall( dataForScatter, 1, gatheredResults, 1, group ); + CommunicatorType::Alltoall( (char*) dataForScatter, sizeof(ResultType), (char*) gatheredResults, sizeof(ResultType), group ); + + // reduce the gathered data + const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) + auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; + auto reduction = [] ( IndexType& aIdx, const IndexType& bIdx, RealType& a, const RealType& b ) { + if( a < b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + auto volatileReduction = [] ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile RealType& a, volatile RealType& b ) { + if( a < b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) nproc, reduction, volatileReduction, fetch, std::numeric_limits< RealType >::lowest() ); + result.first = gatheredResults[ result.first ].first; + } + return result; } template< typename Expression > diff --git a/src/TNL/Containers/Expressions/VerticalOperations.h b/src/TNL/Containers/Expressions/VerticalOperations.h index e1d0f900a0..a125057808 100644 --- a/src/TNL/Containers/Expressions/VerticalOperations.h +++ b/src/TNL/Containers/Expressions/VerticalOperations.h @@ -52,7 +52,6 @@ auto ExpressionArgMin( const Expression& expression ) } else if( a == b && bIdx < aIdx ) aIdx = bIdx; - }; auto volatileReduction = [=] __cuda_callable__ ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile ResultType& a, volatile ResultType& b ) { if( a > b ) { @@ -61,7 +60,6 @@ auto ExpressionArgMin( const Expression& expression ) } else if( a == b && bIdx < aIdx ) aIdx = bIdx; - }; return Algorithms::Reduction< typename Expression::DeviceType >::reduceWithArgument( expression.getSize(), reduction, volatileReduction, fetch, std::numeric_limits< ResultType >::max() ); } diff --git a/src/UnitTests/Containers/VectorVerticalOperationsTest.h b/src/UnitTests/Containers/VectorVerticalOperationsTest.h index 8eb0c5b715..758ee1d500 100644 --- a/src/UnitTests/Containers/VectorVerticalOperationsTest.h +++ b/src/UnitTests/Containers/VectorVerticalOperationsTest.h @@ -166,12 +166,7 @@ TYPED_TEST( VectorVerticalOperationsTest, max ) EXPECT_EQ( max(V1 + 2), size - 1 + 2 ); } -// FIXME: distributed argMax is not implemented yet -#ifdef DISTRIBUTED_VECTOR -TYPED_TEST( VectorVerticalOperationsTest, DISABLED_argMax ) -#else TYPED_TEST( VectorVerticalOperationsTest, argMax ) -#endif { SETUP_VERTICAL_TEST_ALIASES; using RealType = typename TestFixture::VectorOrView::RealType; @@ -196,12 +191,7 @@ TYPED_TEST( VectorVerticalOperationsTest, min ) EXPECT_EQ( min(V1 + 2), 2 ); } -// FIXME: distributed argMin is not implemented yet -#ifdef DISTRIBUTED_VECTOR -TYPED_TEST( VectorVerticalOperationsTest, DISABLED_argMin ) -#else TYPED_TEST( VectorVerticalOperationsTest, argMin ) -#endif { SETUP_VERTICAL_TEST_ALIASES; using RealType = typename TestFixture::VectorOrView::RealType; -- GitLab