diff --git a/src/TNL/Solvers/Linear/BICGStabL.h b/src/TNL/Solvers/Linear/BICGStabL.h index 77311c442c65de86c41a65ba786fb75b870932eb..a35962d54b900ae0a50dfe1f42ff04d9235fc3a8 100644 --- a/src/TNL/Solvers/Linear/BICGStabL.h +++ b/src/TNL/Solvers/Linear/BICGStabL.h @@ -43,8 +43,6 @@ #include "LinearSolver.h" -#include <TNL/Containers/Vector.h> - namespace TNL { namespace Solvers { namespace Linear { @@ -54,12 +52,18 @@ class BICGStabL : public LinearSolver< Matrix > { using Base = LinearSolver< Matrix >; + + // compatibility shortcut + using Traits = Linear::Traits< Matrix >; + public: using RealType = typename Base::RealType; using DeviceType = typename Base::DeviceType; using IndexType = typename Base::IndexType; + // distributed vectors/views using VectorViewType = typename Base::VectorViewType; using ConstVectorViewType = typename Base::ConstVectorViewType; + using VectorType = typename Traits::VectorType; String getType() const; @@ -75,7 +79,11 @@ protected: using DeviceVector = Containers::Vector< RealType, DeviceType, IndexType >; using HostVector = Containers::Vector< RealType, Devices::Host, IndexType >; - void setSize( IndexType size ); + void compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b ); + + void preconditioned_matvec( ConstVectorViewType src, VectorViewType dst ); + + void setSize( const VectorViewType& x ); int ell = 1; @@ -83,12 +91,13 @@ protected: // matrices (in column-major format) DeviceVector R, U; - // single vectors - DeviceVector r_ast, M_tmp, res_tmp; + // single vectors (distributed) + VectorType r_ast, M_tmp, res_tmp; // host-only storage HostVector T, sigma, g_0, g_1, g_2; - IndexType size, ldSize; + IndexType size = 0; + IndexType ldSize = 0; }; } // namespace Linear diff --git a/src/TNL/Solvers/Linear/BICGStabL_impl.h b/src/TNL/Solvers/Linear/BICGStabL_impl.h index 6606bddd561fa8a8d5b3ce5e8855b6e41541b546..8338a3f96fcec4826dc6d34204ed28db7695d0c5 100644 --- a/src/TNL/Solvers/Linear/BICGStabL_impl.h +++ b/src/TNL/Solvers/Linear/BICGStabL_impl.h @@ -22,7 +22,8 @@ namespace Linear { template< typename Matrix > String -BICGStabL< Matrix >::getType() const +BICGStabL< Matrix >:: +getType() const { return String( "BICGStabL< " ) + this->matrix -> getType() + ", " + @@ -52,39 +53,30 @@ setup( const Config::ParameterContainer& parameters, template< typename Matrix > bool -BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) +BICGStabL< Matrix >:: +solve( ConstVectorViewType b, VectorViewType x ) { - this->setSize( this->matrix->getRows() ); + this->setSize( x ); RealType alpha, beta, gamma, rho_0, rho_1, omega, b_norm; - DeviceVector r_0, r_j, r_i, u_0, Au, u; + // initial binding to M_tmp sets the correct local range, global size and + // communication group for distributed views + VectorViewType r_0( M_tmp ), r_j( M_tmp ), r_i( M_tmp ), u_0( M_tmp ), Au( M_tmp ), u( M_tmp ); r_0.bind( R.getData(), size ); u_0.bind( U.getData(), size ); - auto matvec = [this]( const DeviceVector& src, DeviceVector& dst ) - { - if( this->preconditioner ) { - this->matrix->vectorProduct( src, M_tmp ); - this->preconditioner->solve( M_tmp, dst ); - } - else { - this->matrix->vectorProduct( src, dst ); - } - }; - + // initialize the norm of the preconditioned right-hand-side if( this->preconditioner ) { this->preconditioner->solve( b, M_tmp ); - b_norm = M_tmp.lpNorm( ( RealType ) 2.0 ); - - this->matrix->vectorProduct( x, M_tmp ); - M_tmp.addVector( b, 1.0, -1.0 ); - this->preconditioner->solve( M_tmp, r_0 ); + b_norm = M_tmp.lpNorm( 2.0 ); } - else { + else b_norm = b.lpNorm( 2.0 ); - this->matrix->vectorProduct( x, r_0 ); - r_0.addVector( b, 1.0, -1.0 ); - } + if( b_norm == 0.0 ) + b_norm = 1.0; + + // r_0 = M.solve(b - A * x); + compute_residue( r_0, x, b ); sigma[ 0 ] = r_0.lpNorm( 2.0 ); if( std::isnan( sigma[ 0 ] ) ) @@ -97,9 +89,6 @@ BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) omega = 1.0; u_0.setValue( 0.0 ); - if( b_norm == 0.0 ) - b_norm = 1.0; - this->resetIterations(); this->setResidue( sigma[ 0 ] / b_norm ); @@ -132,7 +121,7 @@ BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) */ u.bind( &U.getData()[ j * ldSize ], size ); Au.bind( &U.getData()[ (j + 1) * ldSize ], size ); - matvec( u, Au ); + preconditioned_matvec( u, Au ); gamma = r_ast.scalarProduct( Au ); alpha = rho_0 / gamma; @@ -151,7 +140,7 @@ BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) */ r_j.bind( &R.getData()[ j * ldSize ], size ); r_i.bind( &R.getData()[ (j + 1) * ldSize ], size ); - matvec( r_j, r_i ); + preconditioned_matvec( r_j, r_i ); /**** * x_0 := x_0 + alpha * u_0 @@ -227,29 +216,21 @@ BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) g_2[ 0 ] = g_0[ 1 ]; MatrixOperations< DeviceType >::gemv( size, ell, 1.0, R.getData(), ldSize, g_2.getData(), - 1.0, x.getData() ); + 1.0, Traits::getLocalVectorView( x ).getData() ); // r_0 := r_0 - R_[1:ell] * g_1_[1:ell] MatrixOperations< DeviceType >::gemv( size, ell, -1.0, R.getData() + ldSize, ldSize, &g_1[ 1 ], - 1.0, r_0.getData() ); + 1.0, Traits::getLocalVectorView( r_0 ).getData() ); // u_0 := u_0 - U_[1:ell] * g_0_[1:ell] MatrixOperations< DeviceType >::gemv( size, ell, -1.0, U.getData() + ldSize, ldSize, &g_0[ 1 ], - 1.0, u_0.getData() ); + 1.0, Traits::getLocalVectorView( u_0 ).getData() ); if( exact_residue ) { /**** * Compute the exact preconditioned residue into the 's' vector. */ - if( this->preconditioner ) { - this->matrix->vectorProduct( x, M_tmp ); - M_tmp.addVector( b, 1.0, -1.0 ); - this->preconditioner->solve( M_tmp, res_tmp ); - } - else { - this->matrix->vectorProduct( x, res_tmp ); - res_tmp.addVector( b, 1.0, -1.0 ); - } + compute_residue( res_tmp, x, b ); sigma[ 0 ] = res_tmp.lpNorm( 2.0 ); this->setResidue( sigma[ 0 ] / b_norm ); } @@ -268,15 +249,49 @@ BICGStabL< Matrix >::solve( ConstVectorViewType b, VectorViewType x ) template< typename Matrix > void -BICGStabL< Matrix >::setSize( IndexType size ) +BICGStabL< Matrix >:: +compute_residue( VectorViewType r, ConstVectorViewType x, ConstVectorViewType b ) +{ + /**** + * r = M.solve(b - A * x); + */ + if( this->preconditioner ) { + this->matrix->vectorProduct( x, M_tmp ); + M_tmp.addVector( b, 1.0, -1.0 ); + this->preconditioner->solve( M_tmp, r ); + } + else { + this->matrix->vectorProduct( x, r ); + r.addVector( b, 1.0, -1.0 ); + } +} + +template< typename Matrix > +void +BICGStabL< Matrix >:: +preconditioned_matvec( ConstVectorViewType src, VectorViewType dst ) +{ + if( this->preconditioner ) { + this->matrix->vectorProduct( src, M_tmp ); + this->preconditioner->solve( M_tmp, dst ); + } + else { + this->matrix->vectorProduct( src, dst ); + } +} + +template< typename Matrix > +void +BICGStabL< Matrix >:: +setSize( const VectorViewType& x ) { - this->size = ldSize = size; + this->size = ldSize = Traits::getLocalVectorView( x ).getSize(); R.setSize( (ell + 1) * ldSize ); U.setSize( (ell + 1) * ldSize ); - r_ast.setSize( size ); - M_tmp.setSize( size ); + r_ast.setLike( x ); + M_tmp.setLike( x ); if( exact_residue ) - res_tmp.setSize( size ); + res_tmp.setLike( x ); T.setSize( ell * ell ); sigma.setSize( ell + 1 ); g_0.setSize( ell + 1 );