Commit 5e7005a6 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

MPI refactoring: removed MpiCommunicator from solvers: Merson, GMRES, Linear/Traits.h

parent eb8b40dc
Loading
Loading
Loading
Loading
+0 −3
Original line number Diff line number Diff line
@@ -23,10 +23,7 @@ class GMRES
: public LinearSolver< Matrix >
{
   using Base = LinearSolver< Matrix >;

   // compatibility shortcuts
   using Traits = Linear::Traits< Matrix >;
   using CommunicatorType = typename Traits::CommunicatorType;

public:
   using RealType = typename Base::RealType;
+5 −5
Original line number Diff line number Diff line
@@ -510,7 +510,7 @@ hauseholder_generate( const int i,
      norm_yi_squared = 2 * (normz * normz + std::fabs( y_ii ) * normz);
   }
   // no-op if the problem is not distributed
   CommunicatorType::Bcast( &norm_yi_squared, 1, 0, Traits::getCommunicationGroup( *this->matrix ) );
   MPI::Bcast( &norm_yi_squared, 1, 0, Traits::getCommunicationGroup( *this->matrix ) );

   // XXX: normalization is slower, but more stable
//   y_i *= 1.0 / std::sqrt( norm_yi_squared );
@@ -534,7 +534,7 @@ hauseholder_generate( const int i,
                 i,
                 aux );
      // no-op if the problem is not distributed
      CommunicatorType::Allreduce( aux, i, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );
      MPI::Allreduce( aux, i, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );

      // [T_i]_{0..i-1} = - T_{i-1} * t_i * aux
      for( int k = 0; k < i; k++ ) {
@@ -559,7 +559,7 @@ hauseholder_apply_trunc( HostView out,
   HostView YL_i( &YL[ i * (restarting_max + 1) ], restarting_max + 1 );
   Algorithms::MultiDeviceMemoryOperations< Devices::Host, DeviceType >::copy( YL_i.getData(), Traits::getLocalView( y_i ).getData(), YL_i.getSize() );
   // no-op if the problem is not distributed
   CommunicatorType::Bcast( YL_i.getData(), YL_i.getSize(), 0, Traits::getCommunicationGroup( *this->matrix ) );
   MPI::Bcast( YL_i.getData(), YL_i.getSize(), 0, Traits::getCommunicationGroup( *this->matrix ) );

   // NOTE: aux = t_i * (y_i, z) = 1  since  t_i = 2 / ||y_i||^2  and
   //       (y_i, z) = ||z_trunc||^2 + |z_i| ||z_trunc|| = ||y_i||^2 / 2
@@ -579,7 +579,7 @@ hauseholder_apply_trunc( HostView out,
   }

   // no-op if the problem is not distributed
   CommunicatorType::Bcast( out.getData(), i + 1, 0, Traits::getCommunicationGroup( *this->matrix ) );
   MPI::Bcast( out.getData(), i + 1, 0, Traits::getCommunicationGroup( *this->matrix ) );
}

template< typename Matrix >
@@ -634,7 +634,7 @@ hauseholder_cwy_transposed( VectorViewType z,
              i + 1,
              aux );
   // no-op if the problem is not distributed
   Traits::CommunicatorType::Allreduce( aux, i + 1, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );
   MPI::Allreduce( aux, i + 1, MPI_SUM, Traits::getCommunicationGroup( *this->matrix ) );

   // aux = T_i^T * aux
   // Note that T_i^T is lower triangular, so we can overwrite the aux vector with the result in place
+3 −7
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@

#pragma once

#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/MPI/Wrappers.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Containers/VectorView.h>
#include <TNL/Containers/DistributedVector.h>
@@ -26,8 +26,6 @@ namespace Linear {
template< typename Matrix >
struct Traits
{
   using CommunicatorType = Communicators::MpiCommunicator;

   using VectorType = Containers::Vector
         < typename Matrix::RealType,
           typename Matrix::DeviceType,
@@ -51,7 +49,7 @@ struct Traits
   static ConstLocalViewType getConstLocalView( ConstVectorViewType v ) { return v; }
   static LocalViewType getLocalView( VectorViewType v ) { return v; }

   static typename CommunicatorType::CommunicationGroup getCommunicationGroup( const Matrix& m ) { return CommunicatorType::AllGroup; }
   static MPI_Comm getCommunicationGroup( const Matrix& m ) { return MPI::AllGroup(); }
   static void startSynchronization( VectorViewType v ) {}
   static void waitForSynchronization( VectorViewType v ) {}
};
@@ -59,8 +57,6 @@ struct Traits
template< typename Matrix, typename Communicator >
struct Traits< Matrices::DistributedMatrix< Matrix, Communicator > >
{
   using CommunicatorType = Communicator;

   using VectorType = Containers::DistributedVector
         < typename Matrix::RealType,
           typename Matrix::DeviceType,
@@ -96,7 +92,7 @@ struct Traits< Matrices::DistributedMatrix< Matrix, Communicator > >
   static ConstLocalViewType getConstLocalView( ConstVectorViewType v ) { return v.getConstLocalView(); }
   static LocalViewType getLocalView( VectorViewType v ) { return v.getLocalView(); }

   static typename CommunicatorType::CommunicationGroup getCommunicationGroup( const Matrices::DistributedMatrix< Matrix, Communicator >& m ) { return m.getCommunicationGroup(); }
   static MPI_Comm getCommunicationGroup( const Matrices::DistributedMatrix< Matrix, Communicator >& m ) { return m.getCommunicationGroup(); }
   static void startSynchronization( VectorViewType v ) { v.startSynchronization(); }
   static void waitForSynchronization( VectorViewType v ) { v.waitForSynchronization(); }
};
+5 −5
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
#include <TNL/Config/ParameterContainer.h>
#include <TNL/Communicators/MpiCommunicator.h>
#include <TNL/MPI/Wrappers.h>

#include "Merson.h"

@@ -156,7 +156,7 @@ bool Merson< Problem, SolverMonitor >::solve( DofVectorPointer& _u )
      {
         const RealType localError =
            max( currentTau / 3.0 * abs( 0.2 * k1 -0.9 * k3 + 0.8 * k4 -0.1 * k5 ) );
            Problem::CommunicatorType::Allreduce( &localError, &error, 1, MPI_MAX, Problem::CommunicatorType::AllGroup );
            MPI::Allreduce( &localError, &error, 1, MPI_MAX, MPI::AllGroup() );
      }

      if( adaptivity == 0.0 || error < adaptivity )