Commit 560aba80 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

Copied getSubarrayView method from NDArrayView to NDArray

parent e885aa29
Loading
Loading
Loading
Loading
+26 −0
Original line number Diff line number Diff line
@@ -173,6 +173,32 @@ public:
      return ConstViewType( array.getData(), sizes );
   }

   template< std::size_t... Dimensions, typename... IndexTypes >
   __cuda_callable__
   auto getSubarrayView( IndexTypes&&... indices )
   {
      static_assert( sizeof...( indices ) == getDimension(), "got wrong number of indices" );
      static_assert( 0 < sizeof...(Dimensions) && sizeof...(Dimensions) <= getDimension(), "got wrong number of dimensions" );
      static_assert( __ndarray_impl::all_elements_in_range( 0, Permutation::size(), {Dimensions...} ),
                     "invalid dimensions" );
// FIXME: nvcc chokes on the variadic brace-initialization
#ifndef __NVCC__
      static_assert( __ndarray_impl::is_increasing_sequence( {Dimensions...} ),
                     "specifying permuted dimensions is not supported" );
#endif

      using Getter = __ndarray_impl::SubarrayGetter< Base, Permutation, Dimensions... >;
      using Subpermutation = typename Getter::Subpermutation;
      auto& begin = operator()( std::forward< IndexTypes >( indices )... );
      auto subarray_sizes = Getter::filterSizes( sizes, std::forward< IndexTypes >( indices )... );
      auto strides = Getter::getStrides( sizes, std::forward< IndexTypes >( indices )... );
      static_assert( Subpermutation::size() == sizeof...(Dimensions), "Bug - wrong subpermutation length." );
      static_assert( decltype(subarray_sizes)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the new sizes." );
      static_assert( decltype(strides)::getDimension() == sizeof...(Dimensions), "Bug - wrong dimension of the strides." );
      using SubarrayView = NDArrayView< ValueType, Device, decltype(subarray_sizes), Subpermutation, Base, decltype(strides) >;
      return SubarrayView{ &begin, subarray_sizes, strides };
   }

   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {