From fd6b28fbd557a5d29b04b6726618da198c9dc5e6 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Fri, 12 Feb 2021 20:45:52 +0100
Subject: [PATCH] Optimizing lambda functions for SpMV in sparse matrix view!!!

---
 src/TNL/Matrices/SparseMatrixView.hpp | 29 +++++++++++++++++++++++++--
 1 file changed, 27 insertions(+), 2 deletions(-)

diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp
index ae83dfa891..8cf807335e 100644
--- a/src/TNL/Matrices/SparseMatrixView.hpp
+++ b/src/TNL/Matrices/SparseMatrixView.hpp
@@ -462,12 +462,37 @@ vectorProduct( const InVector& inVector,
             outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + matrixMultiplicator * value;
       }
    };
+   auto keeperDirect = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
+      outVectorView[ row ] = value;
+   };
+   auto keeperMatrixMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
+      outVectorView[ row ] = matrixMultiplicator * value;
+   };
+   auto keeperVectorMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
+      outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + value;
+   };
+
    if( lastRow == 0 )
       lastRow = this->getRows();
    if( isSymmetric() )
-      this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 );
+      this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 );
    else
-      this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 );
+   {
+      if( outVectorMultiplicator == 0.0 )
+      {
+         if( matrixMultiplicator == 1.0 )
+            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperDirect, ( ComputeRealType ) 0.0 );
+         else
+            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperMatrixMult, ( ComputeRealType ) 0.0 );
+      }
+      else
+      {
+         if( matrixMultiplicator == 1.0 )
+            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperVectorMult, ( ComputeRealType ) 0.0 );
+         else
+            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 );
+      }
+   }
 }
 
 template< typename Real,
-- 
GitLab