diff --git a/src/TNL/Containers/NDArray.h b/src/TNL/Containers/NDArray.h index ee89e87862eb819e1e1c5d94bcb6f6b38e477094..d9e4cb09b0882e44d5475b01331d191048eb53d3 100644 --- a/src/TNL/Containers/NDArray.h +++ b/src/TNL/Containers/NDArray.h @@ -173,6 +173,14 @@ public: return ConstViewType( array.getData(), sizes ); } + template< typename Device2 = DeviceType, typename Func > + void forAll( Func f ) const + { + __ndarray_impl::ExecutorDispatcher< ConstViewType, Device2 > dispatch; + dispatch( getConstView(), f ); + } + + // extra methods // TODO: rename to setSizes and make sure that overloading with the following method works diff --git a/src/TNL/Containers/NDArrayView.h b/src/TNL/Containers/NDArrayView.h index fe75bdadd3c842ae47f750d31c4f6cf8e36b870c..73af6713f83de5e5dc3880fc40ff93bc043d454c 100644 --- a/src/TNL/Containers/NDArrayView.h +++ b/src/TNL/Containers/NDArrayView.h @@ -15,6 +15,7 @@ #include <TNL/Containers/ndarray/Indexing.h> #include <TNL/Containers/ndarray/SizesHolder.h> #include <TNL/Containers/ndarray/Subarrays.h> +#include <TNL/Containers/ndarray/Operations.h> namespace TNL { namespace Containers { @@ -226,6 +227,13 @@ public: return SubarrayView{ &begin, subarray_sizes, strides }; } + template< typename Device2 = DeviceType, typename Func > + void forAll( Func f ) const + { + __ndarray_impl::ExecutorDispatcher< NDArrayView, Device2 > dispatch; + dispatch( *this, f ); + } + protected: Value* array = nullptr; SizesHolder sizes; diff --git a/src/UnitTests/Containers/ndarray/NDArrayTest.cpp b/src/UnitTests/Containers/ndarray/NDArrayTest.cpp index 2a98e71f10572fecda73ea123faa51d7c84f8b5a..385ff93da2f8e214056d3022263a6cbf6f6b53d2 100644 --- a/src/UnitTests/Containers/ndarray/NDArrayTest.cpp +++ b/src/UnitTests/Containers/ndarray/NDArrayTest.cpp @@ -180,6 +180,53 @@ TEST( NDArrayTest, SizesHolderPrinter ) EXPECT_EQ( str.str(), "SizesHolder< 0, 1, 2 >( 3, 1, 2 )" ); } +TEST( NDArrayTest, forAll_dynamic ) +{ + int I = 2, J = 2, K = 2, L = 2, M = 2, N = 2; + NDArray< int, + SizesHolder< int, 0, 0, 0, 0, 0, 0 >, + index_sequence< 5, 3, 4, 2, 0, 1 > > a; + a.setSizes( I, J, K, L, M, N ); + a.setValue( 0 ); + + auto setter = [&] ( int i, int j, int k, int l, int m, int n ) + { + a( i, j, k, l, m, n ) = 1; + }; + + a.forAll( setter ); + + for( int n = 0; n < N; n++ ) + for( int l = 0; l < L; l++ ) + for( int m = 0; m < M; m++ ) + for( int k = 0; k < K; k++ ) + for( int i = 0; i < I; i++ ) + for( int j = 0; j < J; j++ ) + EXPECT_EQ( a( i, j, k, l, m, n ), 1 ); +} + +TEST( NDArrayTest, forAll_static ) +{ + constexpr int I = 3, J = 4; + NDArray< int, SizesHolder< int, I, J > > a; + a.setSizes( 0, 0 ); + + for( int i = 0; i < I; i++ ) + for( int j = 0; j < J; j++ ) + a( i, j ) = 0; + + auto setter = [&] ( int i, int j ) + { + a( i, j ) = 1; + }; + + a.forAll( setter ); + + for( int i = 0; i < I; i++ ) + for( int j = 0; j < J; j++ ) + EXPECT_EQ( a( i, j ), 1 ); +} + //#include "GtestMissingError.h" int main( int argc, char* argv[] ) {