Commit 49853fdf authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added smart lambdas to segments.

parent 740f7551
Loading
Loading
Loading
Loading
+8 −7
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Containers/Segments/CSRView.h>
#include <TNL/Containers/Segments/details/CSR.h>
#include <TNL/Containers/Segments/details/LambdaAdapter.h>

namespace TNL {
   namespace Containers {
@@ -215,17 +216,17 @@ void
CSRView< Device, Index >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   const auto offsetsView = this->offsets.getConstView();
   auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
      const IndexType begin = offsetsView[ i ];
      const IndexType end = offsetsView[ i + 1 ];
   auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
      const IndexType begin = offsetsView[ segmentIdx ];
      const IndexType end = offsetsView[ segmentIdx + 1 ];
      RealType aux( zero );
      IndexType localIdx( 0 );
      bool compute( true );
      for( IndexType j = begin; j < end && compute; j++  )
         reduction( aux, fetch( i, localIdx++, j, compute, args... ) );
      keeper( i, aux );
      for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++  )
         reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
      keeper( segmentIdx, aux );
   };
   Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
}
+9 −8
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Containers/Segments/ChunkedEllpackView.h>
#include <TNL/Containers/Segments/details/LambdaAdapter.h>
//#include <TNL/Containers/Segments/details/ChunkedEllpack.h>

namespace TNL {
@@ -401,7 +402,7 @@ void
ChunkedEllpackView< Device, Index, RowMajorOrder >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   if( std::is_same< DeviceType, Devices::Host >::value )
   {
      //segmentsReductionKernel( 0, first, last, fetch, reduction, keeper, zero, args... );
@@ -428,8 +429,8 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
            IndexType begin = sliceOffset + firstChunkOfSegment * chunkSize;
            IndexType end = begin + segmentSize;
            IndexType localIdx( 0 );
            for( IndexType j = begin; j < end && compute; j++ )
               reduction( aux, fetch( segmentIdx, localIdx++, j, compute, args...) );
            for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++ )
               reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
         }
         else
         {
@@ -438,8 +439,8 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
               IndexType begin = sliceOffset + firstChunkOfSegment + chunkIdx;
               IndexType end = begin + chunksInSlice * chunkSize;
               IndexType localIdx( 0 );
               for( IndexType j = begin; j < end && compute; j += chunksInSlice )
                  reduction( aux, fetch( segmentIdx, localIdx++, j, compute, args...) );
               for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx += chunksInSlice )
                  reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
            }
         }
         keeper( segmentIdx, aux );
@@ -459,9 +460,9 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
      {
         if( gridIdx == cudaGrids - 1 )
            cudaGridSize.x = cudaBlocks % Cuda::getMaxGridSize();
         ChunkedEllpackSegmentsReductionKernel< ViewType, IndexType, Fetch, Reduction, ResultKeeper, Real, Args...  >
            <<< cudaGridSize, cudaBlockSize, sharedMemory  >>>
            ( *this, gridIdx, first, last, fetch, reduction, keeper, zero, args... );
         //ChunkedEllpackSegmentsReductionKernel< ViewType, IndexType, Fetch, Reduction, ResultKeeper, Real, Args...  >
         //   <<< cudaGridSize, cudaBlockSize, sharedMemory  >>>
         //   ( *this, gridIdx, first, last, fetch, reduction, keeper, zero, args... );
      }
#endif
   }
+11 −9
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Containers/Segments/EllpackView.h>
#include <TNL/Containers/Segments/details/LambdaAdapter.h>

namespace TNL {
   namespace Containers {
@@ -258,19 +259,20 @@ void
EllpackView< Device, Index, RowMajorOrder, Alignment >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   //using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   if( RowMajorOrder )
   {
      const IndexType segmentSize = this->segmentSize;
      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
         const IndexType begin = i * segmentSize;
      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
         const IndexType begin = segmentIdx * segmentSize;
         const IndexType end = begin + segmentSize;
         RealType aux( zero );
         IndexType localIdx( 0 );
         bool compute( true );
         for( IndexType j = begin; j < end && compute; j++  )
            reduction( aux, fetch( i, localIdx++, j, compute, args... ) );
         keeper( i, aux );
            reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, j, compute ) );
         keeper( segmentIdx, aux );
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
   }
@@ -278,15 +280,15 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
   {
      const IndexType storageSize = this->getStorageSize();
      const IndexType alignedSize = this->alignedSize;
      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
         const IndexType begin = i;
      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
         const IndexType begin = segmentIdx;
         const IndexType end = storageSize;
         RealType aux( zero );
         IndexType localIdx( 0 );
         bool compute( true );
         for( IndexType j = begin; j < end && compute; j += alignedSize  )
            reduction( aux, fetch( i, localIdx++, j, compute, args... ) );
         keeper( i, aux );
            reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, j, compute ) );
         keeper( segmentIdx, aux );
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
   }
+5 −3
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Containers/Segments/SlicedEllpackView.h>
#include <TNL/Containers/Segments/details/LambdaAdapter.h>

#include "SlicedEllpackView.h"

@@ -306,7 +307,8 @@ void
SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   //using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   const auto sliceSegmentSizes_view = this->sliceSegmentSizes.getConstView();
   const auto sliceOffsets_view = this->sliceOffsets.getConstView();
   if( RowMajorOrder )
@@ -321,7 +323,7 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
         IndexType localIdx( 0 );
         bool compute( true );
         for( IndexType globalIdx = begin; globalIdx< end; globalIdx++  )
            reduction( aux, fetch( segmentIdx, localIdx++, globalIdx, compute, args... ) );
            reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
         keeper( segmentIdx, aux );
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -338,7 +340,7 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
         IndexType localIdx( 0 );
         bool compute( true );
         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize  )
            reduction( aux, fetch( segmentIdx, localIdx++, globalIdx, compute, args... ) );
            reduction( aux, details::FetchLambdaAdapter< IndexType, Fetch >::call( fetch, segmentIdx, localIdx++, globalIdx, compute ) );
         keeper( segmentIdx, aux );
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
+81 −0
Original line number Diff line number Diff line
/***************************************************************************
                          CheckLambdas.h -  description
                             -------------------
    begin                : Dpr 4, 2020
    copyright            : (C) 2020 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once


namespace TNL {
   namespace Containers {
      namespace Segments {
         namespace details {

template< typename Index,
          typename Lambda >
class CheckFetchLambdaAcceptsSegmentIdxAndCompute
{
   private:
       typedef char YesType[1];
       typedef char NoType[2];

       template< typename C > static YesType& test( decltype(std::declval< C >()( Index(), Index(), Index(), std::declval< bool& >() ) ) );
       template< typename C > static NoType& test(...);

   public:
       static constexpr bool value = ( sizeof( test< Lambda >(0) ) == sizeof( YesType ) );
};

template< typename Index,
          typename Lambda >
class CheckFetchLambdaAcceptsSegmentIdx
{
   private:
       typedef char YesType[1];
       typedef char NoType[2];

       template< typename C > static YesType& test( decltype(std::declval< C >()( Index(), Index(), Index() ) ) );
       template< typename C > static NoType& test(...);

   public:
       static constexpr bool value = ( sizeof( test< Lambda >(0) ) == sizeof( YesType ) );
};

template< typename Index,
          typename Lambda >
class CheckFetchLambdaAcceptsCompute
{
   private:
       typedef char YesType[1];
       typedef char NoType[2];

       template< typename C > static YesType& test( decltype(std::declval< C >()( Index(), Index(), std::declval< bool& >() ) ) );
       template< typename C > static NoType& test(...);

   public:
       static constexpr bool value = ( sizeof( test< Lambda >(0) ) == sizeof( YesType ) );
};


template< typename Index,
          typename Lambda >
class CheckFetchLambda
{
   static constexpr bool AcceptsSegmentIdxAndCompute = CheckFetchLambdaAcceptsSegmentIdxAndCompute< Index, Lambda >::value;
   static constexpr bool AcceptsSegmentIdx = CheckFetchLambdaAcceptsSegmentIdx< Index, Lambda >::value;
   static constexpr bool AcceptsCompute = CheckFetchLambdaAcceptsCompute< Index, Lambda >::value;

   public:
      static constexpr bool acceptsSegmentIdx() { return AcceptsSegmentIdxAndCompute || AcceptsSegmentIdx; };
      static constexpr bool acceptsCompute() { return AcceptsSegmentIdxAndCompute || AcceptsCompute; };
};

         } // namespace details
      } // namespace Segements
   }  // namespace Conatiners
} // namespace TNL
Loading