Commit e4e780e1 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added overlaps to NDArrayIndexer

This is necessary for local indexing of DistributedNDArray.
parent daebbce1
Loading
Loading
Loading
Loading
+23 −5
Original line number Diff line number Diff line
@@ -22,14 +22,25 @@ namespace Containers {
template< typename SizesHolder,
          typename Permutation,
          typename Base,
          typename StridesHolder = __ndarray_impl::DummyStrideBase< typename SizesHolder::IndexType, SizesHolder::getDimension() > >
          typename StridesHolder = __ndarray_impl::DummyStrideBase< typename SizesHolder::IndexType, SizesHolder::getDimension() >,
          typename Overlaps = __ndarray_impl::make_constant_index_sequence< SizesHolder::getDimension(), 0 > >
class NDArrayIndexer
    : public StridesHolder
{
public:
   using IndexType = typename SizesHolder::IndexType;
   using NDBaseType = Base;
   using SizesHolderType = SizesHolder;
   using StridesHolderType = StridesHolder;
   using PermutationType = Permutation;
   using OverlapsType = Overlaps;

   static_assert( StridesHolder::getDimension() == SizesHolder::getDimension(),
                  "Dimension of strides does not match the dimension of sizes." );
   static_assert( Permutation::size() == SizesHolder::getDimension(),
                  "Dimension of permutation does not match the dimension of sizes." );
   static_assert( Overlaps::size() == SizesHolder::getDimension(),
                  "Dimension of overlaps does not match the dimension of sizes." );

   __cuda_callable__
   NDArrayIndexer() = default;
@@ -60,12 +71,18 @@ public:
   // method template from base class
   using StridesHolder::getStride;

   template< std::size_t level >
   static constexpr std::size_t getOverlap()
   {
      return __ndarray_impl::get< level >( Overlaps{} );
   }

   // returns the product of the aligned sizes
   __cuda_callable__
   IndexType getStorageSize() const
   {
      using Alignment = typename Base::template Alignment< Permutation >;
      return __ndarray_impl::StorageSizeGetter< SizesHolder, Alignment >::get( sizes );
      return __ndarray_impl::StorageSizeGetter< SizesHolder, Alignment, Overlaps >::get( sizes );
   }

   template< typename... IndexTypes >
@@ -74,7 +91,8 @@ public:
   getStorageIndex( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == SizesHolder::getDimension(), "got wrong number of indices" );
      return Base::template getStorageIndex< Permutation >( sizes,
      return Base::template getStorageIndex< Permutation, Overlaps >
             ( sizes,
               static_cast< const StridesHolder& >( *this ),
               std::forward< IndexTypes >( indices )... );
   }
+24 −15
Original line number Diff line number Diff line
@@ -113,6 +113,7 @@ auto host_call_with_unshifted_indices( const SizesHolder& begins, Func&& f, Indi


template< typename Permutation,
          typename Overlaps,
          typename Alignment,
          typename SliceInfo,
          std::size_t level = Permutation::size() - 1,
@@ -121,10 +122,11 @@ struct SlicedIndexer
{};

template< typename Permutation,
          typename Overlaps,
          typename Alignment,
          typename SliceInfo,
          std::size_t level >
struct SlicedIndexer< Permutation, Alignment, SliceInfo, level, false >
struct SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, level, false >
{
   template< typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
@@ -134,17 +136,19 @@ struct SlicedIndexer< Permutation, Alignment, SliceInfo, level, false >
             Indices&&... indices )
   {
      static constexpr std::size_t idx = get< level >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      const auto alpha = get_from_pack< idx >( std::forward< Indices >( indices )... );
      const auto previous = SlicedIndexer< Permutation, Alignment, SliceInfo, level - 1 >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
      return strides.template getStride< idx >( alpha ) * ( alpha + Alignment::template getAlignedSize< idx >( sizes ) * previous );
      const auto previous = SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, level - 1 >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
      return strides.template getStride< idx >( alpha ) * ( alpha + overlap + Alignment::template getAlignedSize< idx >( sizes ) * previous );
   }
};

template< typename Permutation,
          typename Overlaps,
          typename Alignment,
          typename SliceInfo,
          std::size_t level >
struct SlicedIndexer< Permutation, Alignment, SliceInfo, level, true >
struct SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, level, true >
{
   template< typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
@@ -157,20 +161,22 @@ struct SlicedIndexer< Permutation, Alignment, SliceInfo, level, true >
                     "Invalid SliceInfo: static dimension cannot be sliced." );

      static constexpr std::size_t idx = get< level >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      const auto alpha = get_from_pack< idx >( std::forward< Indices >( indices )... );
      static constexpr std::size_t S = SliceInfo::getSliceSize( idx );
      // TODO: check the calculation with strides
      return strides.template getStride< idx >( alpha ) *
                  ( S * (alpha / S) * StorageSizeGetter< SizesHolder, Alignment, IndexTag< level - 1 > >::getPermuted( sizes, Permutation{} ) +
                    alpha % S ) +
             S * SlicedIndexer< Permutation, Alignment, SliceInfo, level - 1 >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
                  ( S * ((alpha + overlap) / S) * StorageSizeGetter< SizesHolder, Alignment, Overlaps, IndexTag< level - 1 > >::getPermuted( sizes, Permutation{} ) +
                    (alpha + overlap) % S ) +
             S * SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, level - 1 >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
   }
};

template< typename Permutation,
          typename Overlaps,
          typename Alignment,
          typename SliceInfo >
struct SlicedIndexer< Permutation, Alignment, SliceInfo, 0, false >
struct SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, 0, false >
{
   template< typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
@@ -180,15 +186,17 @@ struct SlicedIndexer< Permutation, Alignment, SliceInfo, 0, false >
             Indices&&... indices )
   {
      static constexpr std::size_t idx = get< 0 >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      const auto alpha = get_from_pack< idx >( std::forward< Indices >( indices )... );
      return strides.template getStride< idx >( alpha ) * alpha;
      return strides.template getStride< idx >( alpha ) * (alpha + overlap);
   }
};

template< typename Permutation,
          typename Overlaps,
          typename Alignment,
          typename SliceInfo >
struct SlicedIndexer< Permutation, Alignment, SliceInfo, 0, true >
struct SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo, 0, true >
{
   template< typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
@@ -198,8 +206,9 @@ struct SlicedIndexer< Permutation, Alignment, SliceInfo, 0, true >
             Indices&&... indices )
   {
      static constexpr std::size_t idx = get< 0 >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      const auto alpha = get_from_pack< idx >( std::forward< Indices >( indices )... );
      return strides.template getStride< idx >( alpha ) * alpha;
      return strides.template getStride< idx >( alpha ) * (alpha + overlap);
   }
};

@@ -227,14 +236,14 @@ struct NDArrayBase
      }
   };

   template< typename Permutation, typename SizesHolder, typename StridesHolder, typename... Indices >
   template< typename Permutation, typename Overlaps, typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
   typename SizesHolder::IndexType
   static getStorageIndex( const SizesHolder& sizes, const StridesHolder& strides, Indices&&... indices )
   {
      static_assert( check_slice_size( SizesHolder::getDimension(), 0 ), "BUG - invalid SliceInfo type passed to NDArrayBase" );
      using Alignment = Alignment< Permutation >;
      return SlicedIndexer< Permutation, Alignment, SliceInfo >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
      return SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
   }

private:
@@ -271,13 +280,13 @@ struct SlicedNDArrayBase
      }
   };

   template< typename Permutation, typename SizesHolder, typename StridesHolder, typename... Indices >
   template< typename Permutation, typename Overlaps, typename SizesHolder, typename StridesHolder, typename... Indices >
   __cuda_callable__
   static typename SizesHolder::IndexType
   getStorageIndex( const SizesHolder& sizes, const StridesHolder& strides, Indices&&... indices )
   {
      using Alignment = Alignment< Permutation >;
      return SlicedIndexer< Permutation, Alignment, SliceInfo >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
      return SlicedIndexer< Permutation, Overlaps, Alignment, SliceInfo >::getIndex( sizes, strides, std::forward< Indices >( indices )... );
   }
};

+15 −8
Original line number Diff line number Diff line
@@ -25,6 +25,7 @@ namespace __ndarray_impl {
// Dynamic storage size with alignment
template< typename SizesHolder,
          typename Alignment,
          typename Overlaps,
          typename LevelTag = IndexTag< SizesHolder::getDimension() - 1 > >
struct StorageSizeGetter
{
@@ -32,8 +33,10 @@ struct StorageSizeGetter
   __cuda_callable__
   get( const SizesHolder& sizes )
   {
      static constexpr std::size_t overlap = __ndarray_impl::get< LevelTag::value >( Overlaps{} );
      const auto size = Alignment::template getAlignedSize< LevelTag::value >( sizes );
      return size * StorageSizeGetter< SizesHolder, Alignment, IndexTag< LevelTag::value - 1 > >::get( sizes );
      return ( size + 2 * overlap )
             * StorageSizeGetter< SizesHolder, Alignment, Overlaps, IndexTag< LevelTag::value - 1 > >::get( sizes );
   }

   template< typename Permutation >
@@ -41,20 +44,23 @@ struct StorageSizeGetter
   static typename SizesHolder::IndexType
   getPermuted( const SizesHolder& sizes, Permutation )
   {
      constexpr std::size_t idx = __ndarray_impl::get< LevelTag::value >( Permutation{} );
      static constexpr std::size_t idx = __ndarray_impl::get< LevelTag::value >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      const auto size = Alignment::template getAlignedSize< idx >( sizes );
      return size * StorageSizeGetter< SizesHolder, Alignment, IndexTag< LevelTag::value - 1 > >::get( sizes );
      return ( size + 2 * overlap )
             * StorageSizeGetter< SizesHolder, Alignment, Overlaps, IndexTag< LevelTag::value - 1 > >::get( sizes );
   }
};

template< typename SizesHolder, typename Alignment >
struct StorageSizeGetter< SizesHolder, Alignment, IndexTag< 0 > >
template< typename SizesHolder, typename Alignment, typename Overlaps >
struct StorageSizeGetter< SizesHolder, Alignment, Overlaps, IndexTag< 0 > >
{
   static typename SizesHolder::IndexType
   __cuda_callable__
   get( const SizesHolder& sizes )
   {
      return Alignment::template getAlignedSize< 0 >( sizes );
      static constexpr std::size_t overlap = __ndarray_impl::get< 0 >( Overlaps{} );
      return Alignment::template getAlignedSize< 0 >( sizes ) + 2 * overlap;
   }

   template< typename Permutation >
@@ -62,8 +68,9 @@ struct StorageSizeGetter< SizesHolder, Alignment, IndexTag< 0 > >
   static typename SizesHolder::IndexType
   getPermuted( const SizesHolder& sizes, Permutation )
   {
      constexpr std::size_t idx = __ndarray_impl::get< 0 >( Permutation{} );
      return Alignment::template getAlignedSize< idx >( sizes );
      static constexpr std::size_t idx = __ndarray_impl::get< 0 >( Permutation{} );
      static constexpr std::size_t overlap = __ndarray_impl::get< idx >( Overlaps{} );
      return Alignment::template getAlignedSize< idx >( sizes ) + 2 * overlap;
   }
};