Commit ee6823f0 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Distributed GMRES

parent 0a561076
Loading
Loading
Loading
Loading
+48 −27
Original line number Diff line number Diff line
@@ -14,8 +14,6 @@

#include "LinearSolver.h"

#include <TNL/Containers/Vector.h>

namespace TNL {
namespace Solvers {
namespace Linear {
@@ -25,12 +23,19 @@ class GMRES
: public LinearSolver< Matrix >
{
   using Base = LinearSolver< Matrix >;

   // compatibility shortcuts
   using Traits = Linear::Traits< Matrix >;
   using CommunicatorType = typename Traits::CommunicatorType;

public:
   using RealType = typename Base::RealType;
   using DeviceType = typename Base::DeviceType;
   using IndexType = typename Base::IndexType;
   // distributed vectors/views
   using VectorViewType = typename Base::VectorViewType;
   using ConstVectorViewType = typename Base::ConstVectorViewType;
   using VectorType = typename Traits::VectorType;

   String getType() const;

@@ -43,11 +48,12 @@ public:
   bool solve( ConstVectorViewType b, VectorViewType x ) override;

protected:
   using ConstDeviceView = Containers::VectorView< const RealType, DeviceType, IndexType >;
   using DeviceView = Containers::VectorView< RealType, DeviceType, IndexType >;
   using HostView = Containers::VectorView< RealType, Devices::Host, IndexType >;
   using DeviceVector = Containers::Vector< RealType, DeviceType, IndexType >;
   using HostVector = Containers::Vector< RealType, Devices::Host, IndexType >;
   // local vectors/views
   using ConstDeviceView = typename Traits::ConstLocalVectorViewType;
   using DeviceView = typename Traits::LocalVectorViewType;
   using HostView = typename DeviceView::HostType;
   using DeviceVector = typename Traits::LocalVectorType;
   using HostVector = typename DeviceVector::HostType;

   enum class Variant { MGS, MGSR, CWY };

@@ -57,36 +63,37 @@ protected:

   void compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b );

   void preconditioned_matvec( DeviceVector& w, ConstDeviceView v );
   void preconditioned_matvec( VectorViewType w, ConstVectorViewType v );

   void hauseholder_generate( DeviceVector& Y,
                              HostVector& T,
                              const int i,
                              DeviceVector& w );
// nvcc allows __cuda_callable__ lambdas only in public methods
#ifdef __NVCC__
public:
#endif
   void hauseholder_generate( const int i,
                              VectorViewType y_i,
                              ConstVectorViewType z );
#ifdef __NVCC__
protected:
#endif

   void hauseholder_apply_trunc( HostView out,
                                 DeviceVector& Y,
                                 HostVector& T,
                                 const int i,
                                 DeviceVector& w );
                                 VectorViewType y_i,
                                 ConstVectorViewType z );

   void hauseholder_cwy( DeviceVector& w,
                         DeviceVector& Y,
                         HostVector& T,
   void hauseholder_cwy( VectorViewType v,
                         const int i );

   void hauseholder_cwy_transposed( DeviceVector& w,
                                    DeviceVector& Y,
                                    HostVector& T,
   void hauseholder_cwy_transposed( VectorViewType z,
                                    const int i,
                                    DeviceVector& z );
                                    ConstVectorViewType w );

   template< typename Vector >
   void update( const int k,
                const int m,
                const HostVector& H,
                const HostVector& s,
                DeviceVector& v,
                DeviceVector& V,
                Vector& x );

   void generatePlaneRotation( RealType& dx,
@@ -101,15 +108,28 @@ protected:

   void apply_givens_rotations( const int i, const int m );

   void setSize( const VectorViewType& x );

   // Specialized methods to distinguish between normal and distributed matrices
   // in the implementation.
   template< typename M >
   static IndexType getLocalOffset( const M& m )
   {
      return 0;
   }

   void setSize( const IndexType size );
   template< typename M >
   static IndexType getLocalOffset( const DistributedContainers::DistributedMatrix< M >& m )
   {
      return m.getLocalRowRange().getBegin();
   }

   // selected GMRES variant
   Variant variant = Variant::CWY;

   // single vectors
   DeviceVector r, w, z, _M_tmp;
   // matrices (in column-major format)
   // single vectors (distributed)
   VectorType r, w, z, _M_tmp;
   // matrices (in column-major format) (local)
   DeviceVector V, Y;
   // (CWY only) duplicate of the upper (m+1)x(m+1) submatrix of Y (it is lower triangular) for fast access
   HostVector YL, T;
@@ -118,6 +138,7 @@ protected:

   IndexType size = 0;
   IndexType ldSize = 0;
   IndexType localOffset = 0;
   int restarting_min = 10;
   int restarting_max = 10;
   int restarting_step_min = 3;
+104 −124
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ solve( ConstVectorViewType b, VectorViewType x )
                << ", d_max = " << restarting_step_max << std::endl;
      return false;
   }
   setSize( this->matrix->getRows() );
   setSize( x );

   // initialize the norm of the preconditioned right-hand-side
   RealType normb;
@@ -183,13 +183,15 @@ int
GMRES< Matrix >::
orthogonalize_MGS( const int m, const RealType normb, const RealType beta )
{
   DeviceView vi, vk;
   // initial binding to _M_tmp sets the correct local range, global size and
   // communication group for distributed views
   VectorViewType v_i( _M_tmp ), v_k( _M_tmp );

   /***
    * v_0 = r / | r | =  1.0 / beta * r
    */
   vi.bind( V.getData(), ldSize );
   vi.addVector( r, 1.0 / beta, 0.0 );
   v_i.bind( V.getData(), size );
   v_i.addVector( r, 1.0 / beta, 0.0 );

   H.setValue( 0.0 );
   s.setValue( 0.0 );
@@ -199,28 +201,28 @@ orthogonalize_MGS( const int m, const RealType normb, const RealType beta )
    * Starting m-loop
    */
   for( int i = 0; i < m && this->nextIteration(); i++ ) {
      vi.bind( &( V.getData()[ i * ldSize ] ), size );
      v_i.bind( &V.getData()[ i * ldSize ], size );
      /****
       * Solve w from M w = A v_i
       */
      preconditioned_matvec( w, vi );
      preconditioned_matvec( w, v_i );

      for( int k = 0; k <= i; k++ )
         H[ k + i * ( m + 1 ) ] = 0.0;
      const int reorthogonalize = (variant == Variant::MGSR) ? 2 : 1;
      for( int l = 0; l < reorthogonalize; l++ )
         for( int k = 0; k <= i; k++ ) {
            vk.bind( &( V.getData()[ k * ldSize ] ), size );
            v_k.bind( &V.getData()[ k * ldSize ], size );
            /***
             * H_{k,i} = ( w, v_k )
             */
            RealType H_k_i = w.scalarProduct( vk );
            RealType H_k_i = w.scalarProduct( v_k );
            H[ k + i * ( m + 1 ) ] += H_k_i;

            /****
             * w = w - H_{k,i} v_k
             */
            w.addVector( vk, -H_k_i );
            w.addVector( v_k, -H_k_i );
         }
      /***
       * H_{i+1,i} = |w|
@@ -231,8 +233,8 @@ orthogonalize_MGS( const int m, const RealType normb, const RealType beta )
      /***
       * v_{i+1} = w / |w|
       */
      vi.bind( &( V.getData()[ ( i + 1 ) * ldSize ] ), size );
      vi.addVector( w, 1.0 / normw, 0.0 );
      v_i.bind( &V.getData()[ ( i + 1 ) * ldSize ], size );
      v_i.addVector( w, 1.0 / normw, 0.0 );

      /****
       * Applying the Givens rotations G_0, ..., G_i
@@ -254,7 +256,9 @@ int
GMRES< Matrix >::
orthogonalize_CWY( const int m, const RealType normb, const RealType beta )
{
   DeviceVector vi, vk;
   // initial binding to _M_tmp sets the correct local range, global size and
   // communication group for distributed views
   VectorViewType v_i( _M_tmp ), y_i( _M_tmp );

   /***
    * z = r / | r | =  1.0 / beta * r
@@ -278,40 +282,41 @@ orthogonalize_CWY( const int m, const RealType normb, const RealType beta )
      /****
       * Generate new Hauseholder transformation from vector z.
       */
      hauseholder_generate( Y, T, i, z );
      y_i.bind( &Y.getData()[ i * ldSize ], size );
      hauseholder_generate( i, y_i, z );

      if( i == 0 ) {
         /****
          * s = e_1^T * P_i * z
          */
         hauseholder_apply_trunc( s, Y, T, i, z );
         hauseholder_apply_trunc( s, i, y_i, z );
      }
      else {
         /***
          * H_{i-1} = P_i * z
          */
         HostView h( &H.getData()[ (i - 1) * (m + 1) ], m + 1 );
         hauseholder_apply_trunc( h, Y, T, i, z );
         hauseholder_apply_trunc( h, i, y_i, z );
      }

      /***
       * Generate new basis vector v_i, using the compact WY representation:
       *     v_i = (I - Y_i * T_i Y_i^T) * e_i
       */
      vi.bind( &V.getData()[ i * ldSize ], size );
      hauseholder_cwy( vi, Y, T, i );
      v_i.bind( &V.getData()[ i * ldSize ], size );
      hauseholder_cwy( v_i, i );

      if( i < m ) {
         /****
          * Solve w from M w = A v_i
          */
         preconditioned_matvec( w, vi );
         preconditioned_matvec( w, v_i );

         /****
          * Apply all previous Hauseholder transformations, using the compact WY representation:
          *     z = (I - Y_i * T_i^T * Y_i^T) * w
          */
         hauseholder_cwy_transposed( z, Y, T, i, w );
         hauseholder_cwy_transposed( z, i, w );
      }

      /****
@@ -321,7 +326,7 @@ orthogonalize_CWY( const int m, const RealType normb, const RealType beta )
         apply_givens_rotations( i - 1, m );

      this->setResidue( std::fabs( s[ i ] ) / normb );
      if( ! this->checkNextIteration() )
      if( i > 0 && ! this->checkNextIteration() )
         return i - 1;
      else
         this->refreshSolverMonitor();
@@ -352,7 +357,7 @@ compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b
template< typename Matrix >
void
GMRES< Matrix >::
preconditioned_matvec( DeviceVector& w, ConstDeviceView v )
preconditioned_matvec( VectorViewType w, ConstVectorViewType v )
{
   /****
    * w = M.solve(A * v_i);
@@ -365,67 +370,35 @@ preconditioned_matvec( DeviceVector& w, ConstDeviceView v )
      this->matrix->vectorProduct( v, w );
}

#ifdef HAVE_CUDA
template< typename DestinationElement,
          typename SourceElement,
          typename Index >
__global__ void
copyTruncatedVectorKernel( DestinationElement* destination,
                           const SourceElement* source,
                           const Index from,
                           const Index size )
{
   Index elementIdx = blockIdx.x * blockDim.x + threadIdx.x;
   const Index gridSize = blockDim.x * gridDim.x;

   while( elementIdx < from ) {
      destination[ elementIdx ] = (DestinationElement) 0.0;
      elementIdx += gridSize;
   }
   while( elementIdx < size ) {
      destination[ elementIdx ] = source[ elementIdx ];
      elementIdx += gridSize;
   }
}
#endif

template< typename Matrix >
void
GMRES< Matrix >::
hauseholder_generate( DeviceVector& Y,
                      HostVector& T,
                      const int i,
                      DeviceVector& z )
hauseholder_generate( const int i,
                      VectorViewType y_i,
                      ConstVectorViewType z )
{
   DeviceView y_i( &Y.getData()[ i * ldSize ], size );

   // XXX: the upper-right triangle of Y will be full of zeros, which can be exploited for optimization
   if( std::is_same< DeviceType, Devices::Host >::value ) {
      for( IndexType j = 0; j < size; j++ ) {
   if( localOffset == 0 ) {
      TNL_ASSERT_LT( i, size, "upper-right triangle of Y is not on rank 0" );
      auto kernel_truncation = [=] __cuda_callable__ ( IndexType j ) mutable
      {
         if( j < i )
            y_i[ j ] = 0.0;
         else
            y_i[ j ] = z[ j ];
      };
      ParallelFor< DeviceType >::exec( (IndexType) 0, size, kernel_truncation );
   }
   }
   if( std::is_same< DeviceType, Devices::Cuda >::value ) {
#ifdef HAVE_CUDA
      dim3 blockSize( 256 );
      dim3 gridSize;
      gridSize.x = min( Devices::Cuda::getMaxGridSize(), Devices::Cuda::getNumberOfBlocks( size, blockSize.x ) );

      copyTruncatedVectorKernel<<< gridSize, blockSize >>>( y_i.getData(),
                                                            z.getData(),
                                                            i,
                                                            size );
      TNL_CHECK_CUDA_DEVICE;
#else
      throw Exceptions::CudaSupportMissing();
#endif
   else {
      ConstDeviceView z_local = Traits::getLocalVectorView( z );
      DeviceView y_i_local = Traits::getLocalVectorView( y_i );
      y_i_local = z_local;
   }

   // norm of the TRUNCATED vector z
   const RealType normz = y_i.lpNorm( 2.0 );
   RealType norm_yi = 0;
   if( localOffset == 0 ) {
      const RealType y_ii = y_i.getElement( i );
      if( y_ii > 0.0 )
         y_i.setElement( i, y_ii + normz );
@@ -434,7 +407,10 @@ hauseholder_generate( DeviceVector& Y,

      // compute the norm of the y_i vector; equivalent to this calculation by definition:
      //       const RealType norm_yi = y_i.lpNorm( 2.0 );
   const RealType norm_yi = std::sqrt( 2 * (normz * normz + std::fabs( y_ii ) * normz) );
      norm_yi = std::sqrt( 2 * (normz * normz + std::fabs( y_ii ) * normz) );
   }
   // no-op if the problem is not distributed
   CommunicatorType::Bcast( &norm_yi, 1, 0, Traits::getCommunicationGroup( *this->matrix ) );

   // XXX: normalization is slower, but more stable
//   y_i *= 1.0 / norm_yi;
@@ -453,8 +429,10 @@ hauseholder_generate( DeviceVector& Y,
                 size,
                 Y.getData(),
                 ldSize,
                 y_i.getData(),
                 Traits::getLocalVectorView( y_i ).getData(),
                 aux );
      // no-op if the problem is not distributed
      CommunicatorType::Allreduce( aux, i, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );

      // [T_i]_{0..i-1} = - T_{i-1} * t_i * aux
      for( int k = 0; k < i; k++ ) {
@@ -469,49 +447,49 @@ template< typename Matrix >
void
GMRES< Matrix >::
hauseholder_apply_trunc( HostView out,
                         DeviceVector& Y,
                         HostVector& T,
                         const int i,
                         DeviceVector& z )
                         VectorViewType y_i,
                         ConstVectorViewType z )
{
   DeviceView y_i( &Y.getData()[ i * ldSize ], size );

   // copy part of y_i to the YL buffer
   // The upper (m+1)x(m+1) submatrix of Y is duplicated in the YL buffer,
   // which resides on host and is broadcasted from rank 0 to all processes.
   HostView YL_i( &YL[ i * (restarting_max + 1) ], restarting_max + 1 );
   Containers::Algorithms::ArrayOperations< Devices::Host, DeviceType >::copyMemory( YL_i.getData(), Traits::getLocalVectorView( y_i ).getData(), YL_i.getSize() );
   // no-op if the problem is not distributed
   CommunicatorType::Bcast( YL_i.getData(), YL_i.getSize(), 0, Traits::getCommunicationGroup( *this->matrix ) );

   // TODO: is aux always 1?
   const RealType aux = T[ i + i * (restarting_max + 1) ] * y_i.scalarProduct( z );
   if( localOffset == 0 ) {
      if( std::is_same< DeviceType, Devices::Host >::value ) {
         for( int k = 0; k <= i; k++ )
            out[ k ] = z[ k ] - y_i[ k ] * aux;
      }
      if( std::is_same< DeviceType, Devices::Cuda >::value ) {
      // copy part of y_i to buffer on host
      // here we duplicate the upper (m+1)x(m+1) submatrix of Y on host for fast access
      RealType* host_yi = &YL[ i * (restarting_max + 1) ];
         RealType host_z[ i + 1 ];
      Containers::Algorithms::ArrayOperations< Devices::Host, Devices::Cuda >::copyMemory< RealType, RealType, IndexType >( host_yi, y_i.getData(), restarting_max + 1 );
      Containers::Algorithms::ArrayOperations< Devices::Host, Devices::Cuda >::copyMemory< RealType, RealType, IndexType >( host_z, z.getData(), i + 1 );
         Containers::Algorithms::ArrayOperations< Devices::Host, Devices::Cuda >::copyMemory( host_z, Traits::getLocalVectorView( z ).getData(), i + 1 );
         for( int k = 0; k <= i; k++ )
         out[ k ] = host_z[ k ] - host_yi[ k ] * aux;
            out[ k ] = host_z[ k ] - YL_i[ k ] * aux;
      }
   }

   // no-op if the problem is not distributed
   CommunicatorType::Bcast( out.getData(), i + 1, 0, Traits::getCommunicationGroup( *this->matrix ) );
}

template< typename Matrix >
void
GMRES< Matrix >::
hauseholder_cwy( DeviceVector& v,
                 DeviceVector& Y,
                 HostVector& T,
hauseholder_cwy( VectorViewType v,
                 const int i )
{
   // aux = Y_i^T * e_i
   RealType aux[ i + 1 ];
   if( std::is_same< DeviceType, Devices::Host >::value ) {
      for( int k = 0; k <= i; k++ )
         aux[ k ] = Y[ i + k * ldSize ];
   }
   if( std::is_same< DeviceType, Devices::Cuda >::value ) {
      // the upper (m+1)x(m+1) submatrix of Y is duplicated on host for fast access
   // the upper (m+1)x(m+1) submatrix of Y is duplicated on host
   // (faster access than from the device and it is broadcasted to all processes)
   for( int k = 0; k <= i; k++ )
      aux[ k ] = YL[ i + k * (restarting_max + 1) ];
   }

   // aux = T_i * aux
   // Note that T_i is upper triangular, so we can overwrite the aux vector with the result in place
@@ -525,18 +503,17 @@ hauseholder_cwy( DeviceVector& v,
   // v = e_i - Y_i * aux
   MatrixOperations< DeviceType >::gemv( size, i + 1,
                                         -1.0, Y.getData(), ldSize, aux,
                                         0.0, v.getData() );
                                         0.0, Traits::getLocalVectorView( v ).getData() );
   if( localOffset == 0 )
      v.setElement( i, 1.0 + v.getElement( i ) );
}

template< typename Matrix >
void
GMRES< Matrix >::
hauseholder_cwy_transposed( DeviceVector& z,
                            DeviceVector& Y,
                            HostVector& T,
hauseholder_cwy_transposed( VectorViewType z,
                            const int i,
                            DeviceVector& w )
                            ConstVectorViewType w )
{
   // aux = Y_i^T * w
   RealType aux[ i + 1 ];
@@ -547,8 +524,10 @@ hauseholder_cwy_transposed( DeviceVector& z,
              size,
              Y.getData(),
              ldSize,
              w.getData(),
              Traits::getLocalVectorView( w ).getData(),
              aux );
   // no-op if the problem is not distributed
   Traits::CommunicatorType::Allreduce( aux, i + 1, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );

   // aux = T_i^T * aux
   // Note that T_i^T is lower triangular, so we can overwrite the aux vector with the result in place
@@ -563,7 +542,7 @@ hauseholder_cwy_transposed( DeviceVector& z,
   z = w;
   MatrixOperations< DeviceType >::gemv( size, i + 1,
                                         -1.0, Y.getData(), ldSize, aux,
                                         1.0, z.getData() );
                                         1.0, Traits::getLocalVectorView( z ).getData() );
}

template< typename Matrix >
@@ -574,7 +553,7 @@ update( const int k,
        const int m,
        const HostVector& H,
        const HostVector& s,
        DeviceVector& v,
        DeviceVector& V,
        Vector& x )
{
   RealType y[ m + 1 ];
@@ -602,8 +581,8 @@ update( const int k,

   // x = V * y + x
   MatrixOperations< DeviceType >::gemv( size, k + 1,
                                         1.0, v.getData(), ldSize, y,
                                         1.0, x.getData() );
                                         1.0, V.getData(), ldSize, y,
                                         1.0, Traits::getLocalVectorView( x ).getData() );
}

template< typename Matrix >
@@ -673,29 +652,30 @@ apply_givens_rotations( int i, int m )
template< typename Matrix >
void
GMRES< Matrix >::
setSize( const IndexType size )
setSize( const VectorViewType& x )
{
   this->size = size;
   this->size = Traits::getLocalVectorView( x ).getSize();
   if( std::is_same< DeviceType, Devices::Cuda >::value )
      // align each column to 256 bytes - optimal for CUDA
      ldSize = roundToMultiple( size, 256 / sizeof( RealType ) );
   else
      // on the host, we add 1 to disrupt the cache false-sharing pattern
      ldSize = roundToMultiple( size, 256 / sizeof( RealType ) ) + 1;
   localOffset = getLocalOffset( *this->matrix );

   const int m = restarting_max;
   r.setSize( size );
   w.setSize( size );
   r.setLike( x );
   w.setLike( x );
   _M_tmp.setLike( x );
   V.setSize( ldSize * ( m + 1 ) );
   cs.setSize( m + 1 );
   sn.setSize( m + 1 );
   H.setSize( ( m + 1 ) * m );
   s.setSize( m + 1 );
   _M_tmp.setSize( size );

   // CWY-specific storage
   if( variant == Variant::CWY ) {
      z.setSize( size );
      z.setLike( x );
      Y.setSize( ldSize * ( m + 1 ) );
      T.setSize( (m + 1) * (m + 1) );
      YL.setSize( (m + 1) * (m + 1) );
+11 −0
Original line number Diff line number Diff line
@@ -12,7 +12,10 @@

#pragma once

#include <TNL/Communicators/NoDistrCommunicator.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Containers/VectorView.h>
#include <TNL/DistributedContainers/DistributedVector.h>
#include <TNL/DistributedContainers/DistributedVectorView.h>
#include <TNL/DistributedContainers/DistributedMatrix.h>

@@ -23,6 +26,8 @@ namespace Linear {
template< typename Matrix >
struct Traits
{
   using CommunicatorType = Communicators::NoDistrCommunicator;

   using VectorType = Containers::Vector
         < typename Matrix::RealType,
           typename Matrix::DeviceType,
@@ -45,11 +50,15 @@ struct Traits
   static const Matrix& getLocalMatrix( const Matrix& m ) { return m; }
   static ConstLocalVectorViewType getLocalVectorView( ConstVectorViewType v ) { return v; }
   static LocalVectorViewType getLocalVectorView( VectorViewType v ) { return v; }

   static typename CommunicatorType::CommunicationGroup getCommunicationGroup( const Matrix& m ) { return CommunicatorType::AllGroup; }
};

template< typename Matrix, typename Communicator >
struct Traits< DistributedContainers::DistributedMatrix< Matrix, Communicator > >
{
   using CommunicatorType = Communicator;

   using VectorType = DistributedContainers::DistributedVector
         < typename Matrix::RealType,
           typename Matrix::DeviceType,
@@ -84,6 +93,8 @@ struct Traits< DistributedContainers::DistributedMatrix< Matrix, Communicator >
   { return m.getLocalMatrix(); }
   static ConstLocalVectorViewType getLocalVectorView( ConstVectorViewType v ) { return v.getLocalVectorView(); }
   static LocalVectorViewType getLocalVectorView( VectorViewType v ) { return v.getLocalVectorView(); }

   static typename CommunicatorType::CommunicationGroup getCommunicationGroup( const DistributedContainers::DistributedMatrix< Matrix, Communicator >& m ) { return m.getCommunicationGroup(); }
};

} // namespace Linear