Commit 335dd0fc authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Fixing sparse matrix to work with StaticVector as RealType.

parent 4e62a8ff
Loading
Loading
Loading
Loading
+5 −5
Original line number Diff line number Diff line
@@ -168,20 +168,20 @@ getNumberOfNonzeroMatrixElements() const
      Containers::Vector< IndexType, DeviceType, IndexType > row_sums( this->getRows(), 0 );
      auto row_sums_view = row_sums.getView();
      const auto columnIndexesView = this->columnIndexes.getConstView();
      auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType localIdx, IndexType globalIdx, bool& compute ) -> RealType {
      auto fetch = [=] __cuda_callable__ ( IndexType row, IndexType localIdx, IndexType globalIdx, bool& compute ) -> IndexType {
         const IndexType column = columnIndexesView[ globalIdx ];
         compute = ( column != paddingIndex );
         if( ! compute )
            return 0.0;
         return 1 + ( column != row && column < rows && row < columns ); // the addition is for non-diagonal elements
      };
      auto reduction = [] __cuda_callable__ ( RealType& sum, const RealType& value ) {
      auto reduction = [] __cuda_callable__ ( IndexType& sum, const IndexType& value ) {
         sum += value;
      };
      auto keeper = [=] __cuda_callable__ ( IndexType row, const RealType& value ) mutable {
      auto keeper = [=] __cuda_callable__ ( IndexType row, const IndexType& value ) mutable {
         row_sums_view[ row ] = value;
      };
      this->segments.segmentsReduction( 0, this->getRows(), fetch, reduction, keeper, ( RealType ) 0.0 );
      this->segments.segmentsReduction( 0, this->getRows(), fetch, reduction, keeper, ( IndexType ) 0 );
      return sum( row_sums );
   }
}
@@ -648,7 +648,7 @@ print( std::ostream& str ) const
         for( IndexType column = 0; column < this->getColumns(); column++ )
         {
            auto value = this->getElement( row, column );
            if( value )
            if( value != ( RealType ) 0 )
               str << " Col:" << column << "->" << value << "\t";
         }
         str << std::endl;