Commit de8f1eb8 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

Split NDArrayIndexer from NDArrayStorage and NDArrayView

parent fd4d8429
Loading
Loading
Loading
Loading
+61 −85
Original line number Diff line number Diff line
@@ -40,6 +40,7 @@ template< typename Array,
          typename Base,
          typename Device = typename Array::DeviceType >
class NDArrayStorage
    : public NDArrayIndexer< SizesHolder, Permutation, Base >
{
public:
   using StorageArray = Array;
@@ -48,17 +49,12 @@ public:
   using IndexType = typename Array::IndexType;
   using SizesHolderType = SizesHolder;
   using PermutationType = Permutation;
   using IndexerType = NDArrayIndexer< SizesHolder, Permutation, Base >;
   using ViewType = NDArrayView< ValueType, DeviceType, SizesHolder, Permutation, Base >;
   using ConstViewType = NDArrayView< std::add_const_t< ValueType >, DeviceType, SizesHolder, Permutation, Base >;

   static_assert( Permutation::size() == SizesHolder::getDimension(), "invalid permutation" );

   // for compatibility with NDArrayView (which inherits from StrideBase)
   static constexpr bool isContiguous()
   {
      return true;
   }

   // all methods from NDArrayView

   NDArrayStorage() = default;
@@ -83,7 +79,7 @@ public:
      static_assert( std::is_same< PermutationType, typename OtherArray::PermutationType >::value,
                     "Arrays must have the same permutation of indices." );
      // update sizes
      __ndarray_impl::SetSizesCopyHelper< SizesHolderType, typename OtherArray::SizesHolderType >::copy( sizes, other.getSizes() );
      __ndarray_impl::SetSizesCopyHelper< SizesHolderType, typename OtherArray::SizesHolderType >::copy( getSizes(), other.getSizes() );
      // (re)allocate storage if necessary
      array.setSize( getStorageSize() );
      // copy data
@@ -94,57 +90,77 @@ public:
   bool operator==( const NDArrayStorage& other ) const
   {
      // FIXME: uninitialized data due to alignment in NDArray and padding in SlicedNDArray
      return sizes == other.sizes && array == other.array;
      return getSizes() == other.getSizes() && array == other.array;
   }

   bool operator!=( const NDArrayStorage& other ) const
   {
      // FIXME: uninitialized data due to alignment in NDArray and padding in SlicedNDArray
      return sizes != other.sizes || array != other.array;
      return getSizes() != other.getSizes() || array != other.array;
   }

   // accessor to the underlying data
   // (should not be used for accessing the elements, intended only for the implementation
   // of operator= and functions like cudaHostRegister)
   std::add_const_t< ValueType >* getData() const
   __cuda_callable__
   ValueType* getData()
   {
      return array.getData();
   }

   static constexpr std::size_t getDimension()
   __cuda_callable__
   std::add_const_t< ValueType >* getData() const
   {
      return SizesHolder::getDimension();
      return array.getData();
   }

   const SizesHolderType& getSizes() const
   // methods from the base class
   using IndexerType::getDimension;
   using IndexerType::getSizes;
   using IndexerType::getSize;
   using IndexerType::getStride;
   using IndexerType::getStorageSize;
   using IndexerType::getStorageIndex;

   __cuda_callable__
   const IndexerType& getIndexer() const
   {
      return sizes;
      return *this;
   }

   template< std::size_t level >
   __cuda_callable__
   IndexType getSize() const
   ViewType getView()
   {
      return sizes.template getSize< level >();
      return ViewType( array.getData(), getSizes() );
   }

   // returns the product of the aligned sizes
   __cuda_callable__
   IndexType getStorageSize() const
   ConstViewType getConstView() const
   {
      using Alignment = typename Base::template Alignment< Permutation >;
      return __ndarray_impl::StorageSizeGetter< SizesHolder, Alignment >::get( sizes );
      return ConstViewType( array.getData(), getSizes() );
   }

   template< typename... IndexTypes >
   template< std::size_t... Dimensions, typename... IndexTypes >
   __cuda_callable__
   IndexType
   getStorageIndex( IndexTypes&&... indices ) const
   auto getSubarrayView( IndexTypes&&... indices )
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      return Base::template getStorageIndex< Permutation >( sizes,
                                                            StrideBase{},
                                                            std::forward< IndexTypes >( indices )... );
      static_assert( 0 < sizeof...(Dimensions) && sizeof...(Dimensions) <= getDimension(), "got wrong number of dimensions" );
      static_assert( __ndarray_impl::all_elements_in_range( 0, Permutation::size(), {Dimensions...} ),
                     "invalid dimensions" );
// FIXME: nvcc chokes on the variadic brace-initialization
#ifndef __NVCC__
      static_assert( __ndarray_impl::is_increasing_sequence( {Dimensions...} ),
                     "specifying permuted dimensions is not supported" );
#endif

      using Getter = __ndarray_impl::SubarrayGetter< Base, Permutation, Dimensions... >;
      using Subpermutation = typename Getter::Subpermutation;
      auto& begin = operator()( std::forward< IndexTypes >( indices )... );
      auto subarray_sizes = Getter::filterSizes( getSizes(), std::forward< IndexTypes >( indices )... );
      auto strides = Getter::getStrides( getSizes(), std::forward< IndexTypes >( indices )... );
      static_assert( Subpermutation::size() == sizeof...(Dimensions), "Bug - wrong subpermutation length." );
      static_assert( decltype(subarray_sizes)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the new sizes." );
      static_assert( decltype(strides)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the strides." );
      using SubarrayView = NDArrayView< ValueType, Device, decltype(subarray_sizes), Subpermutation, Base, decltype(strides) >;
      return SubarrayView{ &begin, subarray_sizes, strides };
   }

   template< typename... IndexTypes >
@@ -153,7 +169,7 @@ public:
   operator()( IndexTypes&&... indices )
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      __ndarray_impl::assertIndicesInBounds( sizes, std::forward< IndexTypes >( indices )... );
      __ndarray_impl::assertIndicesInBounds( getSizes(), std::forward< IndexTypes >( indices )... );
      TNL_ASSERT_LT( getStorageIndex( std::forward< IndexTypes >( indices )... ), getStorageSize(),
                     "storage index out of bounds - either input error or a bug in the indexer" );
      return array[ getStorageIndex( std::forward< IndexTypes >( indices )... ) ];
@@ -165,7 +181,7 @@ public:
   operator()( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      __ndarray_impl::assertIndicesInBounds( sizes, std::forward< IndexTypes >( indices )... );
      __ndarray_impl::assertIndicesInBounds( getSizes(), std::forward< IndexTypes >( indices )... );
      TNL_ASSERT_LT( getStorageIndex( std::forward< IndexTypes >( indices )... ), getStorageSize(),
                     "storage index out of bounds - either input error or a bug in the indexer" );
      return array[ getStorageIndex( std::forward< IndexTypes >( indices )... ) ];
@@ -177,7 +193,7 @@ public:
   operator[]( IndexType index )
   {
      static_assert( getDimension() == 1, "the access via operator[] is provided only for 1D arrays" );
      __ndarray_impl::assertIndicesInBounds( sizes, std::forward< IndexType >( index ) );
      __ndarray_impl::assertIndicesInBounds( getSizes(), std::forward< IndexType >( index ) );
      return array[ index ];
   }

@@ -186,54 +202,16 @@ public:
   operator[]( IndexType index ) const
   {
      static_assert( getDimension() == 1, "the access via operator[] is provided only for 1D arrays" );
      __ndarray_impl::assertIndicesInBounds( sizes, std::forward< IndexType >( index ) );
      __ndarray_impl::assertIndicesInBounds( getSizes(), std::forward< IndexType >( index ) );
      return array[ index ];
   }

   __cuda_callable__
   ViewType getView()
   {
      return ViewType( array.getData(), sizes );
   }

   __cuda_callable__
   ConstViewType getConstView() const
   {
      return ConstViewType( array.getData(), sizes );
   }

   template< std::size_t... Dimensions, typename... IndexTypes >
   __cuda_callable__
   auto getSubarrayView( IndexTypes&&... indices )
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      static_assert( 0 < sizeof...(Dimensions) && sizeof...(Dimensions) <= getDimension(), "got wrong number of dimensions" );
      static_assert( __ndarray_impl::all_elements_in_range( 0, Permutation::size(), {Dimensions...} ),
                     "invalid dimensions" );
// FIXME: nvcc chokes on the variadic brace-initialization
#ifndef __NVCC__
      static_assert( __ndarray_impl::is_increasing_sequence( {Dimensions...} ),
                     "specifying permuted dimensions is not supported" );
#endif

      using Getter = __ndarray_impl::SubarrayGetter< Base, Permutation, Dimensions... >;
      using Subpermutation = typename Getter::Subpermutation;
      auto& begin = operator()( std::forward< IndexTypes >( indices )... );
      auto subarray_sizes = Getter::filterSizes( sizes, std::forward< IndexTypes >( indices )... );
      auto strides = Getter::getStrides( sizes, std::forward< IndexTypes >( indices )... );
      static_assert( Subpermutation::size() == sizeof...(Dimensions), "Bug - wrong subpermutation length." );
      static_assert( decltype(subarray_sizes)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the new sizes." );
      static_assert( decltype(strides)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the strides." );
      using SubarrayView = NDArrayView< ValueType, Device, decltype(subarray_sizes), Subpermutation, Base, decltype(strides) >;
      return SubarrayView{ &begin, subarray_sizes, strides };
   }

   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {
      __ndarray_impl::ExecutorDispatcher< PermutationType, Device2 > dispatch;
      using Begins = ConstStaticSizesHolder< IndexType, getDimension(), 0 >;
      dispatch( Begins{}, sizes, f );
      dispatch( Begins{}, getSizes(), f );
   }

   template< typename Device2 = DeviceType, typename Func >
@@ -245,7 +223,7 @@ public:
      using Ends = typename __ndarray_impl::SubtractedSizesHolder< SizesHolder, 1 >::type;
      // subtract dynamic sizes
      Ends ends;
      __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolder >::subtract( ends, sizes );
      __ndarray_impl::SetSizesSubtractHelper< 1, Ends, SizesHolder >::subtract( ends, getSizes() );
      dispatch( Begins{}, ends, f );
   }

@@ -266,10 +244,10 @@ public:
      using SkipEnds = typename __ndarray_impl::SubtractedSizesHolder< SizesHolder, 1 >::type;
      // subtract dynamic sizes
      SkipEnds skipEnds;
      __ndarray_impl::SetSizesSubtractHelper< 1, SkipEnds, SizesHolder >::subtract( skipEnds, sizes );
      __ndarray_impl::SetSizesSubtractHelper< 1, SkipEnds, SizesHolder >::subtract( skipEnds, getSizes() );

      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( Begins{}, SkipBegins{}, skipEnds, sizes, f );
      dispatch( Begins{}, SkipBegins{}, skipEnds, getSizes(), f );
   }

   template< typename Device2 = DeviceType, typename Func, typename SkipBegins, typename SkipEnds >
@@ -278,7 +256,7 @@ public:
      // TODO: assert "skipBegins <= sizes", "skipEnds <= sizes"
      using Begins = ConstStaticSizesHolder< IndexType, getDimension(), 0 >;
      __ndarray_impl::BoundaryExecutorDispatcher< PermutationType, Device2 > dispatch;
      dispatch( Begins{}, skipBegins, skipEnds, sizes, f );
      dispatch( Begins{}, skipBegins, skipEnds, getSizes(), f );
   }


@@ -287,7 +265,7 @@ public:
   // TODO: rename to setSizes and make sure that overloading with the following method works
   void setSize( const SizesHolderType& sizes )
   {
      this->sizes = sizes;
      getSizes() = sizes;
      array.setSize( getStorageSize() );
   }

@@ -295,19 +273,19 @@ public:
   void setSizes( IndexTypes&&... sizes )
   {
      static_assert( sizeof...( sizes ) == getDimension(), "got wrong number of sizes" );
      __ndarray_impl::setSizesHelper( this->sizes, std::forward< IndexTypes >( sizes )... );
      __ndarray_impl::setSizesHelper( getSizes(), std::forward< IndexTypes >( sizes )... );
      array.setSize( getStorageSize() );
   }

   void setLike( const NDArrayStorage& other )
   {
      this->sizes = other.getSizes();
      getSizes() = other.getSizes();
      array.setSize( getStorageSize() );
   }

   void reset()
   {
      this->sizes = SizesHolder{};
      getSizes() = SizesHolder{};
      TNL_ASSERT_EQ( getStorageSize(), 0, "Failed to reset the sizes." );
      array.reset();
   }
@@ -318,7 +296,7 @@ public:
   getElement( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      __ndarray_impl::assertIndicesInBounds( sizes, std::forward< IndexTypes >( indices )... );
      __ndarray_impl::assertIndicesInBounds( getSizes(), std::forward< IndexTypes >( indices )... );
      TNL_ASSERT_LT( getStorageIndex( std::forward< IndexTypes >( indices )... ), getStorageSize(),
                     "storage index out of bounds - either input error or a bug in the indexer" );
      return array.getElement( getStorageIndex( std::forward< IndexTypes >( indices )... ) );
@@ -341,9 +319,7 @@ public:

protected:
   StorageArray array;
   SizesHolder sizes;

   using StrideBase = __ndarray_impl::DummyStrideBase< typename SizesHolder::IndexType, SizesHolder::getDimension() >;
   IndexerType indexer;
};

template< typename Value,
+94 −0
Original line number Diff line number Diff line
/***************************************************************************
                          NDArrayIndexer.h  -  description
                             -------------------
    begin                : Apr 14, 2019
    copyright            : (C) 2019 by Tomas Oberhuber et al.
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

// Implemented by: Jakub Klinkovsky

#pragma once

#include <TNL/Containers/ndarray/Indexing.h>
#include <TNL/Containers/ndarray/SizesHolderHelpers.h>   // StorageSizeGetter
#include <TNL/Containers/ndarray/Subarrays.h>   // DummyStrideBase

namespace TNL {
namespace Containers {

template< typename SizesHolder,
          typename Permutation,
          typename Base,
          typename StridesHolder = __ndarray_impl::DummyStrideBase< typename SizesHolder::IndexType, SizesHolder::getDimension() > >
class NDArrayIndexer
    : public StridesHolder
{
public:
   using IndexType = typename SizesHolder::IndexType;
   using SizesHolderType = SizesHolder;
   using PermutationType = Permutation;

   __cuda_callable__
   NDArrayIndexer() = default;

   // explicit initialization by sizes and strides
   __cuda_callable__
   NDArrayIndexer( SizesHolder sizes, StridesHolder strides )
   : StridesHolder(strides), sizes(sizes) {}

   static constexpr std::size_t getDimension()
   {
      return SizesHolder::getDimension();
   }

   __cuda_callable__
   const SizesHolderType& getSizes() const
   {
      return sizes;
   }

   template< std::size_t level >
   __cuda_callable__
   IndexType getSize() const
   {
      return sizes.template getSize< level >();
   }

   // method template from base class
   using StridesHolder::getStride;

   // returns the product of the aligned sizes
   __cuda_callable__
   IndexType getStorageSize() const
   {
      using Alignment = typename Base::template Alignment< Permutation >;
      return __ndarray_impl::StorageSizeGetter< SizesHolder, Alignment >::get( sizes );
   }

   template< typename... IndexTypes >
   __cuda_callable__
   IndexType
   getStorageIndex( IndexTypes&&... indices ) const
   {
      static_assert( sizeof...( indices ) == SizesHolder::getDimension(), "got wrong number of indices" );
      return Base::template getStorageIndex< Permutation >( sizes,
                                                            static_cast< const StridesHolder& >( *this ),
                                                            std::forward< IndexTypes >( indices )... );
   }

protected:
   // non-const reference accessor cannot be public - only subclasses like NDArrayStorage may modify the sizes
   __cuda_callable__
   SizesHolderType& getSizes()
   {
      return sizes;
   }

   SizesHolder sizes;
};

} // namespace Containers
} // namespace TNL
+69 −90

File changed.

Preview size limit exceeded, changes collapsed.