Commit 99ad4528 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed linear systems to work with float as the real type

parent 80c57392
Loading
Loading
Loading
Loading
+24 −24
Original line number Diff line number Diff line
@@ -46,13 +46,13 @@ public:
   template< typename RealType,
             typename IndexType >
   static void
   gemv( const IndexType& m,
         const IndexType& n,
         const RealType& alpha,
   gemv( const IndexType m,
         const IndexType n,
         const RealType alpha,
         const RealType* A,
         const IndexType& lda,
         const IndexType lda,
         const RealType* x,
         const RealType& beta,
         const RealType beta,
         RealType* y )
   {
      TNL_ASSERT_GT( m, 0, "m must be positive" );
@@ -164,16 +164,16 @@ public:
   template< typename RealType,
             typename IndexType >
   static void
   geam( const IndexType& m,
         const IndexType& n,
         const RealType& alpha,
   geam( const IndexType m,
         const IndexType n,
         const RealType alpha,
         const RealType* A,
         const IndexType& lda,
         const RealType& beta,
         const IndexType lda,
         const RealType beta,
         const RealType* B,
         const IndexType& ldb,
         const IndexType ldb,
         RealType* C,
         const IndexType& ldc )
         const IndexType ldc )
   {
      TNL_ASSERT_GT( m, 0, "m must be positive" );
      TNL_ASSERT_GT( n, 0, "n must be positive" );
@@ -326,13 +326,13 @@ public:
   template< typename RealType,
             typename IndexType >
   static void
   gemv( const IndexType& m,
         const IndexType& n,
         const RealType& alpha,
   gemv( const IndexType m,
         const IndexType n,
         const RealType alpha,
         const RealType* A,
         const IndexType& lda,
         const IndexType lda,
         const RealType* x,
         const RealType& beta,
         const RealType beta,
         RealType* y )
   {
      TNL_ASSERT( m <= lda, );
@@ -375,16 +375,16 @@ public:
   template< typename RealType,
             typename IndexType >
   static void
   geam( const IndexType& m,
         const IndexType& n,
         const RealType& alpha,
   geam( const IndexType m,
         const IndexType n,
         const RealType alpha,
         const RealType* A,
         const IndexType& lda,
         const RealType& beta,
         const IndexType lda,
         const RealType beta,
         const RealType* B,
         const IndexType& ldb,
         const IndexType ldb,
         RealType* C,
         const IndexType& ldc )
         const IndexType ldc )
   {
      TNL_ASSERT_GT( m, 0, "m must be positive" );
      TNL_ASSERT_GT( n, 0, "n must be positive" );
+8 −8
Original line number Diff line number Diff line
@@ -102,7 +102,7 @@ solve( ConstVectorViewType b, VectorViewType x )
          */
         Matrices::MatrixOperations< DeviceType >::
            geam( size, j + 1,
                  1.0, R.getData(), ldSize,
                  (RealType) 1.0, R.getData(), ldSize,
                  -beta, U.getData(), ldSize,
                  U.getData(), ldSize );

@@ -121,7 +121,7 @@ solve( ConstVectorViewType b, VectorViewType x )
          */
         Matrices::MatrixOperations< DeviceType >::
            geam( size, j + 1,
                  1.0, R.getData(), ldSize,
                  (RealType) 1.0, R.getData(), ldSize,
                  -alpha, U.getData() + ldSize, ldSize,
                  R.getData(), ldSize );

@@ -206,18 +206,18 @@ solve( ConstVectorViewType b, VectorViewType x )
      g_2[ 0 ] = g_0[ 1 ];
      Matrices::MatrixOperations< DeviceType >::
         gemv( size, ell,
               1.0, R.getData(), ldSize, g_2.getData(),
               1.0, Traits::getLocalView( x ).getData() );
               (RealType) 1.0, R.getData(), ldSize, g_2.getData(),
               (RealType) 1.0, Traits::getLocalView( x ).getData() );
      // r_0 := r_0 - R_[1:ell] * g_1_[1:ell]
      Matrices::MatrixOperations< DeviceType >::
         gemv( size, ell,
               -1.0, R.getData() + ldSize, ldSize, &g_1[ 1 ],
               1.0, Traits::getLocalView( r_0 ).getData() );
               (RealType) -1.0, R.getData() + ldSize, ldSize, &g_1[ 1 ],
               (RealType) 1.0, Traits::getLocalView( r_0 ).getData() );
      // u_0 := u_0 - U_[1:ell] * g_0_[1:ell]
      Matrices::MatrixOperations< DeviceType >::
         gemv( size, ell,
               -1.0, U.getData() + ldSize, ldSize, &g_0[ 1 ],
               1.0, Traits::getLocalView( u_0 ).getData() );
               (RealType) -1.0, U.getData() + ldSize, ldSize, &g_0[ 1 ],
               (RealType) 1.0, Traits::getLocalView( u_0 ).getData() );

      if( exact_residue ) {
         /****
+10 −10
Original line number Diff line number Diff line
@@ -243,8 +243,8 @@ orthogonalize_CGS( const int m, const RealType normb, const RealType beta )
         // w := w - V_i * H_l
         Matrices::MatrixOperations< DeviceType >::
            gemv( size, i + 1,
                  -1.0, V.getData(), ldSize, H_l,
                  1.0, Traits::getLocalView( w ).getData() );
                  (RealType) -1.0, V.getData(), ldSize, H_l,
                  (RealType) 1.0, Traits::getLocalView( w ).getData() );
      }
      /***
       * H_{i+1,i} = |w|
@@ -607,8 +607,8 @@ hauseholder_cwy( VectorViewType v,
   // v = e_i - Y_i * aux
   Matrices::MatrixOperations< DeviceType >::
      gemv( size, i + 1,
            -1.0, Y.getData(), ldSize, aux,
            0.0, Traits::getLocalView( v ).getData() );
            (RealType) -1.0, Y.getData(), ldSize, aux,
            (RealType) 0.0, Traits::getLocalView( v ).getData() );
   if( localOffset == 0 )
      v.setElement( i, 1.0 + v.getElement( i ) );
}
@@ -649,8 +649,8 @@ hauseholder_cwy_transposed( VectorViewType z,
   z = w;
   Matrices::MatrixOperations< DeviceType >::
      gemv( size, i + 1,
            -1.0, Y.getData(), ldSize, aux,
            1.0, Traits::getLocalView( z ).getData() );
            (RealType) -1.0, Y.getData(), ldSize, aux,
            (RealType) 1.0, Traits::getLocalView( z ).getData() );
}

template< typename Matrix >
@@ -691,8 +691,8 @@ update( const int k,
      // x = V * y + x
      Matrices::MatrixOperations< DeviceType >::
         gemv( size, k + 1,
               1.0, V.getData(), ldSize, y,
               1.0, Traits::getLocalView( x ).getData() );
               (RealType) 1.0, V.getData(), ldSize, y,
               (RealType) 1.0, Traits::getLocalView( x ).getData() );
   }
   else {
      // The vectors v_i are not stored, they can be reconstructed as P_0...P_j * e_j.
@@ -723,8 +723,8 @@ update( const int k,
      // x -= Y_{k+1} * aux
      Matrices::MatrixOperations< DeviceType >::
         gemv( size, k + 1,
               -1.0, Y.getData(), ldSize, aux,
               1.0, Traits::getLocalView( x ).getData() );
               (RealType) -1.0, Y.getData(), ldSize, aux,
               (RealType) 1.0, Traits::getLocalView( x ).getData() );

      // x += y
      if( localOffset == 0 )
+20 −1
Original line number Diff line number Diff line
@@ -31,7 +31,26 @@ namespace Preconditioners {
// implementation template
template< typename Matrix, typename Real, typename Device, typename Index >
class ILU0_impl
{};
: public Preconditioner< Matrix >
{
public:
   using RealType = Real;
   using DeviceType = Device;
   using IndexType = Index;
   using typename Preconditioner< Matrix >::VectorViewType;
   using typename Preconditioner< Matrix >::ConstVectorViewType;
   using typename Preconditioner< Matrix >::MatrixPointer;

   virtual void update( const MatrixPointer& matrixPointer ) override
   {
      throw Exceptions::NotImplementedError("ILU0 is not implemented yet for the matrix type " + getType< Matrix >());
   }

   virtual void solve( ConstVectorViewType b, VectorViewType x ) const override
   {
      throw Exceptions::NotImplementedError("ILU0 is not implemented yet for the matrix type " + getType< Matrix >());
   }
};

// actual template to be used by users
template< typename Matrix >