Loading src/TNL/Algorithms/Segments/CSR.h +3 −3 Original line number Diff line number Diff line Loading @@ -33,9 +33,9 @@ class CSR using OffsetsHolder = Containers::Vector< Index, DeviceType, IndexType, IndexAllocator >; using SegmentsSizes = OffsetsHolder; template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< IndexType > >; using ViewTemplate = CSRView< Device_, Index_, KernelType_ >; using ViewType = CSRView< Device, Index, KernelType_ >; using ConstViewType = CSRView< Device, std::add_const_t< IndexType >, KernelType_ >; using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; CSRKernelTypes KernelType = KernelType_; Loading src/TNL/Algorithms/Segments/CSRView.hpp +10 −2 Original line number Diff line number Diff line Loading @@ -14,6 +14,7 @@ #include <TNL/Algorithms/ParallelFor.h> #include <TNL/Algorithms/Segments/CSRView.h> #include <TNL/Algorithms/Segments/details/CSR.h> #include <TNL/Algorithms/Segments/details/CSRKernels.h> #include <TNL/Algorithms/Segments/details/LambdaAdapter.h> namespace TNL { Loading Loading @@ -217,7 +218,7 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio { using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType; const auto offsetsView = this->offsets.getConstView(); if( KernelType == CSRScalarKernel ) if( KernelType == CSRScalarKernel || std::is_same< DeviceType, TNL::Devices::Host >::value ) { auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable { const IndexType begin = offsetsView[ segmentIdx ]; Loading @@ -231,6 +232,13 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio }; Algorithms::ParallelFor< Device >::exec( first, last, l, args... ); } if( KernelType == CSRVectorKernel ) details::RowsReductionVectorKernelCaller( offsetsView, first, last, fetch, reduction, keeper, zero, args... ); if( KernelType == CSRLightKernel ) { const IndexType elementsInSegment = ceil( this->getSize() / this->getSegmentsCount() ); details::RowsReductionLightKernelCaller( elementsInSegment, offsetsView, first, last, fetch, reduction, keeper, zero, args... ); } } template< typename Device, Loading Loading
src/TNL/Algorithms/Segments/CSR.h +3 −3 Original line number Diff line number Diff line Loading @@ -33,9 +33,9 @@ class CSR using OffsetsHolder = Containers::Vector< Index, DeviceType, IndexType, IndexAllocator >; using SegmentsSizes = OffsetsHolder; template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< IndexType > >; using ViewTemplate = CSRView< Device_, Index_, KernelType_ >; using ViewType = CSRView< Device, Index, KernelType_ >; using ConstViewType = CSRView< Device, std::add_const_t< IndexType >, KernelType_ >; using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; CSRKernelTypes KernelType = KernelType_; Loading
src/TNL/Algorithms/Segments/CSRView.hpp +10 −2 Original line number Diff line number Diff line Loading @@ -14,6 +14,7 @@ #include <TNL/Algorithms/ParallelFor.h> #include <TNL/Algorithms/Segments/CSRView.h> #include <TNL/Algorithms/Segments/details/CSR.h> #include <TNL/Algorithms/Segments/details/CSRKernels.h> #include <TNL/Algorithms/Segments/details/LambdaAdapter.h> namespace TNL { Loading Loading @@ -217,7 +218,7 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio { using RealType = typename details::FetchLambdaAdapter< Index, Fetch >::ReturnType; const auto offsetsView = this->offsets.getConstView(); if( KernelType == CSRScalarKernel ) if( KernelType == CSRScalarKernel || std::is_same< DeviceType, TNL::Devices::Host >::value ) { auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable { const IndexType begin = offsetsView[ segmentIdx ]; Loading @@ -231,6 +232,13 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, const Reductio }; Algorithms::ParallelFor< Device >::exec( first, last, l, args... ); } if( KernelType == CSRVectorKernel ) details::RowsReductionVectorKernelCaller( offsetsView, first, last, fetch, reduction, keeper, zero, args... ); if( KernelType == CSRLightKernel ) { const IndexType elementsInSegment = ceil( this->getSize() / this->getSegmentsCount() ); details::RowsReductionLightKernelCaller( elementsInSegment, offsetsView, first, last, fetch, reduction, keeper, zero, args... ); } } template< typename Device, Loading