From 9e6b9ade48291bfb0d3e421ebcebf97fdcea1128 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz> Date: Tue, 18 Sep 2018 11:05:55 +0200 Subject: [PATCH] Refactoring Subrange and Partitioner --- .../DistributedContainers/DistributedArray.h | 20 +- .../DistributedArray_impl.h | 179 ++++++++---------- src/TNL/DistributedContainers/Partitioner.h | 24 +-- .../{IndexMap.h => Subrange.h} | 61 +++--- .../DistributedArrayTest.h | 50 +++-- 5 files changed, 153 insertions(+), 181 deletions(-) rename src/TNL/DistributedContainers/{IndexMap.h => Subrange.h} (60%) diff --git a/src/TNL/DistributedContainers/DistributedArray.h b/src/TNL/DistributedContainers/DistributedArray.h index 8c252807b9..0fc7704a4d 100644 --- a/src/TNL/DistributedContainers/DistributedArray.h +++ b/src/TNL/DistributedContainers/DistributedArray.h @@ -17,16 +17,15 @@ #include <TNL/Containers/Array.h> #include <TNL/Containers/ArrayView.h> #include <TNL/Communicators/MpiCommunicator.h> -#include <TNL/DistributedContainers/IndexMap.h> +#include <TNL/DistributedContainers/Subrange.h> namespace TNL { namespace DistributedContainers { template< typename Value, typename Device = Devices::Host, - typename Communicator = Communicators::MpiCommunicator, typename Index = int, - typename IndexMap = Subrange< Index > > + typename Communicator = Communicators::MpiCommunicator > class DistributedArray : public Object { @@ -36,22 +35,22 @@ public: using DeviceType = Device; using CommunicatorType = Communicator; using IndexType = Index; - using IndexMapType = IndexMap; + using LocalRangeType = Subrange< Index >; using LocalArrayType = Containers::Array< Value, Device, Index >; using LocalArrayViewType = Containers::ArrayView< Value, Device, Index >; using ConstLocalArrayViewType = Containers::ArrayView< typename std::add_const< Value >::type, Device, Index >; - using HostType = DistributedArray< Value, Devices::Host, Communicator, Index, IndexMap >; - using CudaType = DistributedArray< Value, Devices::Cuda, Communicator, Index, IndexMap >; + using HostType = DistributedArray< Value, Devices::Host, Index, Communicator >; + using CudaType = DistributedArray< Value, Devices::Cuda, Index, Communicator >; DistributedArray() = default; DistributedArray( DistributedArray& ) = default; - DistributedArray( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup ); + DistributedArray( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); - void setDistribution( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup ); + void setDistribution( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup ); - const IndexMap& getIndexMap() const; + const LocalRangeType& getLocalRange() const; CommunicationGroup getCommunicationGroup() const; @@ -125,7 +124,8 @@ public: // TODO: serialization (save, load, boundLoad) protected: - IndexMap indexMap; + LocalRangeType localRange; + IndexType globalSize = 0; CommunicationGroup group = Communicator::NullGroup; LocalArrayType localData; diff --git a/src/TNL/DistributedContainers/DistributedArray_impl.h b/src/TNL/DistributedContainers/DistributedArray_impl.h index 11a79eecf4..7731bd6bee 100644 --- a/src/TNL/DistributedContainers/DistributedArray_impl.h +++ b/src/TNL/DistributedContainers/DistributedArray_impl.h @@ -22,49 +22,47 @@ namespace DistributedContainers { template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: -DistributedArray( IndexMap indexMap, CommunicationGroup group ) + typename Communicator > +DistributedArray< Value, Device, Index, Communicator >:: +DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) { - setDistribution( indexMap, group ); + setDistribution( localRange, globalSize, group ); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: -setDistribution( IndexMap indexMap, CommunicationGroup group ) +DistributedArray< Value, Device, Index, Communicator >:: +setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group ) { - this->indexMap = indexMap; + TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" ); + this->localRange = localRange; + this->globalSize = globalSize; this->group = group; if( group != Communicator::NullGroup ) - localData.setSize( indexMap.getLocalSize() ); + localData.setSize( localRange.getSize() ); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -const IndexMap& -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: -getIndexMap() const + typename Communicator > +const Subrange< Index >& +DistributedArray< Value, Device, Index, Communicator >:: +getLocalRange() const { - return indexMap; + return localRange; } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > typename Communicator::CommunicationGroup -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: getCommunicationGroup() const { return group; @@ -72,11 +70,10 @@ getCommunicationGroup() const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::LocalArrayViewType -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: + typename Communicator > +typename DistributedArray< Value, Device, Index, Communicator >::LocalArrayViewType +DistributedArray< Value, Device, Index, Communicator >:: getLocalArrayView() { return localData; @@ -84,11 +81,10 @@ getLocalArrayView() template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::ConstLocalArrayViewType -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: + typename Communicator > +typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalArrayViewType +DistributedArray< Value, Device, Index, Communicator >:: getLocalArrayView() const { return localData; @@ -96,26 +92,24 @@ getLocalArrayView() const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: copyFromGlobal( ConstLocalArrayViewType globalArray ) { - TNL_ASSERT_EQ( indexMap.getGlobalSize(), globalArray.getSize(), + TNL_ASSERT_EQ( getSize(), globalArray.getSize(), "given global array has different size than the distributed array" ); LocalArrayViewType localView( localData ); - const IndexMap indexMap = getIndexMap(); + const LocalRangeType localRange = getLocalRange(); auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable { - if( indexMap.isLocal( i ) ) - localView[ indexMap.getLocalIndex( i ) ] = globalArray[ i ]; + localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ]; }; - ParallelFor< DeviceType >::exec( (IndexType) 0, indexMap.getGlobalSize(), kernel ); + ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel ); } @@ -125,29 +119,26 @@ copyFromGlobal( ConstLocalArrayViewType globalArray ) template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > String -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: getType() { return String( "DistributedContainers::DistributedArray< " ) + TNL::getType< Value >() + ", " + Device::getDeviceType() + ", " + - // TODO: communicators don't have a getType method - "<Communicator>, " + TNL::getType< Index >() + ", " + - IndexMap::getType() + " >"; + // TODO: communicators don't have a getType method + "<Communicator> >"; } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > String -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: getTypeVirtual() const { return getType(); @@ -155,52 +146,50 @@ getTypeVirtual() const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > template< typename Array > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: setLike( const Array& array ) { - indexMap = array.getIndexMap(); + localRange = array.getLocalRange(); + globalSize = array.getSize(); group = array.getCommunicationGroup(); localData.setLike( array.getLocalArrayView() ); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: reset() { - indexMap.reset(); + localRange.reset(); + globalSize = 0; group = Communicator::NullGroup; localData.reset(); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > Index -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: getSize() const { - return indexMap.getGlobalSize(); + return globalSize; } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: setValue( ValueType value ) { localData.setValue( value ); @@ -208,65 +197,60 @@ setValue( ValueType value ) template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > void -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: setElement( IndexType i, ValueType value ) { - const IndexType li = indexMap.getLocalIndex( i ); + const IndexType li = localRange.getLocalIndex( i ); localData.setElement( li, value ); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > Value -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: getElement( IndexType i ) const { - const IndexType li = indexMap.getLocalIndex( i ); + const IndexType li = localRange.getLocalIndex( i ); return localData.getElement( li ); } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > __cuda_callable__ Value& -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) { - const IndexType li = indexMap.getLocalIndex( i ); + const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > __cuda_callable__ const Value& -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: operator[]( IndexType i ) const { - const IndexType li = indexMap.getLocalIndex( i ); + const IndexType li = localRange.getLocalIndex( i ); return localData[ li ]; } template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -DistributedArray< Value, Device, Communicator, Index, IndexMap >& -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: + typename Communicator > +DistributedArray< Value, Device, Index, Communicator >& +DistributedArray< Value, Device, Index, Communicator >:: operator=( const DistributedArray& array ) { setLike( array ); @@ -276,12 +260,11 @@ operator=( const DistributedArray& array ) template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > template< typename Array > -DistributedArray< Value, Device, Communicator, Index, IndexMap >& -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >& +DistributedArray< Value, Device, Index, Communicator >:: operator=( const Array& array ) { setLike( array ); @@ -291,19 +274,19 @@ operator=( const Array& array ) template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > template< typename Array > bool -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: operator==( const Array& array ) const { // we can't run allreduce if the communication groups are different if( group != array.getCommunicationGroup() ) return false; const bool localResult = - indexMap == array.getIndexMap() && + localRange == array.getLocalRange() && + globalSize == array.getSize() && localData == array.getLocalArrayView(); bool result = true; if( group != CommunicatorType::NullGroup ) @@ -313,12 +296,11 @@ operator==( const Array& array ) const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > template< typename Array > bool -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: operator!=( const Array& array ) const { return ! (*this == array); @@ -326,11 +308,10 @@ operator!=( const Array& array ) const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > bool -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: containsValue( ValueType value ) const { bool result = false; @@ -343,11 +324,10 @@ containsValue( ValueType value ) const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > + typename Communicator > bool -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: +DistributedArray< Value, Device, Index, Communicator >:: containsOnlyValue( ValueType value ) const { bool result = true; @@ -360,10 +340,9 @@ containsOnlyValue( ValueType value ) const template< typename Value, typename Device, - typename Communicator, typename Index, - typename IndexMap > -DistributedArray< Value, Device, Communicator, Index, IndexMap >:: + typename Communicator > +DistributedArray< Value, Device, Index, Communicator >:: operator bool() const { return getSize() != 0; diff --git a/src/TNL/DistributedContainers/Partitioner.h b/src/TNL/DistributedContainers/Partitioner.h index 68d0c9d3fa..3c8cace6fb 100644 --- a/src/TNL/DistributedContainers/Partitioner.h +++ b/src/TNL/DistributedContainers/Partitioner.h @@ -1,5 +1,5 @@ /*************************************************************************** - DistributedArray.h - description + Partitioner.h - description ------------------- begin : Sep 6, 2018 copyright : (C) 2018 by Tomas Oberhuber et al. @@ -12,35 +12,31 @@ #pragma once -#include "IndexMap.h" +#include "Subrange.h" #include <TNL/Math.h> namespace TNL { namespace DistributedContainers { -template< typename IndexMap, typename Communicator > -class Partitioner -{}; - template< typename Index, typename Communicator > -class Partitioner< Subrange< Index >, Communicator > +class Partitioner { using CommunicationGroup = typename Communicator::CommunicationGroup; public: - using IndexMap = Subrange< Index >; + using SubrangeType = Subrange< Index >; - static IndexMap splitRange( Index globalSize, CommunicationGroup group ) + static SubrangeType splitRange( Index globalSize, CommunicationGroup group ) { if( group != Communicator::NullGroup ) { const int rank = Communicator::GetRank( group ); const int partitions = Communicator::GetSize( group ); const Index begin = min( globalSize, rank * globalSize / partitions ); const Index end = min( globalSize, (rank + 1) * globalSize / partitions ); - return IndexMap( begin, end, globalSize ); + return SubrangeType( begin, end ); } else - return IndexMap( 0, 0, globalSize ); + return SubrangeType( 0, 0 ); } // Gets the owner of given global index. @@ -67,5 +63,11 @@ public: } }; +// TODO: +// - partitioner in deal.II stores also ghost indices: +// https://www.dealii.org/8.4.0/doxygen/deal.II/classUtilities_1_1MPI_1_1Partitioner.html +// - ghost indices are stored in a general IndexMap class (based on collection of subranges): +// https://www.dealii.org/8.4.0/doxygen/deal.II/classIndexSet.html + } // namespace DistributedContainers } // namespace TNL diff --git a/src/TNL/DistributedContainers/IndexMap.h b/src/TNL/DistributedContainers/Subrange.h similarity index 60% rename from src/TNL/DistributedContainers/IndexMap.h rename to src/TNL/DistributedContainers/Subrange.h index cc5444fd8f..8cff45b495 100644 --- a/src/TNL/DistributedContainers/IndexMap.h +++ b/src/TNL/DistributedContainers/Subrange.h @@ -1,5 +1,5 @@ /*************************************************************************** - IndexMap.h - description + Subrange.h - description ------------------- begin : Sep 6, 2018 copyright : (C) 2018 by Tomas Oberhuber et al. @@ -30,29 +30,26 @@ public: Subrange() = default; __cuda_callable__ - Subrange( Index begin, Index end, Index globalSize ) + Subrange( Index begin, Index end ) { - setSubrange( begin, end, globalSize ); + setSubrange( begin, end ); } // Sets the local subrange and global range size. __cuda_callable__ - void setSubrange( Index begin, Index end, Index globalSize ) + void setSubrange( Index begin, Index end ) { TNL_ASSERT_LE( begin, end, "begin must be before end" ); TNL_ASSERT_GE( begin, 0, "begin must be non-negative" ); - TNL_ASSERT_LE( end - begin, globalSize, "end of the subrange is outside of gloabl range" ); - offset = begin; - localSize = end - begin; - this->globalSize = globalSize; + this->begin = begin; + this->end = end; } __cuda_callable__ void reset() { - offset = 0; - localSize = 0; - globalSize = 0; + begin = 0; + end = 0; } static String getType() @@ -64,36 +61,37 @@ public: __cuda_callable__ bool isLocal( Index i ) const { - return offset <= i && i < offset + localSize; + return begin <= i && i < end; } - // Gets the offset of the subrange. + // Gets the begin of the subrange. __cuda_callable__ - Index getOffset() const + Index getBegin() const { - return offset; + return begin; } - // Gets number of local indices. + // Gets the begin of the subrange. __cuda_callable__ - Index getLocalSize() const + Index getEnd() const { - return localSize; + return end; } - // Gets number of global indices. + // Gets number of local indices. __cuda_callable__ - Index getGlobalSize() const + Index getSize() const { - return globalSize; + return end - begin; } // Gets local index for given global index. __cuda_callable__ Index getLocalIndex( Index i ) const { - TNL_ASSERT_TRUE( isLocal( i ), "Given global index was not found in the local index set." ); - return i - offset; + TNL_ASSERT_GE( i, getBegin(), "Given global index was not found in the local index set." ); + TNL_ASSERT_LT( i, getEnd(), "Given global index was not found in the local index set." ); + return i - begin; } // Gets global index for given local index. @@ -101,15 +99,14 @@ public: Index getGlobalIndex( Index i ) const { TNL_ASSERT_GE( i, 0, "Given local index was not found in the local index set." ); - TNL_ASSERT_LT( i, localSize, "Given local index was not found in the local index set." ); - return i + offset; + TNL_ASSERT_LT( i, getSize(), "Given local index was not found in the local index set." ); + return i + begin; } bool operator==( const Subrange& other ) const { - return offset == other.offset && - localSize == other.localSize && - globalSize == other.globalSize; + return begin == other.begin && + end == other.end; } bool operator!=( const Subrange& other ) const @@ -118,13 +115,9 @@ public: } protected: - Index offset = 0; - Index localSize = 0; - Index globalSize = 0; + Index begin = 0; + Index end = 0; }; -// TODO: implement a general IndexMap class, e.g. based on collection of subranges as in deal.II: -// https://www.dealii.org/8.4.0/doxygen/deal.II/classIndexSet.html - } // namespace DistributedContainers } // namespace TNL diff --git a/src/UnitTests/DistributedContainers/DistributedArrayTest.h b/src/UnitTests/DistributedContainers/DistributedArrayTest.h index 1d6f7b12e2..4a5b4007bb 100644 --- a/src/UnitTests/DistributedContainers/DistributedArrayTest.h +++ b/src/UnitTests/DistributedContainers/DistributedArrayTest.h @@ -34,7 +34,6 @@ protected: using DeviceType = typename DistributedArray::DeviceType; using CommunicatorType = typename DistributedArray::CommunicatorType; using IndexType = typename DistributedArray::IndexType; - using IndexMap = typename DistributedArray::IndexMapType; using DistributedArrayType = DistributedArray; using ArrayViewType = typename DistributedArrayType::LocalArrayViewType; using ArrayType = typename DistributedArrayType::LocalArrayType; @@ -48,24 +47,25 @@ protected: const int rank = CommunicatorType::GetRank(group); const int nproc = CommunicatorType::GetSize(group); - void SetUp() override + DistributedArrayTest() { - const IndexMap map = DistributedContainers::Partitioner< IndexMap, CommunicatorType >::splitRange( globalSize, group ); - distributedArray.setDistribution( map, group ); + using LocalRangeType = typename DistributedArray::LocalRangeType; + const LocalRangeType localRange = DistributedContainers::Partitioner< IndexType, CommunicatorType >::splitRange( globalSize, group ); + distributedArray.setDistribution( localRange, globalSize, group ); - ASSERT_EQ( distributedArray.getIndexMap(), map ); - ASSERT_EQ( distributedArray.getCommunicationGroup(), group ); + EXPECT_EQ( distributedArray.getLocalRange(), localRange ); + EXPECT_EQ( distributedArray.getCommunicationGroup(), group ); } }; // types for which DistributedArrayTest is instantiated using DistributedArrayTypes = ::testing::Types< - DistributedArray< double, Devices::Host, Communicators::MpiCommunicator, int, Subrange< int > >, - DistributedArray< double, Devices::Host, Communicators::NoDistrCommunicator, int, Subrange< int > > + DistributedArray< double, Devices::Host, int, Communicators::MpiCommunicator >, + DistributedArray< double, Devices::Host, int, Communicators::NoDistrCommunicator > #ifdef HAVE_CUDA , - DistributedArray< double, Devices::Cuda, Communicators::MpiCommunicator, int, Subrange< int > >, - DistributedArray< double, Devices::Cuda, Communicators::NoDistrCommunicator, int, Subrange< int > > + DistributedArray< double, Devices::Cuda, int, Communicators::MpiCommunicator >, + DistributedArray< double, Devices::Cuda, int, Communicators::NoDistrCommunicator > #endif >; @@ -132,16 +132,15 @@ TYPED_TEST( DistributedArrayTest, setValue ) TYPED_TEST( DistributedArrayTest, elementwiseAccess ) { using ArrayViewType = typename TestFixture::ArrayViewType; - using IndexMap = typename TestFixture::IndexMap; using IndexType = typename TestFixture::IndexType; this->distributedArray.setValue( 0 ); ArrayViewType localArrayView = this->distributedArray.getLocalArrayView(); - const IndexMap map = this->distributedArray.getIndexMap(); + const auto localRange = this->distributedArray.getLocalRange(); // check initial value for( IndexType i = 0; i < localArrayView.getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); EXPECT_EQ( localArrayView.getElement( i ), 0 ); EXPECT_EQ( this->distributedArray.getElement( gi ), 0 ); if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value ) @@ -150,13 +149,13 @@ TYPED_TEST( DistributedArrayTest, elementwiseAccess ) // use setValue for( IndexType i = 0; i < localArrayView.getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); this->distributedArray.setElement( gi, i + 1 ); } // check set value for( IndexType i = 0; i < localArrayView.getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); EXPECT_EQ( localArrayView.getElement( i ), i + 1 ); EXPECT_EQ( this->distributedArray.getElement( gi ), i + 1 ); if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value ) @@ -168,13 +167,13 @@ TYPED_TEST( DistributedArrayTest, elementwiseAccess ) // use operator[] if( std::is_same< typename TestFixture::DeviceType, Devices::Host >::value ) { for( IndexType i = 0; i < localArrayView.getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); this->distributedArray[ gi ] = i + 1; } // check set value for( IndexType i = 0; i < localArrayView.getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); EXPECT_EQ( localArrayView.getElement( i ), i + 1 ); EXPECT_EQ( this->distributedArray.getElement( gi ), i + 1 ); EXPECT_EQ( this->distributedArray[ gi ], i + 1 ); @@ -207,17 +206,16 @@ TYPED_TEST( DistributedArrayTest, copyAssignment ) TYPED_TEST( DistributedArrayTest, comparisonOperators ) { using DistributedArrayType = typename TestFixture::DistributedArrayType; - using IndexMap = typename TestFixture::IndexMap; using IndexType = typename TestFixture::IndexType; - const IndexMap map = this->distributedArray.getIndexMap(); + const auto localRange = this->distributedArray.getLocalRange(); DistributedArrayType& u = this->distributedArray; DistributedArrayType v, w; v.setLike( u ); w.setLike( u ); for( int i = 0; i < u.getLocalArrayView().getSize(); i ++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); u.setElement( gi, i ); v.setElement( gi, i ); w.setElement( gi, 2 * i ); @@ -242,11 +240,11 @@ TYPED_TEST( DistributedArrayTest, comparisonOperators ) TYPED_TEST( DistributedArrayTest, containsValue ) { using IndexType = typename TestFixture::IndexType; - using IndexMap = typename TestFixture::IndexMap; - const IndexMap map = this->distributedArray.getIndexMap(); + + const auto localRange = this->distributedArray.getLocalRange(); for( int i = 0; i < this->distributedArray.getLocalArrayView().getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); this->distributedArray.setElement( gi, i % 10 ); } @@ -260,11 +258,11 @@ TYPED_TEST( DistributedArrayTest, containsValue ) TYPED_TEST( DistributedArrayTest, containsOnlyValue ) { using IndexType = typename TestFixture::IndexType; - using IndexMap = typename TestFixture::IndexMap; - const IndexMap map = this->distributedArray.getIndexMap(); + + const auto localRange = this->distributedArray.getLocalRange(); for( int i = 0; i < this->distributedArray.getLocalArrayView().getSize(); i++ ) { - const IndexType gi = map.getGlobalIndex( i ); + const IndexType gi = localRange.getGlobalIndex( i ); this->distributedArray.setElement( gi, i % 10 ); } -- GitLab