From fd6b28fbd557a5d29b04b6726618da198c9dc5e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Fri, 12 Feb 2021 20:45:52 +0100 Subject: [PATCH] Optimizing lambda functions for SpMV in sparse matrix view!!! --- src/TNL/Matrices/SparseMatrixView.hpp | 29 +++++++++++++++++++++++++-- 1 file changed, 27 insertions(+), 2 deletions(-) diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp index ae83dfa891..8cf807335e 100644 --- a/src/TNL/Matrices/SparseMatrixView.hpp +++ b/src/TNL/Matrices/SparseMatrixView.hpp @@ -462,12 +462,37 @@ vectorProduct( const InVector& inVector, outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + matrixMultiplicator * value; } }; + auto keeperDirect = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = value; + }; + auto keeperMatrixMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = matrixMultiplicator * value; + }; + auto keeperVectorMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + value; + }; + if( lastRow == 0 ) lastRow = this->getRows(); if( isSymmetric() ) - this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 ); + this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 ); else - this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 ); + { + if( outVectorMultiplicator == 0.0 ) + { + if( matrixMultiplicator == 1.0 ) + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperDirect, ( ComputeRealType ) 0.0 ); + else + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperMatrixMult, ( ComputeRealType ) 0.0 ); + } + else + { + if( matrixMultiplicator == 1.0 ) + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperVectorMult, ( ComputeRealType ) 0.0 ); + else + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 ); + } + } } template< typename Real, -- GitLab