Commit fd6b28fb authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Optimizing lambda functions for SpMV in sparse matrix view!!!

parent 936f2c32
Loading
Loading
Loading
Loading
+27 −2
Original line number Diff line number Diff line
@@ -462,12 +462,37 @@ vectorProduct( const InVector& inVector,
            outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + matrixMultiplicator * value;
      }
   };
   auto keeperDirect = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
      outVectorView[ row ] = value;
   };
   auto keeperMatrixMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
      outVectorView[ row ] = matrixMultiplicator * value;
   };
   auto keeperVectorMult = [=] __cuda_callable__ ( IndexType row, const ComputeRealType& value ) mutable {
      outVectorView[ row ] = outVectorMultiplicator * outVectorView[ row ] + value;
   };

   if( lastRow == 0 )
      lastRow = this->getRows();
   if( isSymmetric() )
      this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 );
      this->segments.segmentsReduction( firstRow, lastRow, symmetricFetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 );
   else
      this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeper, ( ComputeRealType ) 0.0 );
   {
      if( outVectorMultiplicator == 0.0 )
      {
         if( matrixMultiplicator == 1.0 )
            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperDirect, ( ComputeRealType ) 0.0 );
         else
            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperMatrixMult, ( ComputeRealType ) 0.0 );
      }
      else
      {
         if( matrixMultiplicator == 1.0 )
            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperVectorMult, ( ComputeRealType ) 0.0 );
         else
            this->segments.segmentsReduction( firstRow, lastRow, fetch, std::plus<>{}, keeperGeneral, ( ComputeRealType ) 0.0 );
      }
   }
}

template< typename Real,