diff --git a/src/TNL/Containers/Expressions/DistributedVerticalOperations.h b/src/TNL/Containers/Expressions/DistributedVerticalOperations.h index ef793900ae48020242495c1b5becc58a727525c7..4e569a45d46ea5db0a7bb24e9418ae064e30d5ac 100644 --- a/src/TNL/Containers/Expressions/DistributedVerticalOperations.h +++ b/src/TNL/Containers/Expressions/DistributedVerticalOperations.h @@ -12,7 +12,6 @@ #include <TNL/Containers/Expressions/VerticalOperations.h> #include <TNL/Communicators/MpiDefs.h> -#include <TNL/Exceptions/NotImplementedError.h> namespace TNL { namespace Containers { @@ -38,8 +37,51 @@ template< typename Expression > auto DistributedExpressionArgMin( const Expression& expression ) -> std::pair< typename Expression::IndexType, std::decay_t< decltype( expression[0] ) > > { - using ResultType = std::decay_t< decltype( expression[0] ) >; - throw Exceptions::NotImplementedError("DistributedExpressionArgMin is not implemented yet"); + using RealType = std::decay_t< decltype( expression[0] ) >; + using IndexType = typename Expression::IndexType; + using ResultType = std::pair< IndexType, RealType >; + using CommunicatorType = typename Expression::CommunicatorType; + + ResultType result( -1, std::numeric_limits< RealType >::max() ); + const auto group = expression.getCommunicationGroup(); + if( group != CommunicatorType::NullGroup ) { + // compute local argMin + ResultType localResult = ExpressionArgMin( expression.getConstLocalView() ); + // transform local index to global index + localResult.first += expression.getLocalRange().getBegin(); + + // scatter local result to all processes and gather their results + const int nproc = CommunicatorType::GetSize( group ); + ResultType dataForScatter[ nproc ]; + for( int i = 0; i < nproc; i++ ) dataForScatter[ i ] = localResult; + ResultType gatheredResults[ nproc ]; + // NOTE: exchanging general data types does not work with MPI + //CommunicatorType::Alltoall( dataForScatter, 1, gatheredResults, 1, group ); + CommunicatorType::Alltoall( (char*) dataForScatter, sizeof(ResultType), (char*) gatheredResults, sizeof(ResultType), group ); + + // reduce the gathered data + const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) + auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; + auto reduction = [] ( IndexType& aIdx, const IndexType& bIdx, RealType& a, const RealType& b ) { + if( a > b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + auto volatileReduction = [] ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile RealType& a, volatile RealType& b ) { + if( a > b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) nproc, reduction, volatileReduction, fetch, std::numeric_limits< RealType >::max() ); + result.first = gatheredResults[ result.first ].first; + } + return result; } template< typename Expression > @@ -60,8 +102,51 @@ template< typename Expression > auto DistributedExpressionArgMax( const Expression& expression ) -> std::pair< typename Expression::IndexType, std::decay_t< decltype( expression[0] ) > > { - using ResultType = std::decay_t< decltype( expression[0] ) >; - throw Exceptions::NotImplementedError("DistributedExpressionArgMax is not implemented yet"); + using RealType = std::decay_t< decltype( expression[0] ) >; + using IndexType = typename Expression::IndexType; + using ResultType = std::pair< IndexType, RealType >; + using CommunicatorType = typename Expression::CommunicatorType; + + ResultType result( -1, std::numeric_limits< RealType >::lowest() ); + const auto group = expression.getCommunicationGroup(); + if( group != CommunicatorType::NullGroup ) { + // compute local argMax + ResultType localResult = ExpressionArgMax( expression.getConstLocalView() ); + // transform local index to global index + localResult.first += expression.getLocalRange().getBegin(); + + // scatter local result to all processes and gather their results + const int nproc = CommunicatorType::GetSize( group ); + ResultType dataForScatter[ nproc ]; + for( int i = 0; i < nproc; i++ ) dataForScatter[ i ] = localResult; + ResultType gatheredResults[ nproc ]; + // NOTE: exchanging general data types does not work with MPI + //CommunicatorType::Alltoall( dataForScatter, 1, gatheredResults, 1, group ); + CommunicatorType::Alltoall( (char*) dataForScatter, sizeof(ResultType), (char*) gatheredResults, sizeof(ResultType), group ); + + // reduce the gathered data + const auto* _data = gatheredResults; // workaround for nvcc which does not allow to capture variable-length arrays (even in pure host code!) + auto fetch = [_data] ( IndexType i ) { return _data[ i ].second; }; + auto reduction = [] ( IndexType& aIdx, const IndexType& bIdx, RealType& a, const RealType& b ) { + if( a < b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + auto volatileReduction = [] ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile RealType& a, volatile RealType& b ) { + if( a < b ) { + a = b; + aIdx = bIdx; + } + else if( a == b && bIdx < aIdx ) + aIdx = bIdx; + }; + result = Algorithms::Reduction< Devices::Host >::reduceWithArgument( (IndexType) nproc, reduction, volatileReduction, fetch, std::numeric_limits< RealType >::lowest() ); + result.first = gatheredResults[ result.first ].first; + } + return result; } template< typename Expression > diff --git a/src/TNL/Containers/Expressions/VerticalOperations.h b/src/TNL/Containers/Expressions/VerticalOperations.h index e1d0f900a0fec85213766b2b5b77b0d413cf8732..a1250578084b4f95d6c4de23884d9396fea397b0 100644 --- a/src/TNL/Containers/Expressions/VerticalOperations.h +++ b/src/TNL/Containers/Expressions/VerticalOperations.h @@ -52,7 +52,6 @@ auto ExpressionArgMin( const Expression& expression ) } else if( a == b && bIdx < aIdx ) aIdx = bIdx; - }; auto volatileReduction = [=] __cuda_callable__ ( volatile IndexType& aIdx, volatile IndexType& bIdx, volatile ResultType& a, volatile ResultType& b ) { if( a > b ) { @@ -61,7 +60,6 @@ auto ExpressionArgMin( const Expression& expression ) } else if( a == b && bIdx < aIdx ) aIdx = bIdx; - }; return Algorithms::Reduction< typename Expression::DeviceType >::reduceWithArgument( expression.getSize(), reduction, volatileReduction, fetch, std::numeric_limits< ResultType >::max() ); } diff --git a/src/UnitTests/Containers/VectorVerticalOperationsTest.h b/src/UnitTests/Containers/VectorVerticalOperationsTest.h index 8eb0c5b71561605def0b72370260c33ead1d14c5..758ee1d50026dd0c23ccbf51f37c268732e5e16b 100644 --- a/src/UnitTests/Containers/VectorVerticalOperationsTest.h +++ b/src/UnitTests/Containers/VectorVerticalOperationsTest.h @@ -166,12 +166,7 @@ TYPED_TEST( VectorVerticalOperationsTest, max ) EXPECT_EQ( max(V1 + 2), size - 1 + 2 ); } -// FIXME: distributed argMax is not implemented yet -#ifdef DISTRIBUTED_VECTOR -TYPED_TEST( VectorVerticalOperationsTest, DISABLED_argMax ) -#else TYPED_TEST( VectorVerticalOperationsTest, argMax ) -#endif { SETUP_VERTICAL_TEST_ALIASES; using RealType = typename TestFixture::VectorOrView::RealType; @@ -196,12 +191,7 @@ TYPED_TEST( VectorVerticalOperationsTest, min ) EXPECT_EQ( min(V1 + 2), 2 ); } -// FIXME: distributed argMin is not implemented yet -#ifdef DISTRIBUTED_VECTOR -TYPED_TEST( VectorVerticalOperationsTest, DISABLED_argMin ) -#else TYPED_TEST( VectorVerticalOperationsTest, argMin ) -#endif { SETUP_VERTICAL_TEST_ALIASES; using RealType = typename TestFixture::VectorOrView::RealType;