Commit 90a43d2c authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Removing Args... from segments reductions.

parent a006571b
Loading
Loading
Loading
Loading
+103 −103
Original line number Diff line number Diff line
@@ -115,11 +115,11 @@ namespace TNL
      /***
       * \brief Go over all segments and perform a reduction in each of them.
       */
            template <typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args>
            void reduceSegments(IndexType first, IndexType last, Fetch &fetch, const Reduction &reduction, ResultKeeper &keeper, const Real &zero, Args... args) const;
      template <typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceSegments(IndexType first, IndexType last, Fetch &fetch, const Reduction &reduction, ResultKeeper &keeper, const Real &zero ) const;

            template <typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args>
            void reduceAllSegments(Fetch &fetch, const Reduction &reduction, ResultKeeper &keeper, const Real &zero, Args... args) const;
      template <typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceAllSegments(Fetch &fetch, const Reduction &reduction, ResultKeeper &keeper, const Real &zero ) const;

      BiEllpack &operator=(const BiEllpack &source) = default;

+6 −6
Original line number Diff line number Diff line
@@ -497,12 +497,12 @@ template< typename Device,
          typename IndexAllocator,
          ElementsOrganization Organization,
          int WarpSize >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
void
BiEllpack< Device, Index, IndexAllocator, Organization, WarpSize >::
reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const
{
   this->getConstView().reduceSegments( first, last, fetch, reduction, keeper, zero, args... );
   this->getConstView().reduceSegments( first, last, fetch, reduction, keeper, zero );
}

template< typename Device,
@@ -510,12 +510,12 @@ template< typename Device,
          typename IndexAllocator,
          ElementsOrganization Organization,
          int WarpSize >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
void
BiEllpack< Device, Index, IndexAllocator, Organization, WarpSize >::
reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const
{
   this->reduceSegments( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
   this->reduceSegments( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero );
}

template< typename Device,
+4 −4
Original line number Diff line number Diff line
@@ -126,11 +126,11 @@ class BiEllpackView
      /***
       * \brief Go over all segments and perform a reduction in each of them.
       */
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
      void reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const;

      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
      void reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const;

      BiEllpackView& operator=( const BiEllpackView& view );

+13 −17
Original line number Diff line number Diff line
@@ -352,10 +352,10 @@ template< typename Device,
          typename Index,
          ElementsOrganization Organization,
          int WarpSize >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
void
BiEllpackView< Device, Index, Organization, WarpSize >::
reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const
{
   using RealType = typename detail::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   if( this->getStorageSize() == 0 )
@@ -425,9 +425,9 @@ reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction&
         dim3 cudaGridSize = Cuda::getMaxGridSize();
         if( gridIdx == cudaGrids - 1 )
            cudaGridSize.x = cudaBlocks % Cuda::getMaxGridSize();
         details::BiEllpackreduceSegmentsKernel< ViewType, IndexType, Fetch, Reduction, ResultKeeper, Real, BlockDim, Args...  >
         details::BiEllpackreduceSegmentsKernel< ViewType, IndexType, Fetch, Reduction, ResultKeeper, Real, BlockDim  >
            <<< cudaGridSize, cudaBlockSize, sharedMemory >>>
            ( *this, gridIdx, first, last, fetch, reduction, keeper, zero, args... );
            ( *this, gridIdx, first, last, fetch, reduction, keeper, zero );
         cudaThreadSynchronize();
         TNL_CHECK_CUDA_DEVICE;
      }
@@ -439,12 +439,12 @@ template< typename Device,
          typename Index,
          ElementsOrganization Organization,
          int WarpSize >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
   template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
void
BiEllpackView< Device, Index, Organization, WarpSize >::
reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const
{
   this->reduceSegments( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
   this->reduceSegments( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero );
}

template< typename Device,
@@ -513,8 +513,7 @@ template< typename Device,
             typename Reduction,
             typename ResultKeeper,
             typename Real,
             int BlockDim,
             typename... Args >
             int BlockDim >
__device__
void
BiEllpackView< Device, Index, Organization, WarpSize >::
@@ -524,10 +523,9 @@ reduceSegmentsKernelWithAllParameters( IndexType gridIdx,
                                          Fetch fetch,
                                          Reduction reduction,
                                          ResultKeeper keeper,
                                          Real zero,
                                          Args... args ) const
                                          Real zero ) const
{
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >(), args... ) );
   using RealType = decltype( fetch( IndexType(), IndexType(), IndexType(), std::declval< bool& >() ) );
   const IndexType segmentIdx = ( gridIdx * Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x + first;
   if( segmentIdx >= last )
      return;
@@ -569,8 +567,7 @@ template< typename Device,
             typename Reduction,
             typename ResultKeeper,
             typename Real,
             int BlockDim,
             typename... Args >
             int BlockDim >
__device__
void
BiEllpackView< Device, Index, Organization, WarpSize >::
@@ -580,10 +577,9 @@ reduceSegmentsKernel( IndexType gridIdx,
                         Fetch fetch,
                         Reduction reduction,
                         ResultKeeper keeper,
                         Real zero,
                         Args... args ) const
                         Real zero ) const
{
   using RealType = decltype( fetch( IndexType(), std::declval< bool& >(), args... ) );
   using RealType = decltype( fetch( IndexType(), std::declval< bool& >() ) );
   Index segmentIdx = ( gridIdx * Cuda::getMaxGridSize() + blockIdx.x ) * blockDim.x + threadIdx.x + first;

   const IndexType strip = segmentIdx >> getLogWarpSize();
+4 −4
Original line number Diff line number Diff line
@@ -123,11 +123,11 @@ class CSR
      /***
       * \brief Go over all segments and perform a reduction in each of them.
       */
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
      void reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceSegments( IndexType first, IndexType last, Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const;

      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
      void reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real >
      void reduceAllSegments( Fetch& fetch, const Reduction& reduction, ResultKeeper& keeper, const Real& zero ) const;

      CSR& operator=( const CSR& rhsSegments ) = default;

Loading