Commit ad461c2d authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Improved GMRES - less storage for the CWY variant

parent 62489092
Loading
Loading
Loading
Loading
+50 −6
Original line number Diff line number Diff line
@@ -303,7 +303,9 @@ orthogonalize_CWY( const int m, const RealType normb, const RealType beta )
       * Generate new basis vector v_i, using the compact WY representation:
       *     v_i = (I - Y_i * T_i Y_i^T) * e_i
       */
      v_i.bind( &V.getData()[ i * ldSize ], size );
      // vectors v_i are not stored, they can be reconstructed in the update() method
//      v_i.bind( &V.getData()[ i * ldSize ], size );
      v_i.bind( V.getData(), size );
      hauseholder_cwy( v_i, i );

      if( i < m ) {
@@ -581,11 +583,49 @@ update( const int k,
         y[ j ] -= H[ j + i * ( m + 1 ) ] * y[ i ];
   }

   if( variant != Variant::CWY ) {
      // x = V * y + x
      MatrixOperations< DeviceType >::gemv( size, k + 1,
                                            1.0, V.getData(), ldSize, y,
                                            1.0, Traits::getLocalVectorView( x ).getData() );
   }
   else {
      // The vectors v_i are not stored, they can be reconstructed as P_0...P_j * e_j.
      // Hence, for j = 0, ... k:  x += y_j P_0...P_j e_j,
      // or equivalently: x += \sum_0^k y_j e_j - Y_k T_k \sum_0^k y_j Y_j^T e_j

      RealType aux[ k + 1 ];
      for( int j = 0; j <= k; j++ )
         aux[ j ] = 0;

      for( int j = 0; j <= k; j++ ) {
         // aux += y_j * Y_j^T * e_j
         // the upper (m+1)x(m+1) submatrix of Y is duplicated on host
         // (faster access than from the device and it is broadcasted to all processes)
         for( int i = 0; i <= j; i++ )
            aux[ i ] += y[ j ] * YL[ j + i * (restarting_max + 1) ];
      }

      // aux = T_{k+1} * aux
      // Note that T_{k+1} is upper triangular, so we can overwrite the aux vector with the result in place
      for( int i = 0; i <= k; i++ ) {
         RealType aux2 = 0.0;
         for( int j = i; j <= k; j++ )
            aux2 += T[ i + j * (restarting_max + 1) ] * aux[ j ];
         aux[ i ] = aux2;
      }

      // x -= Y_{k+1} * aux
      MatrixOperations< DeviceType >::gemv( size, k + 1,
                                            -1.0, Y.getData(), ldSize, aux,
                                            1.0, Traits::getLocalVectorView( x ).getData() );

      // x += y
      if( localOffset == 0 )
         for( int j = 0; j <= k; j++ )
            x.setElement( j, x.getElement( j ) + y[ j ] );
   }
}

template< typename Matrix >
void
@@ -669,7 +709,6 @@ setSize( const VectorViewType& x )
   r.setLike( x );
   w.setLike( x );
   _M_tmp.setLike( x );
   V.setSize( ldSize * ( m + 1 ) );
   cs.setSize( m + 1 );
   sn.setSize( m + 1 );
   H.setSize( ( m + 1 ) * m );
@@ -681,6 +720,11 @@ setSize( const VectorViewType& x )
      Y.setSize( ldSize * ( m + 1 ) );
      T.setSize( (m + 1) * (m + 1) );
      YL.setSize( (m + 1) * (m + 1) );
      // vectors v_i are not stored, they can be reconstructed in the update() method
      V.setLike( x );
   }
   else {
      V.setSize( ldSize * ( m + 1 ) );
   }
}