Commit f0b42e43 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added support for ghost ranges to DistributedArray and DistributedVector and their views

parent 5184793e
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -228,10 +228,10 @@ struct SpmvBenchmark
      const auto group = CommunicatorType::AllGroup;
      const auto localRange = Partitioner::splitRange( matrix.getRows(), group );
      DistributedMatrix distributedMatrix( localRange, matrix.getRows(), matrix.getColumns(), group );
      DistributedVector distributedVector( localRange, matrix.getRows(), group );
      DistributedVector distributedVector( localRange, 0, matrix.getRows(), group );

      // copy the row lengths from the global matrix to the distributed matrix
      DistributedRowLengths distributedRowLengths( localRange, matrix.getRows(), group );
      DistributedRowLengths distributedRowLengths( localRange, 0, matrix.getRows(), group );
      for( IndexType i = 0; i < distributedMatrix.getLocalMatrix().getRows(); i++ ) {
         const auto gi = distributedMatrix.getLocalRowRange().getGlobalIndex( i );
         distributedRowLengths[ gi ] = matrix.getRowCapacity( gi );
+3 −3
Original line number Diff line number Diff line
@@ -435,11 +435,11 @@ struct LinearSolversBenchmark
      const auto group = CommunicatorType::AllGroup;
      const auto localRange = Partitioner::splitRange( matrixPointer->getRows(), group );
      SharedPointer< DistributedMatrix > distMatrixPointer( localRange, matrixPointer->getRows(), matrixPointer->getColumns(), group );
      DistributedVector dist_x0( localRange, matrixPointer->getRows(), group );
      DistributedVector dist_b( localRange, matrixPointer->getRows(), group );
      DistributedVector dist_x0( localRange, 0, matrixPointer->getRows(), group );
      DistributedVector dist_b( localRange, 0, matrixPointer->getRows(), group );

      // copy the row capacities from the global matrix to the distributed matrix
      DistributedRowLengths distributedRowLengths( localRange, matrixPointer->getRows(), group );
      DistributedRowLengths distributedRowLengths( localRange, 0, matrixPointer->getRows(), group );
      for( IndexType i = 0; i < distMatrixPointer->getLocalMatrix().getRows(); i++ ) {
         const auto gi = distMatrixPointer->getLocalRowRange().getGlobalIndex( i );
         distributedRowLengths[ gi ] = matrixPointer->getRowCapacity( gi );
+39 −21
Original line number Diff line number Diff line
@@ -37,6 +37,7 @@ public:
   using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >;
   using ViewType = DistributedArrayView< Value, Device, Index, Communicator >;
   using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >;
   using SynchronizerType = typename ViewType::SynchronizerType;

   /**
    * \brief A template which allows to quickly obtain a \ref DistributedArray type with changed template parameters.
@@ -50,46 +51,54 @@ public:

   DistributedArray() = default;

   DistributedArray( const DistributedArray& ) = default;
   // Copy-constructor does deep copy.
   DistributedArray( const DistributedArray& );

   DistributedArray( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup );
   DistributedArray( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup );

   void setDistribution( LocalRangeType localRange, Index globalSize, CommunicationGroup group = Communicator::AllGroup );
   void setDistribution( LocalRangeType localRange, Index ghosts, Index globalSize, CommunicationGroup group = Communicator::AllGroup );

   const LocalRangeType& getLocalRange() const;

   IndexType getGhosts() const;

   CommunicationGroup getCommunicationGroup() const;

   /**
    * \brief Returns a modifiable view of the local part of the array.
    *
    * If \e begin or \e end is set to a non-zero value, a view for the
    * sub-interval `[begin, end)` is returned. Otherwise a view for whole
    * local part of the array view is returned.
    *
    * \param begin The beginning of the array view sub-interval. It is 0 by
    *              default.
    * \param end The end of the array view sub-interval. The default value is 0
    *            which is, however, replaced with the array size.
    */
   LocalViewType getLocalView();

   /**
    * \brief Returns a non-modifiable view of the local part of the array.
    *
    * If \e begin or \e end is set to a non-zero value, a view for the
    * sub-interval `[begin, end)` is returned. Otherwise a view for whole
    * local part of the array view is returned.
    *
    * \param begin The beginning of the array view sub-interval. It is 0 by
    *              default.
    * \param end The end of the array view sub-interval. The default value is 0
    *            which is, however, replaced with the array size.
    */
   ConstLocalViewType getConstLocalView() const;

   /**
    * \brief Returns a modifiable view of the local part of the array,
    * including ghost values.
    */
   LocalViewType getLocalViewWithGhosts();

   /**
    * \brief Returns a non-modifiable view of the local part of the array,
    * including ghost values.
    */
   ConstLocalViewType getConstLocalViewWithGhosts() const;

   void copyFromGlobal( ConstLocalViewType globalArray );

   // synchronizer stuff
   void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 );

   std::shared_ptr< SynchronizerType > getSynchronizer() const;

   int getValuesPerElement() const;

   void startSynchronization();

   void waitForSynchronization() const;


   // Usual Array methods follow below.

@@ -170,6 +179,15 @@ public:
protected:
   ViewType view;
   LocalArrayType localData;

private:
   template< typename Array, std::enable_if_t< std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true >
   static void setSynchronizerHelper( ViewType& view, const Array& array )
   {
      view.setSynchronizer( array.getSynchronizer(), array.getValuesPerElement() );
   }
   template< typename Array, std::enable_if_t< ! std::is_same< typename Array::DeviceType, DeviceType >::value, bool > = true >
   static void setSynchronizerHelper( ViewType& view, const Array& array ) {}
};

} // namespace Containers
+111 −9
Original line number Diff line number Diff line
@@ -25,9 +25,20 @@ template< typename Value,
          typename Index,
          typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
DistributedArray( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
DistributedArray( const DistributedArray& array )
{
   setDistribution( localRange, globalSize, group );
   setLike( array );
   localData = array.getConstLocalViewWithGhosts();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
DistributedArray< Value, Device, Index, Communicator >::
DistributedArray( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group )
{
   setDistribution( localRange, ghosts, globalSize, group );
}

template< typename Value,
@@ -36,12 +47,12 @@ template< typename Value,
          typename Communicator >
void
DistributedArray< Value, Device, Index, Communicator >::
setDistribution( LocalRangeType localRange, IndexType globalSize, CommunicationGroup group )
setDistribution( LocalRangeType localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group )
{
   TNL_ASSERT_LE( localRange.getEnd(), globalSize, "end of the local range is outside of the global range" );
   if( group != Communicator::NullGroup )
      localData.setSize( localRange.getSize() );
   view.bind( localRange, globalSize, group, localData.getView() );
      localData.setSize( localRange.getSize() + ghosts );
   view.bind( localRange, ghosts, globalSize, group, localData.getView() );
}

template< typename Value,
@@ -55,6 +66,17 @@ getLocalRange() const
   return view.getLocalRange();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
Index
DistributedArray< Value, Device, Index, Communicator >::
getGhosts() const
{
   return view.getGhosts();
}

template< typename Value,
          typename Device,
          typename Index,
@@ -74,7 +96,7 @@ typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType
DistributedArray< Value, Device, Index, Communicator >::
getLocalView()
{
   return localData.getView();
   return view.getLocalView();
}

template< typename Value,
@@ -85,7 +107,29 @@ typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewT
DistributedArray< Value, Device, Index, Communicator >::
getConstLocalView() const
{
   return localData.getConstView();
   return view.getConstLocalView();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::LocalViewType
DistributedArray< Value, Device, Index, Communicator >::
getLocalViewWithGhosts()
{
   return view.getLocalViewWithGhosts();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
typename DistributedArray< Value, Device, Index, Communicator >::ConstLocalViewType
DistributedArray< Value, Device, Index, Communicator >::
getConstLocalViewWithGhosts() const
{
   return view.getConstLocalViewWithGhosts();
}


@@ -100,6 +144,61 @@ copyFromGlobal( ConstLocalViewType globalArray )
   view.copyFromGlobal( globalArray );
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedArray< Value, Device, Index, Communicator >::
setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement )
{
   view.setSynchronizer( synchronizer, valuesPerElement );
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
std::shared_ptr< typename DistributedArrayView< Value, Device, Index, Communicator >::SynchronizerType >
DistributedArray< Value, Device, Index, Communicator >::
getSynchronizer() const
{
   return view.getSynchronizer();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
int
DistributedArray< Value, Device, Index, Communicator >::
getValuesPerElement() const
{
   return view.getValuesPerElement();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedArray< Value, Device, Index, Communicator >::
startSynchronization()
{
   view.startSynchronization();
}

template< typename Value,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedArray< Value, Device, Index, Communicator >::
waitForSynchronization() const
{
   view.waitForSynchronization();
}


/*
 * Usual Array methods follow below.
@@ -156,8 +255,11 @@ void
DistributedArray< Value, Device, Index, Communicator >::
setLike( const Array& array )
{
   localData.setLike( array.getConstLocalView() );
   view.bind( array.getLocalRange(), array.getSize(), array.getCommunicationGroup(), localData.getView() );
   localData.setLike( array.getConstLocalViewWithGhosts() );
   view.bind( array.getLocalRange(), array.getGhosts(), array.getSize(), array.getCommunicationGroup(), localData.getView() );
   // set, but do not unset, the synchronizer
   if( array.getSynchronizer() )
      setSynchronizerHelper( view, array );
}

template< typename Value,
+31 −5
Original line number Diff line number Diff line
@@ -12,9 +12,12 @@

#pragma once

#include <memory>

#include <TNL/Containers/ArrayView.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/Containers/Subrange.h>
#include <TNL/Containers/ByteArraySynchronizer.h>

namespace TNL {
namespace Containers {
@@ -36,6 +39,7 @@ public:
   using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >;
   using ViewType = DistributedArrayView< Value, Device, Index, Communicator >;
   using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >;
   using SynchronizerType = ByteArraySynchronizer< DeviceType, IndexType >;

   /**
    * \brief A template which allows to quickly obtain a \ref DistributedArrayView type with changed template parameters.
@@ -48,11 +52,12 @@ public:


   // Initialization by raw data
   DistributedArrayView( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData )
   : localRange(localRange), globalSize(globalSize), group(group), localData(localData)
   DistributedArrayView( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData )
   : localRange(localRange), ghosts(ghosts), globalSize(globalSize), group(group), localData(localData)
   {
      TNL_ASSERT_EQ( localData.getSize(), localRange.getSize(),
      TNL_ASSERT_EQ( localData.getSize(), localRange.getSize() + ghosts,
                     "The local array size does not match the local range of the distributed array." );
      TNL_ASSERT_GE( ghosts, 0, "The ghosts count must be non-negative." );
   }

   DistributedArrayView() = default;
@@ -68,27 +73,44 @@ public:
   DistributedArrayView( DistributedArrayView&& ) = default;

   // method for rebinding (reinitialization) to raw data
   void bind( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData );
   void bind( const LocalRangeType& localRange, IndexType ghosts, IndexType globalSize, CommunicationGroup group, LocalViewType localData );

   // Note that you can also bind directly to DistributedArray and other types implicitly
   // convertible to DistributedArrayView.
   void bind( DistributedArrayView view );

   // binding to local array via raw pointer
   // (local range, global size and communication group are preserved)
   // (local range, ghosts, global size and communication group are preserved)
   template< typename Value_ >
   void bind( Value_* data, IndexType localSize );

   const LocalRangeType& getLocalRange() const;

   IndexType getGhosts() const;

   CommunicationGroup getCommunicationGroup() const;

   LocalViewType getLocalView();

   ConstLocalViewType getConstLocalView() const;

   LocalViewType getLocalViewWithGhosts();

   ConstLocalViewType getConstLocalViewWithGhosts() const;

   void copyFromGlobal( ConstLocalViewType globalArray );

   // synchronizer stuff
   void setSynchronizer( std::shared_ptr< SynchronizerType > synchronizer, int valuesPerElement = 1 );

   std::shared_ptr< SynchronizerType > getSynchronizer() const;

   int getValuesPerElement() const;

   void startSynchronization();

   void waitForSynchronization() const;


   /*
    * Usual ArrayView methods follow below.
@@ -156,9 +178,13 @@ public:

protected:
   LocalRangeType localRange;
   IndexType ghosts = 0;
   IndexType globalSize = 0;
   CommunicationGroup group = Communicator::NullGroup;
   LocalViewType localData;

   std::shared_ptr< SynchronizerType > synchronizer = nullptr;
   int valuesPerElement = 1;
};

} // namespace Containers
Loading