From e07a01684e5f10209a6660c8cfb4e9b2ebcd20db Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Mon, 30 Dec 2019 15:20:48 +0100
Subject: [PATCH] Added boolean compute to stop segment reduction.

---
 src/TNL/Containers/Segments/CSR.hpp           |  7 +++---
 src/TNL/Containers/Segments/CSRView.hpp       |  7 +++---
 src/TNL/Containers/Segments/Ellpack.hpp       | 13 ++++++-----
 src/TNL/Containers/Segments/EllpackView.hpp   | 13 ++++++-----
 src/TNL/Containers/Segments/SlicedEllpack.hpp | 14 +++++++-----
 .../Containers/Segments/SlicedEllpackView.hpp | 22 +++++++++++--------
 src/TNL/Matrices/SparseMatrix.hpp             |  7 +++---
 src/TNL/Matrices/SparseMatrixView.hpp         |  5 +++--
 .../Containers/Segments/SegmentsTest.hpp      |  2 +-
 9 files changed, 51 insertions(+), 39 deletions(-)

diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp
index 9ab2186c36..83da548fc9 100644
--- a/src/TNL/Containers/Segments/CSR.hpp
+++ b/src/TNL/Containers/Segments/CSR.hpp
@@ -218,14 +218,15 @@ void
 CSR< Device, Index, IndexAllocator >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
-   using RealType = decltype( fetch( IndexType(), IndexType() ) );
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    const auto offsetsView = this->offsets.getConstView();
    auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
       const IndexType begin = offsetsView[ i ];
       const IndexType end = offsetsView[ i + 1 ];
       RealType aux( zero );
-      for( IndexType j = begin; j < end; j++  )
-         reduction( aux, fetch( i, j, args... ) );
+      bool compute( true );
+      for( IndexType j = begin; j < end && compute; j++  )
+         reduction( aux, fetch( i, j, compute, args... ) );
       keeper( i, aux );
    };
    Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp
index f4f59370d4..b4304ee321 100644
--- a/src/TNL/Containers/Segments/CSRView.hpp
+++ b/src/TNL/Containers/Segments/CSRView.hpp
@@ -204,14 +204,15 @@ void
 CSRView< Device, Index >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
-   using RealType = decltype( fetch( IndexType(), IndexType() ) );
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    const auto offsetsView = this->offsets.getConstView();
    auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
       const IndexType begin = offsetsView[ i ];
       const IndexType end = offsetsView[ i + 1 ];
       RealType aux( zero );
-      for( IndexType j = begin; j < end; j++  )
-         reduction( aux, fetch( i, j, args... ) );
+      bool compute( true );
+      for( IndexType j = begin; j < end && compute; j++  )
+         reduction( aux, fetch( i, j, compute, args... ) );
       keeper( i, aux );
    };
    Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp
index 9f7702a6f2..ebc2b360eb 100644
--- a/src/TNL/Containers/Segments/Ellpack.hpp
+++ b/src/TNL/Containers/Segments/Ellpack.hpp
@@ -306,31 +306,32 @@ void
 Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    if( RowMajorOrder )
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
       const IndexType segmentSize = this->segmentSize;
       auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i * segmentSize;
          const IndexType end = begin + segmentSize;
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j++  )
-            reduction( aux, fetch( i, j, args... ) );
+         bool compute( true );
+         for( IndexType j = begin; j < end && compute; j++  )
+            reduction( aux, fetch( i, j, compute, args... ) );
          keeper( i, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
    }
    else
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
       const IndexType storageSize = this->getStorageSize();
       const IndexType alignedSize = this->alignedSize;
       auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i;
          const IndexType end = storageSize;
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j += alignedSize  )
-            reduction( aux, fetch( i, j, args... ) );
+         bool compute( true );
+         for( IndexType j = begin; j < end && compute; j += alignedSize  )
+            reduction( aux, fetch( i, j, compute, args... ) );
          keeper( i, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Containers/Segments/EllpackView.hpp b/src/TNL/Containers/Segments/EllpackView.hpp
index f5dba4f3d7..dc6bd485dd 100644
--- a/src/TNL/Containers/Segments/EllpackView.hpp
+++ b/src/TNL/Containers/Segments/EllpackView.hpp
@@ -245,31 +245,32 @@ void
 EllpackView< Device, Index, RowMajorOrder, Alignment >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    if( RowMajorOrder )
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
       const IndexType segmentSize = this->segmentSize;
       auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i * segmentSize;
          const IndexType end = begin + segmentSize;
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j++  )
-            reduction( aux, fetch( i, j, args... ) );
+         bool compute( true );
+         for( IndexType j = begin; j < end && compute; j++  )
+            reduction( aux, fetch( i, j, compute, args... ) );
          keeper( i, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
    }
    else
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
       const IndexType storageSize = this->getStorageSize();
       const IndexType alignedSize = this->alignedSize;
       auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i;
          const IndexType end = storageSize;
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j += alignedSize  )
-            reduction( aux, fetch( i, j, args... ) );
+         bool compute( true );
+         for( IndexType j = begin; j < end && compute; j += alignedSize  )
+            reduction( aux, fetch( i, j, compute, args... ) );
          keeper( i, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp
index e2aec924d5..ecd32abb25 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.hpp
+++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp
@@ -127,7 +127,7 @@ setSegmentsSizes( const SizesHolder& sizes )
    const auto sizes_view = sizes.getConstView();
    auto slices_view = this->sliceOffsets.getView();
    auto slice_segment_size_view = this->sliceSegmentSizes.getView();
-   auto fetch = [=] __cuda_callable__ ( IndexType segmentIdx, IndexType globalIdx ) -> IndexType {
+   auto fetch = [=] __cuda_callable__ ( IndexType segmentIdx, IndexType globalIdx, bool& compute ) -> IndexType {
       if( globalIdx < _size )
          return sizes_view[ globalIdx ];
       return 0;
@@ -341,7 +341,7 @@ void
 SlicedEllpack< Device, Index, IndexAllocator, RowMajorOrder, SliceSize >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
-   using RealType = decltype( fetch( IndexType(), IndexType() ) );
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    const auto sliceSegmentSizes_view = this->sliceSegmentSizes.getConstView();
    const auto sliceOffsets_view = this->sliceOffsets.getConstView();
    if( RowMajorOrder )
@@ -353,8 +353,9 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx * segmentSize;
          const IndexType end = begin + segmentSize;
          RealType aux( zero );
-         for( IndexType globalIdx = begin; globalIdx< end; globalIdx++  )
-            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx< end && compute; globalIdx++  )
+            reduction( aux, fetch( segmentIdx, globalIdx, compute, args... ) );
          keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -368,8 +369,9 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx;
          const IndexType end = sliceOffsets_view[ sliceIdx + 1 ];
          RealType aux( zero );
-         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize  )
-            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx += SliceSize  )
+            reduction( aux, fetch( segmentIdx, globalIdx, compute, args... ) );
          keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.hpp b/src/TNL/Containers/Segments/SlicedEllpackView.hpp
index 139a09a15e..41b49ed150 100644
--- a/src/TNL/Containers/Segments/SlicedEllpackView.hpp
+++ b/src/TNL/Containers/Segments/SlicedEllpackView.hpp
@@ -247,8 +247,9 @@ forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx * segmentSize;
          const IndexType end = begin + segmentSize;
          IndexType localIdx( 0 );
-         for( IndexType globalIdx = begin; globalIdx < end; globalIdx++  )
-            if( ! f( segmentIdx, localIdx++, globalIdx, args... ) )
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx++  )
+            if( ! f( segmentIdx, localIdx++, globalIdx, compute, args... ) )
                break;
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -262,8 +263,9 @@ forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx;
          const IndexType end = sliceOffsets_view[ sliceIdx + 1 ];
          IndexType localIdx( 0 );
-         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize )
-            if( ! f( segmentIdx, localIdx++, globalIdx, args... ) )
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx += SliceSize )
+            if( ! f( segmentIdx, localIdx++, globalIdx, compute, args... ) )
                break;
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -291,7 +293,7 @@ void
 SlicedEllpackView< Device, Index, RowMajorOrder, SliceSize >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
-   using RealType = decltype( fetch( IndexType(), IndexType() ) );
+   using RealType = decltype( fetch( IndexType(), IndexType(), std::declval< bool& >(), args... ) );
    const auto sliceSegmentSizes_view = this->sliceSegmentSizes.getConstView();
    const auto sliceOffsets_view = this->sliceOffsets.getConstView();
    if( RowMajorOrder )
@@ -303,8 +305,9 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx * segmentSize;
          const IndexType end = begin + segmentSize;
          RealType aux( zero );
-         for( IndexType globalIdx = begin; globalIdx< end; globalIdx++  )
-            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx< end && compute; globalIdx++  )
+            reduction( aux, fetch( segmentIdx, globalIdx, compute, args... ) );
          keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -318,8 +321,9 @@ segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& red
          const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx;
          const IndexType end = sliceOffsets_view[ sliceIdx + 1 ];
          RealType aux( zero );
-         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize  )
-            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         bool compute( true );
+         for( IndexType globalIdx = begin; globalIdx < end && compute; globalIdx += SliceSize  )
+            reduction( aux, fetch( segmentIdx, globalIdx, compute, args... ) );
          keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index c0dd3b9a3e..691157a9c4 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -628,9 +628,10 @@ vectorProduct( const InVector& inVector,
    const auto valuesView = this->values.getConstView();
    const auto columnIndexesView = this->columnIndexes.getConstView();
    const IndexType paddingIndex = this->getPaddingIndex();
-   auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType offset ) -> RealType {
+   auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType offset, bool& compute ) -> RealType {
       const IndexType column = columnIndexesView[ offset ];
-      if( column == paddingIndex )
+      compute = ( column != paddingIndex );
+      if( ! compute )
          return 0.0;
       return valuesView[ offset ] * inVectorView[ column ];
    };
@@ -658,7 +659,7 @@ rowsReduction( IndexType first, IndexType last, Fetch& fetch, Reduce& reduce, Ke
    const auto columns_view = this->columnIndexes.getConstView();
    const auto values_view = this->values.getConstView();
    const IndexType paddingIndex_ = this->getPaddingIndex();
-   auto fetch_ = [=] __cuda_callable__ ( IndexType rowIdx, IndexType globalIdx ) mutable -> decltype( fetch( IndexType(), IndexType(), RealType() ) ) {
+   auto fetch_ = [=] __cuda_callable__ ( IndexType rowIdx, IndexType globalIdx, bool& compute ) mutable -> decltype( fetch( IndexType(), IndexType(), RealType() ) ) {
       IndexType columnIdx = columns_view[ globalIdx ];
       if( columnIdx != paddingIndex_ )
          return fetch( rowIdx, columnIdx, values_view[ globalIdx ] );
diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp
index 5ac494a9b6..ce0e7aa181 100644
--- a/src/TNL/Matrices/SparseMatrixView.hpp
+++ b/src/TNL/Matrices/SparseMatrixView.hpp
@@ -508,9 +508,10 @@ vectorProduct( const InVector& inVector,
    const auto valuesView = this->values.getConstView();
    const auto columnIndexesView = this->columnIndexes.getConstView();
    const IndexType paddingIndex = this->getPaddingIndex();
-   auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType offset ) -> RealType {
+   auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType offset, bool& compute ) -> RealType {
       const IndexType column = columnIndexesView[ offset ];
-      if( column == paddingIndex )
+      compute = ( column != paddingIndex );
+      if( ! compute )
          return 0.0;
       return valuesView[ offset ] * inVectorView[ column ];
    };
diff --git a/src/UnitTests/Containers/Segments/SegmentsTest.hpp b/src/UnitTests/Containers/Segments/SegmentsTest.hpp
index 5e74f96b03..6189c2e9a4 100644
--- a/src/UnitTests/Containers/Segments/SegmentsTest.hpp
+++ b/src/UnitTests/Containers/Segments/SegmentsTest.hpp
@@ -143,7 +143,7 @@ void test_AllReduction_MaximumInSegments()
 
    const auto v_view = v.getConstView();
    auto result_view = result.getView();
-   auto fetch = [=] __cuda_callable__ ( IndexType segmentIdx, IndexType globalIdx ) -> IndexType {
+   auto fetch = [=] __cuda_callable__ ( IndexType segmentIdx, IndexType globalIdx, bool& compute ) -> IndexType {
       return v_view[ globalIdx ];
    };
    auto reduce = [] __cuda_callable__ ( IndexType& a, const IndexType b ) {
-- 
GitLab