diff --git a/src/TNL/Matrices/SparseMatrixView.hpp b/src/TNL/Matrices/SparseMatrixView.hpp index ae83dfa891c75e553fdf5e38ad92ac8cde51cb76..8cf807335e92954b7f8987d75533ddae243969be 100644 --- a/src/TNL/Matrices/SparseMatrixView.hpp +++ b/src/TNL/Matrices/SparseMatrixView.hpp @@ -462,12 +462,37 @@ vectorProduct( const InVector& inVector, outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + matrixMultiplicator * value; } }; + auto keeperDirect = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = value; + }; + auto keeperMatrixMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = matrixMultiplicator * value; + }; + auto keeperVectorMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable { + outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + value; + }; + if( lastRow == 0 ) lastRow = this->getRows(); if( isSymmetric() ) - this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 ); + this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 ); else - this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 ); + { + if( outVectorMultiplicator == 0.0 ) + { + if( matrixMultiplicator == 1.0 ) + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperDirect, ( ComputeRealType ) 0.0 ); + else + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperMatrixMult, ( ComputeRealType ) 0.0 ); + } + else + { + if( matrixMultiplicator == 1.0 ) + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperVectorMult, ( ComputeRealType ) 0.0 ); + else + this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 ); + } + } } template< typename Real,