Commit 0d735ef4 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Merge branch 'TJ/static-for' into 'develop'

extension of the implementation of staticFor

See merge request !95
parents ee53fa0a 117edb17
Loading
Loading
Loading
Loading
+48 −0
Original line number Diff line number Diff line
#include <iostream>
#include <array>
#include <tuple>
#include <TNL/Algorithms/staticFor.h>

/*
 * Example function printing members of std::tuple using staticFor
 * using lambda with capture.
 */
template< typename... Ts >
void printTuple( const std::tuple<Ts...>& tupleVar )
{
   std::cout << "{ ";
   TNL::Algorithms::staticFor<size_t, 0, sizeof... (Ts)>( [&](auto i) {
      std::cout << std::get<i>(tupleVar);
      if( i < sizeof... (Ts) - 1 )
         std::cout << ", ";
   });
   std::cout << " }" << std::endl;
}

struct TuplePrinter
{
   constexpr TuplePrinter() = default;

   template< typename Index, typename... Ts >
   void operator()( Index i, const std::tuple<Ts...>& tupleVar )
   {
      std::cout << std::get<i>( tupleVar );
      if( i < sizeof... (Ts) - 1 )
         std::cout << ", ";
   }
};

/*
 * Example function printing members of std::tuple using staticFor
 * and a structure with templated operator().
 */
template< typename... Ts >
void printTupleCallableStruct( const std::tuple<Ts...>& tupleVar )
{
   std::cout << "{ ";
   TNL::Algorithms::staticFor< size_t, 0, sizeof... (Ts) >( TuplePrinter(), tupleVar );
   std::cout << " }" << std::endl;
}


int main( int argc, char* argv[] )
{
   // initiate std::array
@@ -13,4 +56,9 @@ int main( int argc, char* argv[] )
         std::cout << "a[ " << i << " ] = " << std::get< i >( a ) << std::endl;
      }
   );

   // example of printing a tuple using staticFor and a lambda function
   printTuple( std::make_tuple( "Hello", 3, 2.1 ) );
   // example of printing a tuple using staticFor and a structure with templated operator()
   printTupleCallableStruct( std::make_tuple( "Hello", 3, 2.1 ) );
}
+37 −25
Original line number Diff line number Diff line
@@ -27,20 +27,23 @@ static_for_dispatch( Func &&f )
#if __cplusplus >= 201703L

// C++17 version using fold expression
template< typename Index, Index begin,  typename Func, Index... idx >
constexpr void static_for_impl( Func &&f, std::integer_sequence< Index, idx... > )
template< typename Index, Index begin,  typename Func, Index... idx, typename... ArgTypes >
constexpr void static_for_impl( Func &&f, std::integer_sequence< Index, idx... >, ArgTypes&&... args )
{
   ( f( std::integral_constant<Index, begin + idx>{} ), ... );
   ( f( std::integral_constant< Index, begin + idx >{},
        std::forward< ArgTypes >( args )... ),
     ... );
}

// general dispatch for `begin < end`
template< typename Index, Index begin, Index end,  typename Func >
template< typename Index, Index begin, Index end,  typename Func, typename... ArgTypes >
constexpr std::enable_if_t< (begin < end) >
static_for_dispatch( Func &&f )
static_for_dispatch( Func &&f, ArgTypes&&... args )
{
   static_for_impl< Index, begin >(
         std::forward< Func >( f ),
         std::make_integer_sequence< Index, end - begin >{}
         std::make_integer_sequence< Index, end - begin >{},
         std::forward< ArgTypes >( args )...
   );
}

@@ -52,21 +55,24 @@ static_for_dispatch( Func &&f )
// the recursion depth.)

// special dispatch for 1 iteration
template< typename Index, Index begin, Index end,  typename Func >
template< typename Index, Index begin, Index end,  typename Func, typename... ArgTypes >
constexpr std::enable_if_t< (begin < end && end - begin == 1) >
static_for_dispatch( Func &&f )
static_for_dispatch( Func &&f, ArgTypes&&... args )
{
   f( std::integral_constant< Index, begin >{} );
   f( std::integral_constant< Index, begin >{},
      std::forward< ArgTypes >( args )... );
}

// general dispatch for at least 2 iterations
template< typename Index, Index begin, Index end,  typename Func >
template< typename Index, Index begin, Index end,  typename Func, typename... ArgTypes >
constexpr std::enable_if_t< (begin < end && end - begin >= 2) >
static_for_dispatch( Func &&f )
static_for_dispatch( Func &&f, ArgTypes&&... args )
{
   constexpr Index mid = begin + (end - begin) / 2;
   static_for_dispatch< Index, begin, mid >( std::forward< Func >( f ) );
   static_for_dispatch< Index, mid, end >( std::forward< Func >( f ) );
   static_for_dispatch< Index, begin, mid >( std::forward< Func >( f ),
                                             std::forward< ArgTypes >( args )... );
   static_for_dispatch< Index, mid, end >( std::forward< Func >( f ),
                                           std::forward< ArgTypes >( args )... );
}

#endif
@@ -79,34 +85,40 @@ static_for_dispatch( Func &&f )
 *
 * \e staticFor is a generic C++14/C++17 implementation of a static for-loop
 * using \e constexpr functions and template metaprogramming. It is equivalent
 * to executing a function `f(i)` for arguments `i` from the integral range
 * `[begin, end)`, but with the type \ref std::integral_constant rather than
 * `int` or `std::size_t` representing the indices. Hence, each index has its
 * own distinct C++ type and the \e value of the index can be deduced from the
 * type.
 * to executing a function `f(i, args...)` for arguments `i` from the integral
 * range `[begin, end)`, but with the type \ref std::integral_constant rather
 * than `int` or `std::size_t` representing the indices. Hence, each index has
 * its own distinct C++ type and the \e value of the index can be deduced from
 * the type. The `args...` are additional user-supplied arguments that are
 * forwarded to the \e staticFor function.
 *
 * Also note that thanks to `constexpr`, the argument `i` can be used in
 * constant expressions and the \e staticFor function can be used from the host
 * code as well as CUDA kernels (TNL requires the `--expt-relaxed-constexpr`
 * parameter when compiled by `nvcc`).
 * Also note that thanks to `constexpr` cast operator, the argument `i` can be
 * used in constant expressions and the \e staticFor function can be used from
 * the host code as well as CUDA kernels (TNL requires the
 * `--expt-relaxed-constexpr` parameter when compiled by `nvcc`).
 *
 * \tparam Index is the type of the loop indices.
 * \tparam begin is the left bound of the iteration range `[begin, end)`.
 * \tparam end is the right bound of the iteration range `[begin, end)`.
 * \tparam Func is the type of the functor (it is usually deduced from the
 *    argument used in the function call).
 * \tparam ArgTypes are the types of additional arguments passed to the
 *    function.
 *
 * \param f is the functor to be called in each iteration.
 * \param args... are additional user-supplied arguments that are forwarded
 *    to each call of \e f.
 *
 * \par Example
 * \include Algorithms/staticForExample.cpp
 * \par Output
 * \include staticForExample.out
 */
template< typename Index, Index begin, Index end,  typename Func >
constexpr void staticFor( Func&& f )
template< typename Index, Index begin, Index end,  typename Func, typename... ArgTypes >
constexpr void staticFor( Func&& f, ArgTypes&&... args )
{
   detail::static_for_dispatch< Index, begin, end >( std::forward< Func >( f ) );
   detail::static_for_dispatch< Index, begin, end >( std::forward< Func >( f ),
                                                     std::forward< ArgTypes >( args )... );
}

} // namespace Algorithms