Commit 183f565c authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Adding CSR Light kernel.

parent a45e7910
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -33,9 +33,9 @@ class CSR
      using OffsetsHolder = Containers::Vector< Index, DeviceType, IndexType, IndexAllocator >;
      using SegmentsSizes = OffsetsHolder;
      template< typename Device_, typename Index_ >
      using ViewTemplate = CSRView< Device_, Index_ >;
      using ViewType = CSRView< Device, Index >;
      using ConstViewType = CSRView< Device, std::add_const_t< IndexType > >;
      using ViewTemplate = CSRView< Device_, Index_, KernelType_ >;
      using ViewType = CSRView< Device, Index, KernelType_ >;
      using ConstViewType = CSRView< Device, std::add_const_t< IndexType >, KernelType_ >;
      using SegmentViewType = SegmentView< IndexType, RowMajorOrder >;
      CSRKernelTypes KernelType = KernelType_;

+10 −2
Original line number Diff line number Diff line
@@ -14,6 +14,7 @@
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Algorithms/Segments/CSRView.h>
#include <TNL/Algorithms/Segments/details/CSR.h>
#include <TNL/Algorithms/Segments/details/CSRKernels.h>
#include <TNL/Algorithms/Segments/details/LambdaAdapter.h>

namespace TNL {
@@ -217,7 +218,7 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio
{
   using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType;
   const auto offsetsView = this->offsets.getConstView();
   if( KernelType == CSRScalarKernel )
   if( KernelType == CSRScalarKernel || std::is_same< DeviceType, TNL::Devices::Host >::value )
   {
      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
         const IndexType begin = offsetsView[ segmentIdx ];
@@ -231,6 +232,13 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio
      };
      Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
   }
   if( KernelType == CSRVectorKernel )
      details::RowsReductionVectorKernelCaller( offsetsView, first, last, fetch, reduction, keeper, zero, args... );
   if( KernelType == CSRLightKernel )
   {
      const IndexType elementsInSegment = ceil( this->getSize() / this->getSegmentsCount() );
      details::RowsReductionLightKernelCaller( elementsInSegment, offsetsView, first, last, fetch, reduction, keeper, zero, args... );
   }
}

template< typename Device,