Commit d13a2d18 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Implemented distributed prefix-sum

Fixes #43
parent 174ad5fd
Loading
Loading
Loading
Loading
+70 −0
Original line number Diff line number Diff line
/***************************************************************************
                          PrefixSum.h  -  description
                             -------------------
    begin                : Aug 16, 2019
    copyright            : (C) 2019 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

// Implemented by: Jakub Klinkovsky

#pragma once

#include <TNL/Containers/Algorithms/PrefixSum.h>
#include <TNL/Containers/Vector.h>

namespace TNL {
namespace Containers {
namespace Algorithms {

template< PrefixSumType Type >
struct DistributedPrefixSum
{
   template< typename DistributedVector,
             typename Reduction >
   static void
   perform( DistributedVector& v,
            typename DistributedVector::IndexType begin,
            typename DistributedVector::IndexType end,
            const Reduction& reduction,
            const typename DistributedVector::RealType zero )
   {
      using RealType = typename DistributedVector::RealType;
      using DeviceType = typename DistributedVector::DeviceType;
      using CommunicatorType = typename DistributedVector::CommunicatorType;

      const auto group = v.getCommunicationGroup();
      if( group != CommunicatorType::NullGroup ) {
         // adjust begin and end for the local range
         const auto localRange = v.getLocalRange();
         begin = min( max( begin, localRange.getBegin() ), localRange.getEnd() ) - localRange.getBegin();
         end = max( min( end, localRange.getEnd() ), localRange.getBegin() ) - localRange.getBegin();

         // perform first phase on the local data
         auto localView = v.getLocalView();
         const auto blockShifts = PrefixSum< DeviceType, Type >::performFirstPhase( localView, begin, end, reduction, zero );
         const RealType localSum = blockShifts.getElement( blockShifts.getSize() - 1 );

         // exchange local sums between ranks
         const int nproc = CommunicatorType::GetSize( group );
         RealType dataForScatter[ nproc ];
         for( int i = 0; i < nproc; i++ ) dataForScatter[ i ] = localSum;
         Vector< RealType, Devices::Host > rankSums( nproc );
         // NOTE: exchanging general data types does not work with MPI
         CommunicatorType::Alltoall( dataForScatter, 1, rankSums.getData(), 1, group );

         // compute prefix-sum of the per-rank sums
         PrefixSum< Devices::Host, PrefixSumType::Exclusive >::perform( rankSums, 0, nproc, reduction, zero );

         // perform second phase: shift by the per-block and per-rank offsets
         const int rank = CommunicatorType::GetRank( group );
         PrefixSum< DeviceType, Type >::performSecondPhase( localView, blockShifts, begin, end, reduction, rankSums[ rank ] );
      }
   }
};

} // namespace Algorithms
} // namespace Containers
} // namespace TNL
+2 −7
Original line number Diff line number Diff line
@@ -127,13 +127,8 @@ public:
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVector& operator/=( const Vector& vector );

   void computePrefixSum();

   void computePrefixSum( IndexType begin, IndexType end );

   void computeExclusivePrefixSum();

   void computeExclusivePrefixSum( IndexType begin, IndexType end );
   template< Algorithms::PrefixSumType Type = Algorithms::PrefixSumType::Inclusive >
   void prefixSum( IndexType begin = 0, IndexType end = 0 );
};

} // namespace Containers
+6 −36
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@
#pragma once

#include "DistributedVector.h"
#include <TNL/Exceptions/NotImplementedError.h>
#include <TNL/Containers/Algorithms/DistributedPrefixSum.h>

namespace TNL {
namespace Containers {
@@ -298,44 +298,14 @@ template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
   template< Algorithms::PrefixSumType Type >
void
DistributedVector< Real, Device, Index, Communicator >::
computePrefixSum()
prefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVector< Real, Device, Index, Communicator >::
computePrefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVector< Real, Device, Index, Communicator >::
computeExclusivePrefixSum()
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVector< Real, Device, Index, Communicator >::
computeExclusivePrefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
   if( end == 0 )
      end = this->getSize();
   Algorithms::DistributedPrefixSum< Type >::perform( *this, begin, end, std::plus<>{}, (RealType) 0.0 );
}

} // namespace Containers
+2 −7
Original line number Diff line number Diff line
@@ -127,13 +127,8 @@ public:
             typename = std::enable_if_t< HasSubscriptOperator<Vector>::value > >
   DistributedVectorView& operator/=( const Vector& vector );

   void computePrefixSum();

   void computePrefixSum( IndexType begin, IndexType end );

   void computeExclusivePrefixSum();

   void computeExclusivePrefixSum( IndexType begin, IndexType end );
   template< Algorithms::PrefixSumType Type = Algorithms::PrefixSumType::Inclusive >
   void prefixSum( IndexType begin = 0, IndexType end = 0 );
};

} // namespace Containers
+6 −36
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@
#pragma once

#include "DistributedVectorView.h"
#include <TNL/Exceptions/NotImplementedError.h>
#include <TNL/Containers/Algorithms/DistributedPrefixSum.h>

namespace TNL {
namespace Containers {
@@ -274,44 +274,14 @@ template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
   template< Algorithms::PrefixSumType Type >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum()
prefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computePrefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum()
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
}

template< typename Real,
          typename Device,
          typename Index,
          typename Communicator >
void
DistributedVectorView< Real, Device, Index, Communicator >::
computeExclusivePrefixSum( IndexType begin, IndexType end )
{
   throw Exceptions::NotImplementedError("Distributed prefix sum is not implemented yet.");
   if( end == 0 )
      end = this->getSize();
   Algorithms::DistributedPrefixSum< Type >::perform( *this, begin, end, std::plus<>{}, (RealType) 0.0 );
}

} // namespace Containers
Loading