Commit 02482962 authored by Jakub Klinkovský's avatar Jakub Klinkovský Committed by Jakub Klinkovský
Browse files

NDArray: added forAll method

parent ffc00260
Loading
Loading
Loading
Loading
+8 −0
Original line number Diff line number Diff line
@@ -173,6 +173,14 @@ public:
      return ConstViewType( array.getData(), sizes );
   }

   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {
      __ndarray_impl::ExecutorDispatcher< ConstViewType, Device2 > dispatch;
      dispatch( getConstView(), f );
   }


   // extra methods

   // TODO: rename to setSizes and make sure that overloading with the following method works
+8 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#include <TNL/Containers/ndarray/Indexing.h>
#include <TNL/Containers/ndarray/SizesHolder.h>
#include <TNL/Containers/ndarray/Subarrays.h>
#include <TNL/Containers/ndarray/Operations.h>

namespace TNL {
namespace Containers {
@@ -226,6 +227,13 @@ public:
      return SubarrayView{ &begin, subarray_sizes, strides };
   }

   template< typename Device2 = DeviceType, typename Func >
   void forAll( Func f ) const
   {
      __ndarray_impl::ExecutorDispatcher< NDArrayView, Device2 > dispatch;
      dispatch( *this, f );
   }

protected:
   Value* array = nullptr;
   SizesHolder sizes;
+47 −0
Original line number Diff line number Diff line
@@ -180,6 +180,53 @@ TEST( NDArrayTest, SizesHolderPrinter )
   EXPECT_EQ( str.str(), "SizesHolder< 0, 1, 2 >( 3, 1, 2 )" );
}

TEST( NDArrayTest, forAll_dynamic )
{
    int I = 2, J = 2, K = 2, L = 2, M = 2, N = 2;
    NDArray< int,
             SizesHolder< int, 0, 0, 0, 0, 0, 0 >,
             index_sequence< 5, 3, 4, 2, 0, 1 > > a;
    a.setSizes( I, J, K, L, M, N );
    a.setValue( 0 );

    auto setter = [&] ( int i, int j, int k, int l, int m, int n )
    {
       a( i, j, k, l, m, n ) = 1;
    };

    a.forAll( setter );

    for( int n = 0; n < N; n++ )
    for( int l = 0; l < L; l++ )
    for( int m = 0; m < M; m++ )
    for( int k = 0; k < K; k++ )
    for( int i = 0; i < I; i++ )
    for( int j = 0; j < J; j++ )
        EXPECT_EQ( a( i, j, k, l, m, n ), 1 );
}

TEST( NDArrayTest, forAll_static )
{
    constexpr int I = 3, J = 4;
    NDArray< int, SizesHolder< int, I, J > > a;
    a.setSizes( 0, 0 );

    for( int i = 0; i < I; i++ )
    for( int j = 0; j < J; j++ )
        a( i, j ) = 0;

    auto setter = [&] ( int i, int j )
    {
       a( i, j ) = 1;
    };

    a.forAll( setter );

    for( int i = 0; i < I; i++ )
    for( int j = 0; j < J; j++ )
        EXPECT_EQ( a( i, j ), 1 );
}

//#include "GtestMissingError.h"
int main( int argc, char* argv[] )
{