Commit 05634323 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

DistributedNDArray: added methods getLocalIndexer, getLocalView,...

DistributedNDArray: added methods getLocalIndexer, getLocalView, getConstLocalView, getStorageIndex and getData
parent e4e780e1
Loading
Loading
Loading
Loading
+58 −0
Original line number Diff line number Diff line
@@ -36,9 +36,12 @@ public:
   using LocalBeginsType = __ndarray_impl::LocalBeginsHolder< typename NDArray::SizesHolderType >;
   using LocalRangeType = Subrange< IndexType >;
   using OverlapsType = Overlaps;
   using LocalIndexerType = NDArrayIndexer< SizesHolderType, PermutationType, typename NDArray::NDBaseType, typename NDArray::StridesHolderType, Overlaps >;

   using ViewType = DistributedNDArrayView< typename NDArray::ViewType, Communicator, Overlaps >;
   using ConstViewType = DistributedNDArrayView< typename NDArray::ConstViewType, Communicator, Overlaps >;
   using LocalViewType = typename NDArray::ViewType;
   using ConstLocalViewType = typename NDArray::ConstViewType;

   static_assert( Overlaps::size() == NDArray::getDimension(), "invalid overlaps" );

@@ -59,6 +62,18 @@ public:
   DistributedNDArray( DistributedNDArray&& ) = default;
   DistributedNDArray& operator=( DistributedNDArray&& ) = default;

   // Templated copy-assignment
   template< typename OtherArray >
   DistributedNDArray& operator=( const OtherArray& other )
   {
      globalSizes = other.getSizes();
      localBegins = other.getLocalBegins();
      localEnds = other.getLocalEnds();
      group = other.getCommunicationGroup();
      localArray = other.getConstLocalView();
      return *this;
   }

   static constexpr std::size_t getDimension()
   {
      return NDArray::getDimension();
@@ -111,6 +126,49 @@ public:
      return localArray.getStorageSize();
   }

   LocalIndexerType getLocalIndexer() const
   {
      return LocalIndexerType( localEnds - localBegins, typename NDArray::StridesHolderType{} );
   }

   LocalViewType getLocalView()
   {
      return localArray.getView();
   }

   ConstLocalViewType getConstLocalView() const
   {
      return localArray.getConstView();
   }

   // returns the *local* storage index for given *global* indices
   template< typename... IndexTypes >
   __cuda_callable__
   IndexType
   getStorageIndex( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == SizesHolderType::getDimension(), "got wrong number of indices" );
      __ndarray_impl::assertIndicesInRange( localBegins, localEnds, Overlaps{}, std::forward< IndexTypes >( indices )... );
      auto getStorageIndex = [this]( auto&&... indices )
      {
         return this->localArray.getStorageIndex( std::forward< decltype(indices) >( indices )... );
      };
      return __ndarray_impl::call_with_unshifted_indices< LocalBeginsType, Overlaps >( localBegins, getStorageIndex, std::forward< IndexTypes >( indices )... );
   }

   __cuda_callable__
   ValueType* getData()
   {
      return localArray.getData();
   }

   __cuda_callable__
   std::add_const_t< ValueType >* getData() const
   {
      return localArray.getData();
   }


   template< typename... IndexTypes >
   __cuda_callable__
   ValueType&
+74 −1
Original line number Diff line number Diff line
@@ -35,9 +35,12 @@ public:
   using LocalBeginsType = __ndarray_impl::LocalBeginsHolder< typename NDArrayView::SizesHolderType >;
   using LocalRangeType = Subrange< IndexType >;
   using OverlapsType = Overlaps;
   using LocalIndexerType = NDArrayIndexer< SizesHolderType, PermutationType, typename NDArrayView::NDBaseType, typename NDArrayView::StridesHolderType, Overlaps >;

   using ViewType = DistributedNDArrayView< NDArrayView, Communicator, Overlaps >;
   using ConstViewType = DistributedNDArrayView< typename NDArrayView::ConstViewType, Communicator, Overlaps >;
   using LocalViewType = NDArrayView;
   using ConstLocalViewType = typename NDArrayView::ConstViewType;

   static_assert( Overlaps::size() == NDArrayView::getDimension(), "invalid overlaps" );

@@ -67,7 +70,19 @@ public:
   // There is no move-assignment operator, so expressions like `a = b.getView()`
   // are resolved as copy-assignment.

   // method for rebinding (reinitialization)
   // Templated copy-assignment
   template< typename OtherArray >
   DistributedNDArrayView& operator=( const OtherArray& other )
   {
      globalSizes = other.getSizes();
      localBegins = other.getLocalBegins();
      localEnds = other.getLocalEnds();
      group = other.getCommunicationGroup();
      localView = other.getConstLocalView();
      return *this;
   }

   // methods for rebinding (reinitialization)
   __cuda_callable__
   void bind( DistributedNDArrayView view )
   {
@@ -78,6 +93,21 @@ public:
      localEnds = view.localEnds;
   }

   // binds to the given raw pointer and changes the indexer
   __cuda_callable__
   void bind( ValueType* data, LocalIndexerType indexer )
   {
      localView.bind( data, indexer );
      localView.bind( data );
   }

   // binds to the given raw pointer and preserves the current indexer
   __cuda_callable__
   void bind( ValueType* data )
   {
      localView.bind( data );
   }

   __cuda_callable__
   void reset()
   {
@@ -140,6 +170,49 @@ public:
      return localView.getStorageSize();
   }

   LocalIndexerType getLocalIndexer() const
   {
      return LocalIndexerType( localEnds - localBegins, typename NDArrayView::StridesHolderType{} );
   }

   LocalViewType getLocalView()
   {
      return localView;
   }

   ConstLocalViewType getConstLocalView() const
   {
      return localView.getConstView();
   }

   // returns the *local* storage index for given *global* indices
   template< typename... IndexTypes >
   __cuda_callable__
   IndexType
   getStorageIndex( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == SizesHolderType::getDimension(), "got wrong number of indices" );
      __ndarray_impl::assertIndicesInRange( localBegins, localEnds, Overlaps{}, std::forward< IndexTypes >( indices )... );
      auto getStorageIndex = [this]( auto&&... indices )
      {
         return this->localView.getStorageIndex( std::forward< decltype(indices) >( indices )... );
      };
      return __ndarray_impl::call_with_unshifted_indices< LocalBeginsType, Overlaps >( localBegins, getStorageIndex, std::forward< IndexTypes >( indices )... );
   }

   __cuda_callable__
   ValueType* getData()
   {
      return localView.getData();
   }

   __cuda_callable__
   std::add_const_t< ValueType >* getData() const
   {
      return localView.getData();
   }


   template< typename... IndexTypes >
   __cuda_callable__
   ValueType&
+16 −1
Original line number Diff line number Diff line
@@ -98,7 +98,7 @@ public:
   // There is no move-assignment operator, so expressions like `a = b.getView()`
   // are resolved as copy-assignment.

   // method for rebinding (reinitialization)
   // methods for rebinding (reinitialization)
   __cuda_callable__
   void bind( NDArrayView view )
   {
@@ -106,6 +106,21 @@ public:
      array = view.array;
   }

   // binds to the given raw pointer and changes the indexer
   __cuda_callable__
   void bind( Value* data, IndexerType indexer )
   {
      IndexerType::operator=( indexer );
      array = data;
   }

   // binds to the given raw pointer and preserves the current indexer
   __cuda_callable__
   void bind( Value* data )
   {
      array = data;
   }

   __cuda_callable__
   void reset()
   {
+44 −0
Original line number Diff line number Diff line
@@ -144,6 +144,28 @@ struct SizesHolderSizePrinter
   }
};

template< std::size_t level >
struct SizesHolerOperatorPlusHelper
{
   template< typename Result, typename LHS, typename RHS >
   static void exec( Result& result, const LHS& lhs, const RHS& rhs )
   {
      if( result.template getStaticSize< level >() == 0 )
         result.template setSize< level >( lhs.template getSize< level >() + rhs.template getSize< level >() );
   }
};

template< std::size_t level >
struct SizesHolerOperatorMinusHelper
{
   template< typename Result, typename LHS, typename RHS >
   static void exec( Result& result, const LHS& lhs, const RHS& rhs )
   {
      if( result.template getStaticSize< level >() == 0 )
         result.template setSize< level >( lhs.template getSize< level >() - rhs.template getSize< level >() );
   }
};

} // namespace __ndarray_impl


@@ -202,6 +224,28 @@ public:
   }
};

template< typename Index,
          std::size_t... sizes,
          typename OtherHolder >
SizesHolder< Index, sizes... >
operator+( const SizesHolder< Index, sizes... >& lhs, const OtherHolder& rhs )
{
   SizesHolder< Index, sizes... > result;
   TemplateStaticFor< std::size_t, 0, sizeof...(sizes), __ndarray_impl::SizesHolerOperatorPlusHelper >::execHost( result, lhs, rhs );
   return result;
}

template< typename Index,
          std::size_t... sizes,
          typename OtherHolder >
SizesHolder< Index, sizes... >
operator-( const SizesHolder< Index, sizes... >& lhs, const OtherHolder& rhs )
{
   SizesHolder< Index, sizes... > result;
   TemplateStaticFor< std::size_t, 0, sizeof...(sizes), __ndarray_impl::SizesHolerOperatorMinusHelper >::execHost( result, lhs, rhs );
   return result;
}


template< typename Index,
          std::size_t dimension,
+5 −0
Original line number Diff line number Diff line
@@ -156,6 +156,11 @@ public:
template< typename Index, std::size_t Dimension >
struct DummyStrideBase
{
   static constexpr std::size_t getDimension()
   {
      return Dimension;
   }

   static constexpr bool isContiguous()
   {
      return true;