Commit 4cf35454 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactoring DistributedArray - implementation via DistributedArrayView as a data member

parent 40cd3071
Loading
Loading
Loading
Loading
+1 −3
Original line number Diff line number Diff line
@@ -168,9 +168,7 @@ public:
   // TODO: serialization (save, load)

protected:
   LocalRangeType localRange;
   IndexType globalSize = 0;
   CommunicationGroup group = Communicator::NullGroup;
   ViewType view;
   LocalArrayType localData;
};

+21 −62
Original line number Diff line number Diff line
@@ -39,11 +39,9 @@ DistributedArray< Value, Device, Index, Communicator >::
setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
{
   TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" );
   this->localRange = localRange;
   this->globalSize = globalSize;
   this->group = group;
   if( group != Communicator::NullGroup )
      localData.setSize( localRange.getSize() );
   view.bind( localRange, globalSize, group, localData.getView() );
}

template< typename Value,
@@ -54,7 +52,7 @@ const Subrange< Index >&
DistributedArray< Value, Device, Index, Communicator >::
getLocalRange() const
{
   return localRange;
   return view.getLocalRange();
}

template< typename Value,
@@ -65,7 +63,7 @@ typename Communicator::CommunicationGroup
DistributedArray< Value, Device, Index, Communicator >::
getCommunicationGroup() const
{
   return group;
   return view.getCommunicationGroup();
}

template< typename Value,
@@ -99,18 +97,7 @@ void
DistributedArray< Value, Device, Index, Communicator >::
copyFromGlobal( ConstLocalViewType globalArray )
{
   TNL_ASSERT_EQ( getSize(), globalArray.getSize(),
                  "given global array has different size than the distributed array" );

   LocalViewType localView( localData );
   const LocalRangeType localRange = getLocalRange();

   auto kernel = [=] __cuda_callable__ ( IndexType i ) mutable
   {
      localView[ i ] = globalArray[ localRange.getGlobalIndex( i ) ];
   };

   Algorithms::ParallelFor< DeviceType >::exec( (IndexType) 0, localRange.getSize(), kernel );
   view.copyFromGlobal( globalArray );
}


@@ -126,7 +113,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ViewType
DistributedArray< Value, Device, Index, Communicator >::
getView()
{
   return ViewType( getLocalRange(), getSize(), getCommunicationGroup(), getLocalView() );
   return view;
}

template< typename Value,
@@ -137,7 +124,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstViewType
DistributedArray< Value, Device, Index, Communicator >::
getConstView() const
{
   return ConstViewType( getLocalRange(), getSize(), getCommunicationGroup(), getConstLocalView() );
   return view.getConstView();
}

template< typename Value,
@@ -169,10 +156,8 @@ void
DistributedArray< Value, Device, Index, Communicator >::
setLike( const Array& array )
{
   localRange = array.getLocalRange();
   globalSize = array.getSize();
   group = array.getCommunicationGroup();
   localData.setLike( array.getConstLocalView() );
   view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() );
}

template< typename Value,
@@ -183,9 +168,7 @@ void
DistributedArray< Value, Device, Index, Communicator >::
reset()
{
   localRange.reset();
   globalSize = 0;
   group = Communicator::NullGroup;
   view.reset();
   localData.reset();
}

@@ -197,7 +180,7 @@ bool
DistributedArray< Value, Device, Index, Communicator >::
empty() const
{
   return getSize() == 0;
   return view.empty();
}

template< typename Value,
@@ -208,7 +191,7 @@ Index
DistributedArray< Value, Device, Index, Communicator >::
getSize() const
{
   return globalSize;
   return view.getSize();
}

template< typename Value,
@@ -219,7 +202,7 @@ void
DistributedArray< Value, Device, Index, Communicator >::
setValue( ValueType value )
{
   localData.setValue( value );
   view.setValue( value );
}

template< typename Value,
@@ -230,8 +213,7 @@ void
DistributedArray< Value, Device, Index, Communicator >::
setElement( IndexType i, ValueType value )
{
   const IndexType li = localRange.getLocalIndex( i );
   localData.setElement( li, value );
   view.setElement( i, value );
}

template< typename Value,
@@ -242,8 +224,7 @@ Value
DistributedArray< Value, Device, Index, Communicator >::
getElement( IndexType i ) const
{
   const IndexType li = localRange.getLocalIndex( i );
   return localData.getElement( li );
   return view.getElement( i );
}

template< typename Value,
@@ -255,8 +236,7 @@ Value&
DistributedArray< Value, Device, Index, Communicator >::
operator[]( IndexType i )
{
   const IndexType li = localRange.getLocalIndex( i );
   return localData[ li ];
   return view[ i ];
}

template< typename Value,
@@ -268,8 +248,7 @@ const Value&
DistributedArray< Value, Device, Index, Communicator >::
operator[]( IndexType i ) const
{
   const IndexType li = localRange.getLocalIndex( i );
   return localData[ li ];
   return view[ i ];
}

template< typename Value,
@@ -281,7 +260,7 @@ DistributedArray< Value, Device, Index, Communicator >::
operator=( const DistributedArray& array )
{
   setLike( array );
   localData = array.getConstLocalView();
   view = array;
   return *this;
}

@@ -295,7 +274,7 @@ DistributedArray< Value, Device, Index, Communicator >::
operator=( const Array& array )
{
   setLike( array );
   localData = array.getConstLocalView();
   view = array;
   return *this;
}

@@ -308,17 +287,7 @@ bool
DistributedArray< Value, Device, Index, Communicator >::
operator==( const Array& array ) const
{
   // we can't run allreduce if the communication groups are different
   if( group != array.getCommunicationGroup() )
      return false;
   const bool localResult =
         localRange == array.getLocalRange() &&
         globalSize == array.getSize() &&
         localData == array.getConstLocalView();
   bool result = true;
   if( group != CommunicatorType::NullGroup )
      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
   return result;
   return view == array;
}

template< typename Value,
@@ -330,7 +299,7 @@ bool
DistributedArray< Value, Device, Index, Communicator >::
operator!=( const Array& array ) const
{
   return ! (*this == array);
   return view != array;
}

template< typename Value,
@@ -341,12 +310,7 @@ bool
DistributedArray< Value, Device, Index, Communicator >::
containsValue( ValueType value ) const
{
   bool result = false;
   if( group != CommunicatorType::NullGroup ) {
      const bool localResult = localData.containsValue( value );
      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LOR, group );
   }
   return result;
   return view.containsValue( value );
}

template< typename Value,
@@ -357,12 +321,7 @@ bool
DistributedArray< Value, Device, Index, Communicator >::
containsOnlyValue( ValueType value ) const
{
   bool result = true;
   if( group != CommunicatorType::NullGroup ) {
      const bool localResult = localData.containsOnlyValue( value );
      CommunicatorType::Allreduce( &localResult, &result, 1, MPI_LAND, group );
   }
   return result;
   return view.containsOnlyValue( value );
}

} // namespace Containers
+6 −3
Original line number Diff line number Diff line
@@ -74,9 +74,12 @@ public:
   __cuda_callable__
   DistributedArrayView( DistributedArrayView&& ) = default;

   // method for rebinding (reinitialization)
   // Note that you can also bind directly to Array and other types implicitly
   // convertible to ArrayView.
   // method for rebinding (reinitialization) to raw data
   __cuda_callable__
   void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData );

   // Note that you can also bind directly to DistributedArray and other types implicitly
   // convertible to DistributedArrayView.
   __cuda_callable__
   void bind( DistributedArrayView view );

+18 −0
Original line number Diff line number Diff line
@@ -31,6 +31,24 @@ DistributedArrayView( const DistributedArrayView< Value_, Device, Index, Communi
  localData( view.getConstLocalView() )
{}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
__cuda_callable__
void
DistributedArrayView< Value, Device, Index, Communicator >::
bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData )
{
   TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(),
                  "The local array size does not match the local range of the distributed array." );

   this->localRange = localRange;
   this->globalSize = globalSize;
   this->group = group;
   this->localData.bind( localData );
}

template< typename Value,
          typename Device,
          typename Index,