Commit 9e6b9ade authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactoring Subrange and Partitioner

parent e55b52d3
Loading
Loading
Loading
Loading
+10 −10
Original line number Diff line number Diff line
@@ -17,16 +17,15 @@
#include <TNL/Containers/Array.h>
#include <TNL/Containers/ArrayView.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/DistributedContainers/IndexMap.h>
#include <TNL/DistributedContainers/Subrange.h>

namespace TNL {
namespace DistributedContainers {

template< typename Value,
          typename Device = Devices::Host,
          typename Communicator = Communicators::MpiCommunicator,
          typename Index = int,
          typename IndexMap = Subrange< Index > >
          typename Communicator = Communicators::MpiCommunicator >
class DistributedArray
: public Object
{
@@ -36,22 +35,22 @@ public:
   using DeviceType = Device;
   using CommunicatorType = Communicator;
   using IndexType = Index;
   using IndexMapType = IndexMap;
   using LocalRangeType = Subrange< Index >;
   using LocalArrayType = Containers::Array< Value, Device, Index >;
   using LocalArrayViewType = Containers::ArrayView< Value, Device, Index >;
   using ConstLocalArrayViewType = Containers::ArrayView< typename std::add_const< Value >::type, Device, Index >;
   using HostType = DistributedArray< Value, Devices::Host, Communicator, Index, IndexMap >;
   using CudaType = DistributedArray< Value, Devices::Cuda, Communicator, Index, IndexMap >;
   using HostType = DistributedArray< Value, Devices::Host, Index, Communicator >;
   using CudaType = DistributedArray< Value, Devices::Cuda, Index, Communicator >;

   DistributedArray() = default;

   DistributedArray( DistributedArray& ) = default;

   DistributedArray( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup );
   DistributedArray( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup );

   void setDistribution( IndexMap indexMap, CommunicationGroup group = Communicator::AllGroup );
   void setDistribution( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup );

   const IndexMap& getIndexMap() const;
   const LocalRangeType& getLocalRange() const;

   CommunicationGroup getCommunicationGroup() const;

@@ -125,7 +124,8 @@ public:
   // TODO: serialization (save, load, boundLoad)

protected:
   IndexMap indexMap;
   LocalRangeType localRange;
   IndexType globalSize = 0;
   CommunicationGroup group = Communicator::NullGroup;
   LocalArrayType localData;

+79 −100
Original line number Diff line number Diff line
@@ -22,49 +22,47 @@ namespace DistributedContainers {

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray( IndexMap indexMap, CommunicationGroup group )
          typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
{
   setDistribution( indexMap, group );
   setDistribution( localRange, globalSize, group );
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
setDistribution( IndexMap indexMap, CommunicationGroup group )
DistributedArray< Value, Device, Index, Communicator >::
setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
{
   this->indexMap = indexMap;
   TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" );
   this->localRange = localRange;
   this->globalSize = globalSize;
   this->group = group;
   if( group != Communicator::NullGroup )
      localData.setSize( indexMap.getLocalSize() );
      localData.setSize( localRange.getSize() );
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
const IndexMap&
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
getIndexMap() const
          typename Communicator >
const Subrange< Index >&
DistributedArray< Value, Device, Index, Communicator >::
getLocalRange() const
{
   return indexMap;
   return localRange;
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
typename Communicator::CommunicationGroup
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
getCommunicationGroup() const
{
   return group;
@@ -72,11 +70,10 @@ getCommunicationGroup() const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::LocalArrayViewType
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
          typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::LocalArrayViewType
DistributedArray< Value, Device, Index, Communicator >::
getLocalArrayView()
{
   return localData;
@@ -84,11 +81,10 @@ getLocalArrayView()

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
typename DistributedArray< Value, Device, Communicator, Index, IndexMap >::ConstLocalArrayViewType
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
          typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalArrayViewType
DistributedArray< Value, Device, Index, Communicator >::
getLocalArrayView() const
{
   return localData;
@@ -96,26 +92,24 @@ getLocalArrayView() const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
copyFromGlobal( ConstLocalArrayViewType globalArray )
{
   TNL_ASSERT_EQ( indexMap.getGlobalSize(), globalArray.getSize(),
   TNL_ASSERT_EQ( getSize(), globalArray.getSize(),
                  "given global array has different size than the distributed array" );

   LocalArrayViewType localView( localData );
   const IndexMap indexMap = getIndexMap();
   const LocalRangeType localRange = getLocalRange();

   auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      if( indexMap.isLocal( i ) )
         localView[ indexMap.getLocalIndex( i ) ] = globalArray[ i ];
      localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ];
   };

   ParallelFor< DeviceType >::exec( (IndexType) 0, indexMap.getGlobalSize(), kernel );
   ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel );
}


@@ -125,29 +119,26 @@ copyFromGlobal( ConstLocalArrayViewType globalArray )

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
String
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
getType()
{
   return String( "DistributedContainers::DistributedArray< " ) +
          TNL::getType< Value >() + ", " +
          Device::getDeviceType() + ", " +
          // TODO: communicators don't have a getType method
          "<Communicator>, " +
          TNL::getType< Index >() + ", " +
          IndexMap::getType() + " >";
          // TODO: communicators don't have a getType method
          "<Communicator> >";
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
String
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
getTypeVirtual() const
{
   return getType();
@@ -155,52 +146,50 @@ getTypeVirtual() const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
   template< typename Array >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
setLike( const Array& array )
{
   indexMap = array.getIndexMap();
   localRange = array.getLocalRange();
   globalSize = array.getSize();
   group = array.getCommunicationGroup();
   localData.setLike( array.getLocalArrayView() );
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
reset()
{
   indexMap.reset();
   localRange.reset();
   globalSize = 0;
   group = Communicator::NullGroup;
   localData.reset();
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
Index
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
getSize() const
{
   return indexMap.getGlobalSize();
   return globalSize;
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
setValue( ValueType value )
{
   localData.setValue( value );
@@ -208,65 +197,60 @@ setValue( ValueType value )

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
void
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
setElement( IndexType i, ValueType value )
{
   const IndexType li = indexMap.getLocalIndex( i );
   const IndexType li = localRange.getLocalIndex( i );
   localData.setElement( li, value );
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
Value
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
getElement( IndexType i ) const
{
   const IndexType li = indexMap.getLocalIndex( i );
   const IndexType li = localRange.getLocalIndex( i );
   return localData.getElement( li );
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
Value&
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
operator[]( IndexType i )
{
   const IndexType li = indexMap.getLocalIndex( i );
   const IndexType li = localRange.getLocalIndex( i );
   return localData[ li ];
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
__cuda_callable__
const Value&
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
operator[]( IndexType i ) const
{
   const IndexType li = indexMap.getLocalIndex( i );
   const IndexType li = localRange.getLocalIndex( i );
   return localData[ li ];
}

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
DistributedArray< Value, Device, Communicator, Index, IndexMap >&
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
          typename Communicator >
DistributedArray< Value, Device, Index, Communicator >&
DistributedArray< Value, Device, Index, Communicator >::
operator=( const DistributedArray& array )
{
   setLike( array );
@@ -276,12 +260,11 @@ operator=( const DistributedArray& array )

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
   template< typename Array >
DistributedArray< Value, Device, Communicator, Index, IndexMap >&
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >&
DistributedArray< Value, Device, Index, Communicator >::
operator=( const Array& array )
{
   setLike( array );
@@ -291,19 +274,19 @@ operator=( const Array& array )

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
   template< typename Array >
bool
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
operator==( const Array& array ) const
{
   // we can't run allreduce if the communication groups are different
   if( group != array.getCommunicationGroup() )
      return false;
   const bool localResult =
         indexMap == array.getIndexMap() &&
         localRange == array.getLocalRange() &&
         globalSize == array.getSize() &&
         localData == array.getLocalArrayView();
   bool result = true;
   if( group != CommunicatorType::NullGroup )
@@ -313,12 +296,11 @@ operator==( const Array& array ) const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
   template< typename Array >
bool
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
operator!=( const Array& array ) const
{
   return ! (*this == array);
@@ -326,11 +308,10 @@ operator!=( const Array& array ) const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
bool
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
containsValue( ValueType value ) const
{
   bool result = false;
@@ -343,11 +324,10 @@ containsValue( ValueType value ) const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
          typename Communicator >
bool
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
DistributedArray< Value, Device, Index, Communicator >::
containsOnlyValue( ValueType value ) const
{
   bool result = true;
@@ -360,10 +340,9 @@ containsOnlyValue( ValueType value ) const

template< typename Value,
          typename Device,
          typename Communicator,
          typename Index,
          typename IndexMap >
DistributedArray< Value, Device, Communicator, Index, IndexMap >::
          typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
operator bool() const
{
   return getSize() != 0;
+13 −11
Original line number Diff line number Diff line
/***************************************************************************
                          DistributedArray.h  -  description
                          Partitioner.h  -  description
                             -------------------
    begin                : Sep 6, 2018
    copyright            : (C) 2018 by Tomas Oberhuber et al.
@@ -12,35 +12,31 @@

#pragma once

#include "IndexMap.h"
#include "Subrange.h"

#include <TNL/Math.h>

namespace TNL {
namespace DistributedContainers {

template< typename IndexMap, typename Communicator >
class Partitioner
{};

template< typename Index, typename Communicator >
class Partitioner< Subrange< Index >, Communicator >
class Partitioner
{
   using CommunicationGroup = typename Communicator::CommunicationGroup;
public:
   using IndexMap = Subrange< Index >;
   using SubrangeType = Subrange< Index >;

   static IndexMap splitRange( Index globalSize, CommunicationGroup group )
   static SubrangeType splitRange( Index globalSize, CommunicationGroup group )
   {
      if( group != Communicator::NullGroup ) {
         const int rank = Communicator::GetRank( group );
         const int partitions = Communicator::GetSize( group );
         const Index begin = min( globalSize, rank * globalSize / partitions );
         const Index end = min( globalSize, (rank + 1) * globalSize / partitions );
         return IndexMap( begin, end, globalSize );
         return SubrangeType( begin, end );
      }
      else
         return IndexMap( 0, 0, globalSize );
         return SubrangeType( 0, 0 );
   }

   // Gets the owner of given global index.
@@ -67,5 +63,11 @@ public:
   }
};

// TODO:
// - partitioner in deal.II stores also ghost indices:
//   https://www.dealii.org/8.4.0/doxygen/deal.II/classUtilities_1_1MPI_1_1Partitioner.html
// - ghost indices are stored in a general IndexMap class (based on collection of subranges):
//   https://www.dealii.org/8.4.0/doxygen/deal.II/classIndexSet.html

} // namespace DistributedContainers
} // namespace TNL
+27 −34
Original line number Diff line number Diff line
/***************************************************************************
                          IndexMap.h  -  description
                          Subrange.h  -  description
                             -------------------
    begin                : Sep 6, 2018
    copyright            : (C) 2018 by Tomas Oberhuber et al.
@@ -30,29 +30,26 @@ public:
   Subrange() = default;

   __cuda_callable__
   Subrange( Index begin, Index end, Index globalSize )
   Subrange( Index begin, Index end )
   {
      setSubrange( begin, end, globalSize );
      setSubrange( begin, end );
   }

   // Sets the local subrange and global range size.
   __cuda_callable__
   void setSubrange( Index begin, Index end, Index globalSize )
   void setSubrange( Index begin, Index end )
   {
      TNL_ASSERT_LE( begin, end, "begin must be before end" );
      TNL_ASSERT_GE( begin, 0, "begin must be non-negative" );
      TNL_ASSERT_LE( end - begin, globalSize, "end of the subrange is outside of gloabl range" );
      offset = begin;
      localSize = end - begin;
      this->globalSize = globalSize;
      this->begin = begin;
      this->end = end;
   }

   __cuda_callable__
   void reset()
   {
      offset = 0;
      localSize = 0;
      globalSize = 0;
      begin = 0;
      end = 0;
   }

   static String getType()
@@ -64,36 +61,37 @@ public:
   __cuda_callable__
   bool isLocal( Index i ) const
   {
      return offset <= i && i < offset + localSize;
      return begin <= i && i < end;
   }

   // Gets the offset of the subrange.
   // Gets the begin of the subrange.
   __cuda_callable__
   Index getOffset() const
   Index getBegin() const
   {
      return offset;
      return begin;
   }

   // Gets number of local indices.
   // Gets the begin of the subrange.
   __cuda_callable__
   Index getLocalSize() const
   Index getEnd() const
   {
      return localSize;
      return end;
   }

   // Gets number of global indices.
   // Gets number of local indices.
   __cuda_callable__
   Index getGlobalSize() const
   Index getSize() const
   {
      return globalSize;
      return end - begin;
   }

   // Gets local index for given global index.
   __cuda_callable__
   Index getLocalIndex( Index i ) const
   {
      TNL_ASSERT_TRUE( isLocal( i ), "Given global index was not found in the local index set." );
      return i - offset;
      TNL_ASSERT_GE( i, getBegin(), "Given global index was not found in the local index set." );
      TNL_ASSERT_LT( i, getEnd(), "Given global index was not found in the local index set." );
      return i - begin;
   }

   // Gets global index for given local index.
@@ -101,15 +99,14 @@ public:
   Index getGlobalIndex( Index i ) const
   {
      TNL_ASSERT_GE( i, 0, "Given local index was not found in the local index set." );
      TNL_ASSERT_LT( i, localSize, "Given local index was not found in the local index set." );
      return i + offset;
      TNL_ASSERT_LT( i, getSize(), "Given local index was not found in the local index set." );
      return i + begin;
   }

   bool operator==( const Subrange& other ) const
   {
      return offset == other.offset &&
             localSize == other.localSize &&
             globalSize == other.globalSize;
      return begin == other.begin &&
             end == other.end;
   }

   bool operator!=( const Subrange& other ) const
@@ -118,13 +115,9 @@ public:
   }

protected:
   Index offset = 0;
   Index localSize = 0;
   Index globalSize = 0;
   Index begin = 0;
   Index end = 0;
};

// TODO: implement a general IndexMap class, e.g. based on collection of subranges as in deal.II:
// https://www.dealii.org/8.4.0/doxygen/deal.II/classIndexSet.html

} // namespace DistributedContainers
} // namespace TNL
+24 −26

File changed.

Preview size limit exceeded, changes collapsed.