Commit 69deca31 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Updated code using TemplateStaticFor to use staticFor

parent 4f8d1e21
Loading
Loading
Loading
Loading
+16 −23
Original line number Diff line number Diff line
@@ -391,7 +391,22 @@ public:
   void allocate()
   {
      SizesHolderType localSizes;
      Algorithms::TemplateStaticFor< std::size_t, 0, SizesHolderType::getDimension(), LocalSizesSetter >::execHost( localSizes, globalSizes, localBegins, localEnds );
      Algorithms::staticFor< std::size_t, 0, SizesHolderType::getDimension() >(
         [&] ( auto level ) {
            if( SizesHolderType::template getStaticSize< level >() != 0 )
               return;

            const auto begin = localBegins.template getSize< level >();
            const auto end = localEnds.template getSize< level >();
            if( begin == end )
               localSizes.template setSize< level >( globalSizes.template getSize< level >() );
            else {
               TNL_ASSERT_GE( end - begin, (decltype(end)) __ndarray_impl::get<level>( OverlapsType{} ), "local size is less than the size of overlaps" );
               //localSizes.template setSize< level >( end - begin + 2 * __ndarray_impl::get<level>( OverlapsType{} ) );
               localSizes.template setSize< level >( end - begin );
            }
         }
      );
      localArray.setSize( localSizes );
   }

@@ -439,28 +454,6 @@ protected:
   // static sizes should have different type: localBegin is always 0, localEnd is always the full size
   LocalBeginsType localBegins;
   SizesHolderType localEnds;

private:
   template< std::size_t level >
   struct LocalSizesSetter
   {
      template< typename SizesHolder, typename LocalBegins >
      static void exec( SizesHolder& localSizes, const SizesHolder& globalSizes, const LocalBegins& localBegins, const SizesHolder& localEnds )
      {
         if( SizesHolder::template getStaticSize< level >() != 0 )
            return;

         const auto begin = localBegins.template getSize< level >();
         const auto end = localEnds.template getSize< level >();
         if( begin == end )
            localSizes.template setSize< level >( globalSizes.template getSize< level >() );
         else {
            TNL_ASSERT_GE( end - begin, (decltype(end)) __ndarray_impl::get<level>( OverlapsType{} ), "local size is less than the size of overlaps" );
            //localSizes.template setSize< level >( end - begin + 2 * __ndarray_impl::get<level>( OverlapsType{} ) );
            localSizes.template setSize< level >( end - begin );
         }
      }
   };
};

} // namespace Containers
+147 −141
Original line number Diff line number Diff line
@@ -156,7 +156,11 @@ public:
         this->mask = mask;

         // allocate buffers
         Algorithms::TemplateStaticFor< std::size_t, 0, DistributedNDArray::getDimension(), AllocateHelper >::execHost( buffers, array_view );
         Algorithms::staticFor< std::size_t, 0, DistributedNDArray::getDimension() >(
            [&] ( auto dim ) {
               allocateHelper< dim >( buffers, array_view );
            }
         );
      }
      else {
         // only bind to the actual data
@@ -239,12 +243,20 @@ protected:
   RequestsVector worker_init()
   {
      // fill send buffers
      Algorithms::TemplateStaticFor< std::size_t, 0, DistributedNDArray::getDimension(), CopyHelper >::execHost( buffers, array_view, true, mask );
      Algorithms::staticFor< std::size_t, 0, DistributedNDArray::getDimension() >(
         [&] ( auto dim ) {
            copyHelper< dim >( buffers, array_view, true, mask );
         }
      );

      // issue all send and receive async operations
      RequestsVector requests;
      const MPI_Comm group = array_view.getCommunicationGroup();
      Algorithms::TemplateStaticFor< std::size_t, 0, DistributedNDArray::getDimension(), SendHelper >::execHost( buffers, requests, group, tag_offset, mask );
      Algorithms::staticFor< std::size_t, 0, DistributedNDArray::getDimension() >(
         [&] ( auto dim ) {
            sendHelper< dim >( buffers, requests, group, tag_offset, mask );
         }
      );

      return requests;
   }
@@ -252,13 +264,15 @@ protected:
   void worker_finish()
   {
      // copy data from receive buffers
      Algorithms::TemplateStaticFor< std::size_t, 0, DistributedNDArray::getDimension(), CopyHelper >::execHost( buffers, array_view, false, mask );
      Algorithms::staticFor< std::size_t, 0, DistributedNDArray::getDimension() >(
         [&] ( auto dim ) {
            copyHelper< dim >( buffers, array_view, false, mask );
         }
      );
   }

   template< std::size_t dim >
   struct AllocateHelper
   {
      static void exec( Buffers& buffers, const DistributedNDArrayView& array_view )
   static void allocateHelper( Buffers& buffers, const DistributedNDArrayView& array_view )
   {
      auto& dim_buffers = buffers.template getDimBuffers< dim >();

@@ -312,12 +326,9 @@ protected:
      dim_buffers.left_neighbor = (rank + nproc - 1) % nproc;
      dim_buffers.right_neighbor = (rank + 1) % nproc;
   }
   };

   template< std::size_t dim >
   struct CopyHelper
   {
      static void exec( Buffers& buffers, DistributedNDArrayView& array_view, bool to_buffer, SyncDirection mask )
   static void copyHelper( Buffers& buffers, DistributedNDArrayView& array_view, bool to_buffer, SyncDirection mask )
   {
      // skip if there are no overlaps
      constexpr std::size_t overlap = DistributedNDArrayView::LocalViewType::IndexerType::template getOverlap< dim >();
@@ -368,13 +379,9 @@ protected:
      }

   }
   };

   template< std::size_t dim >
   struct SendHelper
   {
      template< typename Requests, typename Group >
      static void exec( Buffers& buffers, Requests& requests, Group group, int tag_offset, SyncDirection mask )
   static void sendHelper( Buffers& buffers, RequestsVector& requests, MPI_Comm group, int tag_offset, SyncDirection mask )
   {
      constexpr std::size_t overlap = DistributedNDArrayView::LocalViewType::IndexerType::template getOverlap< dim >();
      if( overlap == 0 )
@@ -415,7 +422,6 @@ protected:
                                         dim_buffers.right_neighbor, tag_offset + 0, group ) );
      }
   }
   };

#ifdef __NVCC__
public:
+33 −49
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@

#include <TNL/Assert.h>
#include <TNL/Cuda/CudaCallable.h>
#include <TNL/Algorithms/TemplateStaticFor.h>
#include <TNL/Algorithms/staticFor.h>

#include <TNL/Containers/ndarray/Meta.h>

@@ -124,48 +124,6 @@ protected:
    }
};

template< std::size_t dimension >
struct SizesHolderStaticSizePrinter
{
   template< typename SizesHolder >
   static void exec( std::ostream& str, const SizesHolder& holder )
   {
      str << holder.template getStaticSize< dimension >() << ", ";
   }
};

template< std::size_t dimension >
struct SizesHolderSizePrinter
{
   template< typename SizesHolder >
   static void exec( std::ostream& str, const SizesHolder& holder )
   {
      str << holder.template getSize< dimension >() << ", ";
   }
};

template< std::size_t level >
struct SizesHolerOperatorPlusHelper
{
   template< typename Result, typename LHS, typename RHS >
   static void exec( Result& result, const LHS& lhs, const RHS& rhs )
   {
      if( result.template getStaticSize< level >() == 0 )
         result.template setSize< level >( lhs.template getSize< level >() + rhs.template getSize< level >() );
   }
};

template< std::size_t level >
struct SizesHolerOperatorMinusHelper
{
   template< typename Result, typename LHS, typename RHS >
   static void exec( Result& result, const LHS& lhs, const RHS& rhs )
   {
      if( result.template getStaticSize< level >() == 0 )
         result.template setSize< level >( lhs.template getSize< level >() - rhs.template getSize< level >() );
   }
};

} // namespace __ndarray_impl


@@ -231,7 +189,12 @@ SizesHolder< Index, sizes... >
operator+( const SizesHolder< Index, sizes... >& lhs, const OtherHolder& rhs )
{
   SizesHolder< Index, sizes... > result;
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes), __ndarray_impl::SizesHolerOperatorPlusHelper >::execHost( result, lhs, rhs );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) >(
      [&result, &lhs, &rhs] ( auto level ) {
         if( result.template getStaticSize< level >() == 0 )
            result.template setSize< level >( lhs.template getSize< level >() + rhs.template getSize< level >() );
      }
   );
   return result;
}

@@ -242,7 +205,12 @@ SizesHolder< Index, sizes... >
operator-( const SizesHolder< Index, sizes... >& lhs, const OtherHolder& rhs )
{
   SizesHolder< Index, sizes... > result;
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes), __ndarray_impl::SizesHolerOperatorMinusHelper >::execHost( result, lhs, rhs );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) >(
      [&result, &lhs, &rhs] ( auto level ) {
         if( result.template getStaticSize< level >() == 0 )
            result.template setSize< level >( lhs.template getSize< level >() - rhs.template getSize< level >() );
      }
   );
   return result;
}

@@ -295,9 +263,17 @@ template< typename Index,
std::ostream& operator<<( std::ostream& str, const SizesHolder< Index, sizes... >& holder )
{
   str << "SizesHolder< ";
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderStaticSizePrinter >::execHost( str, holder );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) - 1 >(
      [&str, &holder] ( auto dimension ) {
         str << holder.template getStaticSize< dimension >() << ", ";
      }
   );
   str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " >( ";
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderSizePrinter >::execHost( str, holder );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) - 1 >(
      [&str, &holder] ( auto dimension ) {
         str << holder.template getSize< dimension >() << ", ";
      }
   );
   str << holder.template getSize< sizeof...(sizes) - 1 >() << " )";
   return str;
}
@@ -360,10 +336,18 @@ template< typename Index,
std::ostream& operator<<( std::ostream& str, const __ndarray_impl::LocalBeginsHolder< SizesHolder< Index, sizes... >, ConstValue >& holder )
{
   str << "LocalBeginsHolder< SizesHolder< ";
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderStaticSizePrinter >::execHost( str, (SizesHolder< Index, sizes... >) holder );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) - 1 >(
      [&str, &holder] ( auto dimension ) {
         str << holder.template getStaticSize< dimension >() << ", ";
      }
   );
   str << holder.template getStaticSize< sizeof...(sizes) - 1 >() << " >, ";
   str << ConstValue << " >( ";
   Algorithms::TemplateStaticFor< std::size_t, 0, sizeof...(sizes) - 1, __ndarray_impl::SizesHolderSizePrinter >::execHost( str, holder );
   Algorithms::staticFor< std::size_t, 0, sizeof...(sizes) - 1 >(
      [&str, &holder] ( auto dimension ) {
         str << holder.template getSize< dimension >() << ", ";
      }
   );
   str << holder.template getSize< sizeof...(sizes) - 1 >() << " )";
   return str;
}
+6 −14
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@
#include <algorithm>

#include <TNL/Assert.h>
#include <TNL/Algorithms/TemplateStaticFor.h>
#include <TNL/Algorithms/staticFor.h>
#include <TNL/Containers/ndarray/Meta.h>

namespace TNL {
@@ -209,18 +209,6 @@ struct SetSizesCopyHelper< TargetHolder, SourceHolder, 0 >
};


template< std::size_t level >
struct WeakCompareHelper
{
   template< typename SizesHolder1,
             typename SizesHolder2 >
   __cuda_callable__
   static void exec( const SizesHolder1& sizes1, const SizesHolder2& sizes2, bool& result )
   {
      result &= sizes1.template getSize< level >() == sizes2.template getSize< level >();
   }
};

// helper for the assignment operator in NDArrayView
template< typename SizesHolder1,
          typename SizesHolder2 >
@@ -230,7 +218,11 @@ bool sizesWeakCompare( const SizesHolder1& sizes1, const SizesHolder2& sizes2 )
   static_assert( SizesHolder1::getDimension() == SizesHolder2::getDimension(),
                  "Cannot compare sizes of different dimensions." );
   bool result = true;
   Algorithms::TemplateStaticFor< std::size_t, 0, SizesHolder1::getDimension(), WeakCompareHelper >::exec( sizes1, sizes2, result );
   Algorithms::staticFor< std::size_t, 0, SizesHolder1::getDimension() >(
      [&result, &sizes1, &sizes2] ( auto level ) {
         result = result && sizes1.template getSize< level >() == sizes2.template getSize< level >();
      }
   );
   return result;
}

+45 −53
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@
#include <TNL/Meshes/GridDetails/Grid1D.h>
#include <TNL/Meshes/GridDetails/Grid2D.h>
#include <TNL/Meshes/GridDetails/Grid3D.h>
#include <TNL/Algorithms/TemplateStaticFor.h>
#include <TNL/Algorithms/staticFor.h>

namespace TNL {
namespace Meshes {
@@ -160,23 +160,15 @@ class NeighborGridEntityGetter<

      }

      template< IndexType index >
      class StencilRefresher
      {
         public:
 
            __cuda_callable__
            static void exec( NeighborGridEntityGetter& neighborEntityGetter, const IndexType& entityIndex )
            {
               neighborEntityGetter.stencil[ index + stencilSize ] = entityIndex + index;
            }
      };
 
      __cuda_callable__
      void refresh( const GridType& grid, const IndexType& entityIndex )
      {
#ifndef HAVE_CUDA  // TODO: fix it -- does not work with nvcc
         Algorithms::TemplateStaticFor< IndexType, -stencilSize, stencilSize + 1, StencilRefresher >::exec( *this, entityIndex );
         Algorithms::staticFor< IndexType, -stencilSize, stencilSize + 1 >(
            [&] ( auto index ) {
               stencil[ index + stencilSize ] = entityIndex + index;
            }
         );
#endif
      };

Loading