From 98fcaffdec94da88f2d28350d6fa595e3f354b3c Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkovsky@mmg.fjfi.cvut.cz>
Date: Fri, 25 Mar 2022 15:51:17 +0100
Subject: [PATCH] Added RAII wrapper for MPI communicators

---
 src/TNL/MPI.h                                 |   1 +
 src/TNL/MPI/Comm.h                            | 249 ++++++++++++++++++
 src/TNL/MPI/DummyDefs.h                       |  10 +
 src/TNL/MPI/Utils.h                           |  27 --
 src/TNL/MPI/Wrappers.h                        |  12 -
 src/TNL/MPI/selectGPU.h                       |   2 +-
 .../DistributedMeshes/DistributedGrid.h       |   6 +-
 .../DistributedMeshes/DistributedGrid.hpp     |  10 +-
 .../DistributedMeshes/DistributedMesh.h       |   9 +-
 src/TNL/Meshes/Readers/PVTUReader.h           |   4 +-
 src/UnitTests/MPI/CMakeLists.txt              |   2 +-
 src/UnitTests/MPI/MPICommTest.cpp             | 126 +++++++++
 12 files changed, 403 insertions(+), 55 deletions(-)
 create mode 100644 src/TNL/MPI/Comm.h
 create mode 100644 src/UnitTests/MPI/MPICommTest.cpp

diff --git a/src/TNL/MPI.h b/src/TNL/MPI.h
index 77984b53d4..9f006e330c 100644
--- a/src/TNL/MPI.h
+++ b/src/TNL/MPI.h
@@ -22,6 +22,7 @@
 #include "MPI/selectGPU.h"
 #include "MPI/Wrappers.h"
 #include "MPI/Utils.h"
+#include "MPI/Comm.h"
 #include "MPI/ScopedInitializer.h"
 #include "MPI/Config.h"
 #include "MPI/Print.h"
diff --git a/src/TNL/MPI/Comm.h b/src/TNL/MPI/Comm.h
new file mode 100644
index 0000000000..054bf6a119
--- /dev/null
+++ b/src/TNL/MPI/Comm.h
@@ -0,0 +1,249 @@
+// Copyright (c) 2004-2022 Tomáš Oberhuber et al.
+//
+// This file is part of TNL - Template Numerical Library (https://tnl-project.org/)
+//
+// SPDX-License-Identifier: MIT
+
+#pragma once
+
+#include <stdexcept>
+#include <memory>
+
+#include "Wrappers.h"
+
+namespace TNL {
+namespace MPI {
+
+/**
+ * \brief An RAII wrapper for custom MPI communicators.
+ *
+ * This is an RAII wrapper for custom MPI communicators created by calls to
+ * \ref MPI_Comm_create, \ref MPI_Comm_split, or similar functions. It is based
+ * on \ref std::shared_ptr so copy-constructible and copy-assignable, copies of
+ * the object represent the same communicator that is deallocated only when the
+ * internal reference counter drops to zero.
+ *
+ * Note that predefined communicators (i.e. \ref MPI_COMM_WORLD,
+ * \ref MPI_COMM_NULL and \ref MPI_COMM_SELF) can be used to initialize this
+ * class, but other handles of the \ref MPI_Comm type _cannot_ be used to
+ * initialize this class.
+ *
+ * This class follows the factory pattern, i.e. it provides static methods such
+ * as \ref Comm::create or \ref Comm::split that return an instance of a new
+ * communicator.
+ */
+class Comm
+{
+private:
+   struct Wrapper
+   {
+      MPI_Comm comm = MPI_COMM_NULL;
+
+      Wrapper() = default;
+      Wrapper( const Wrapper& other ) = delete;
+      Wrapper( Wrapper&& other ) = default;
+      Wrapper&
+      operator=( const Wrapper& other ) = delete;
+      Wrapper&
+      operator=( Wrapper&& other ) = default;
+
+      Wrapper( MPI_Comm comm ) : comm( comm ) {}
+
+      ~Wrapper()
+      {
+#ifdef HAVE_MPI
+         // cannot free a predefined handle
+         if( comm != MPI_COMM_NULL && comm != MPI_COMM_WORLD && comm != MPI_COMM_SELF )
+            MPI_Comm_free( &comm );
+#endif
+      }
+   };
+
+   std::shared_ptr< Wrapper > wrapper;
+
+   //! \brief Internal constructor for the factory methods - initialization by the wrapper.
+   Comm( std::shared_ptr< Wrapper >&& wrapper ) : wrapper( std::move( wrapper ) ) {}
+
+public:
+   //! \brief Constructs an empty communicator with a null handle (`MPI_COMM_NULL`).
+   Comm() = default;
+
+   //! \brief Deleted copy-constructor.
+   Comm( const Comm& other ) = default;
+
+   //! \brief Default move-constructor.
+   Comm( Comm&& other ) = default;
+
+   //! \brief Deleted copy-assignment operator.
+   Comm&
+   operator=( const Comm& other ) = default;
+
+   //! \brief Default move-assignment operator.
+   Comm&
+   operator=( Comm&& other ) = default;
+
+   /**
+    * \brief Constructs a communicator initialized by given predefined communicator.
+    *
+    * Note that only predefined communicators (i.e. \ref MPI_COMM_WORLD,
+    * \ref MPI_COMM_NULL and \ref MPI_COMM_SELF) can be used to initialize this
+    * class. Other handles of the \ref MPI_Comm type _cannot_ be used to
+    * initialize this class.
+    *
+    * \throws std::logic_error when the \e comm handle is not a predefined
+    * communicator.
+    */
+   Comm( MPI_Comm comm )
+   {
+      if( comm != MPI_COMM_NULL && comm != MPI_COMM_WORLD && comm != MPI_COMM_SELF )
+         throw std::logic_error( "Only predefined communicators (MPI_COMM_WORLD, MPI_COMM_NULL and "
+                                 "MPI_COMM_SELF) can be used to initialize this class. Other "
+                                 "handles of the MPI_Comm type *cannot* be used to initialize "
+                                 "the TNL::MPI::Comm class." );
+      wrapper = std::make_shared< Wrapper >( comm );
+   }
+
+   //! \brief Factory method – wrapper for \ref MPI_Comm_dup
+   static Comm
+   duplicate( MPI_Comm comm )
+   {
+#ifdef HAVE_MPI
+      MPI_Comm newcomm;
+      MPI_Comm_dup( comm, &newcomm );
+      return { std::make_shared< Wrapper >( newcomm ) };
+#else
+      return { std::make_shared< Wrapper >( comm ) };
+#endif
+   }
+
+   //! \brief Non-static factory method – wrapper for \ref MPI_Comm_dup
+   Comm
+   duplicate() const
+   {
+      return duplicate( *this );
+   }
+
+   //! \brief Factory method – wrapper for \ref MPI_Comm_split
+   static Comm
+   split( MPI_Comm comm, int color, int key )
+   {
+#ifdef HAVE_MPI
+      MPI_Comm newcomm;
+      MPI_Comm_split( comm, color, key, &newcomm );
+      return { std::make_shared< Wrapper >( newcomm ) };
+#else
+      return { std::make_shared< Wrapper >( comm ) };
+#endif
+   }
+
+   //! \brief Non-static factory method – wrapper for \ref MPI_Comm_split
+   Comm
+   split( int color, int key ) const
+   {
+      return split( *this, color, key );
+   }
+
+   //! \brief Factory method – wrapper for \ref MPI_Comm_split_type
+   static Comm
+   split_type( MPI_Comm comm, int split_type, int key, MPI_Info info )
+   {
+#ifdef HAVE_MPI
+      MPI_Comm newcomm;
+      MPI_Comm_split_type( comm, split_type, key, info, &newcomm );
+      return { std::make_shared< Wrapper >( newcomm ) };
+#else
+      return { std::make_shared< Wrapper >( comm ) };
+#endif
+   }
+
+   //! \brief Non-static factory method – wrapper for \ref MPI_Comm_split_type
+   Comm
+   split_type( int split_type, int key, MPI_Info info ) const
+   {
+      return Comm::split_type( *this, split_type, key, info );
+   }
+
+   /**
+    * \brief Access the MPI communicator associated with this object.
+    *
+    * This routine permits the implicit conversion from \ref Comm to
+    * \ref MPI_Comm.
+    *
+    * \b Warning: The obtained \ref MPI_Comm handle becomes invalid when the
+    * originating \ref Comm object is destroyed. For example, the following
+    * code is invalid, because the \ref Comm object managing the lifetime of
+    * the communicator is destroyed as soon as it is cast to \ref MPI_Comm:
+    *
+    * \code{.cpp}
+    * const MPI_Comm comm = MPI::Comm::duplicate( MPI_COMM_WORLD );
+    * const int nproc = MPI::GetSize( comm );
+    * \endcode
+    */
+   operator const MPI_Comm&() const
+   {
+      return wrapper->comm;
+   }
+
+   //! \brief Determines the rank of the calling process in the communicator.
+   int
+   rank() const
+   {
+      return GetRank( *this );
+   }
+
+   //! \brief Returns the size of the group associated with a communicator.
+   int
+   size() const
+   {
+      return GetSize( *this );
+   }
+
+   //! \brief Compares two communicators – wrapper for \ref MPI_Comm_compare.
+   int
+   compare( MPI_Comm comm2 ) const
+   {
+#ifdef HAVE_MPI
+      int result;
+      MPI_Comm_compare( *this, comm2, &result );
+      return result;
+#else
+      return MPI_IDENT;
+#endif
+   }
+
+   /**
+    * \brief Wait for all processes within a communicator to reach the barrier.
+    *
+    * This routine is a collective operation that blocks each process until all
+    * processes have entered it, then releases all of the processes
+    * "simultaneously". It is equivalent to calling \ref MPI_Barrier with the
+    * MPI communicator associated with this object.
+    */
+   void barrier() const
+   {
+      Barrier( *this );
+   }
+};
+
+/**
+ * \brief Returns a local rank ID of the current process within a group of
+ * processes running on a shared-memory node.
+ *
+ * The given MPI communicator is split into groups according to the
+ * `MPI_COMM_TYPE_SHARED` type (from MPI-3) and the rank ID of the process
+ * within the group is returned.
+ */
+inline int
+getRankOnNode( MPI_Comm communicator = MPI_COMM_WORLD )
+{
+#ifdef HAVE_MPI
+   const int rank = GetRank( communicator );
+   const MPI::Comm local_comm = MPI::Comm::split_type( communicator, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL );
+   return local_comm.rank();
+#else
+   return 0;
+#endif
+}
+
+}  // namespace MPI
+}  // namespace TNL
diff --git a/src/TNL/MPI/DummyDefs.h b/src/TNL/MPI/DummyDefs.h
index 25a460305c..e22af96704 100644
--- a/src/TNL/MPI/DummyDefs.h
+++ b/src/TNL/MPI/DummyDefs.h
@@ -9,6 +9,7 @@
 #ifndef HAVE_MPI
 using MPI_Request = int;
 using MPI_Comm = int;
+using MPI_Info = int;
 
 enum MPI_Op
 {
@@ -26,6 +27,15 @@ enum MPI_Op
    MPI_MAXLOC,
 };
 
+// Comparison results
+enum
+{
+   MPI_IDENT,
+   MPI_CONGRUENT,
+   MPI_SIMILAR,
+   MPI_UNEQUAL
+};
+
 // MPI_Init_thread constants
 enum
 {
diff --git a/src/TNL/MPI/Utils.h b/src/TNL/MPI/Utils.h
index 3355f08fb5..41a2287d36 100644
--- a/src/TNL/MPI/Utils.h
+++ b/src/TNL/MPI/Utils.h
@@ -42,33 +42,6 @@ restoreRedirection()
    }
 }
 
-/**
- * \brief Returns a local rank ID of the current process within a group of
- * processes running on a shared-memory node.
- *
- * The given MPI communicator is split into groups according to the
- * `MPI_COMM_TYPE_SHARED` type (from MPI-3) and the rank ID of the process
- * within the group is returned.
- */
-inline int
-getRankOnNode( MPI_Comm communicator = MPI_COMM_WORLD )
-{
-#ifdef HAVE_MPI
-   const int rank = GetRank( communicator );
-
-   MPI_Comm local_comm;
-   MPI_Comm_split_type( communicator, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL, &local_comm );
-
-   const int local_rank = GetRank( local_comm );
-
-   MPI_Comm_free( &local_comm );
-
-   return local_rank;
-#else
-   return 0;
-#endif
-}
-
 /**
  * \brief Applies the given reduction operation to the values among all ranks
  * in the given communicator.
diff --git a/src/TNL/MPI/Wrappers.h b/src/TNL/MPI/Wrappers.h
index 872e4ad4de..a6ea059abe 100644
--- a/src/TNL/MPI/Wrappers.h
+++ b/src/TNL/MPI/Wrappers.h
@@ -134,18 +134,6 @@ GetSize( MPI_Comm communicator = MPI_COMM_WORLD )
 
 // wrappers for MPI helper functions
 
-inline MPI_Comm
-Comm_split( MPI_Comm comm, int color, int key )
-{
-#ifdef HAVE_MPI
-   MPI_Comm newcomm;
-   MPI_Comm_split( comm, color, key, &newcomm );
-   return newcomm;
-#else
-   return comm;
-#endif
-}
-
 /**
  * \brief Wrapper for \ref MPI_Dims_create.
  *
diff --git a/src/TNL/MPI/selectGPU.h b/src/TNL/MPI/selectGPU.h
index 9ebf1bcccb..b09bc1f007 100644
--- a/src/TNL/MPI/selectGPU.h
+++ b/src/TNL/MPI/selectGPU.h
@@ -10,7 +10,7 @@
 
 #include <TNL/Cuda/CheckDevice.h>
 
-#include "Utils.h"
+#include "Comm.h"
 
 namespace TNL {
 namespace MPI {
diff --git a/src/TNL/Meshes/DistributedMeshes/DistributedGrid.h b/src/TNL/Meshes/DistributedMeshes/DistributedGrid.h
index 64548a6a99..2c4becf7e5 100644
--- a/src/TNL/Meshes/DistributedMeshes/DistributedGrid.h
+++ b/src/TNL/Meshes/DistributedMeshes/DistributedGrid.h
@@ -100,9 +100,9 @@ public:
    getSubdomainCoordinates() const;
 
    void
-   setCommunicator( MPI_Comm communicator );
+   setCommunicator( MPI::Comm&& communicator );
 
-   MPI_Comm
+   const MPI::Comm&
    getCommunicator() const;
 
    template< int EntityDimension >
@@ -168,7 +168,7 @@ public:
 
    bool isSet = false;
 
-   MPI_Comm communicator = MPI_COMM_WORLD;
+   MPI::Comm communicator = MPI_COMM_WORLD;
 };
 
 template< int Dimension, typename Real, typename Device, typename Index >
diff --git a/src/TNL/Meshes/DistributedMeshes/DistributedGrid.hpp b/src/TNL/Meshes/DistributedMeshes/DistributedGrid.hpp
index c859188549..33cf5d3b1e 100644
--- a/src/TNL/Meshes/DistributedMeshes/DistributedGrid.hpp
+++ b/src/TNL/Meshes/DistributedMeshes/DistributedGrid.hpp
@@ -237,13 +237,13 @@ DistributedMesh< Grid< Dimension, Real, Device, Index > >::getEntitiesCount() co
 
 template< int Dimension, typename Real, typename Device, typename Index >
 void
-DistributedMesh< Grid< Dimension, Real, Device, Index > >::setCommunicator( MPI_Comm communicator )
+DistributedMesh< Grid< Dimension, Real, Device, Index > >::setCommunicator( MPI::Comm&& communicator )
 {
-   this->communicator = communicator;
+   this->communicator = std::move( communicator );
 }
 
 template< int Dimension, typename Real, typename Device, typename Index >
-MPI_Comm
+const MPI::Comm&
 DistributedMesh< Grid< Dimension, Real, Device, Index > >::getCommunicator() const
 {
    return this->communicator;
@@ -383,7 +383,7 @@ DistributedMesh< Grid< Dimension, Real, Device, Index > >::SetupByCut(
       // TODO: set interiorBegin, interiorEnd
 
       const int newRank = getRankOfProcCoord( this->subdomainCoordinates );
-      this->communicator = MPI::Comm_split( oldCommunicator, 1, newRank );
+      this->communicator = MPI::Comm::split( oldCommunicator, 1, newRank );
 
       setupNeighbors();
 
@@ -396,7 +396,7 @@ DistributedMesh< Grid< Dimension, Real, Device, Index > >::SetupByCut(
       return true;
    }
    else {
-      this->communicator = MPI::Comm_split( oldCommunicator, MPI_UNDEFINED, 0 );
+      this->communicator = MPI::Comm::split( oldCommunicator, MPI_UNDEFINED, 0 );
       return false;
    }
 }
diff --git a/src/TNL/Meshes/DistributedMeshes/DistributedMesh.h b/src/TNL/Meshes/DistributedMeshes/DistributedMesh.h
index ff384078b5..6958288730 100644
--- a/src/TNL/Meshes/DistributedMeshes/DistributedMesh.h
+++ b/src/TNL/Meshes/DistributedMeshes/DistributedMesh.h
@@ -10,6 +10,7 @@
 
 #include <TNL/Containers/Array.h>
 #include <TNL/MPI/Wrappers.h>
+#include <TNL/MPI/Comm.h>
 #include <TNL/Meshes/DistributedMeshes/GlobalIndexStorage.h>
 #include <TNL/Meshes/MeshDetails/IndexPermutationApplier.h>
 
@@ -94,12 +95,12 @@ public:
     * Methods specific to the distributed mesh
     */
    void
-   setCommunicator( MPI_Comm communicator )
+   setCommunicator( MPI::Comm&& communicator )
    {
-      this->communicator = communicator;
+      this->communicator = std::move( communicator );
    }
 
-   MPI_Comm
+   const MPI::Comm&
    getCommunicator() const
    {
       return communicator;
@@ -241,7 +242,7 @@ public:
 
 protected:
    MeshType localMesh;
-   MPI_Comm communicator = MPI_COMM_NULL;
+   MPI::Comm communicator = MPI_COMM_NULL;
    int ghostLevels = 0;
 
    // vtkGhostType arrays for points and cells (cached for output into VTK formats)
diff --git a/src/TNL/Meshes/Readers/PVTUReader.h b/src/TNL/Meshes/Readers/PVTUReader.h
index 83abf37d51..f3c72d3e19 100644
--- a/src/TNL/Meshes/Readers/PVTUReader.h
+++ b/src/TNL/Meshes/Readers/PVTUReader.h
@@ -228,10 +228,10 @@ public:
       if( minCount == 0 ) {
          // split the communicator, remove the ranks which did not get a subdomain
          const int color = ( pointsCount > 0 && cellsCount > 0 ) ? 0 : MPI_UNDEFINED;
-         const MPI_Comm subCommunicator = MPI::Comm_split( communicator, color, 0 );
+         MPI::Comm subCommunicator = MPI::Comm::split( communicator, color, 0 );
 
          // set the communicator
-         mesh.setCommunicator( subCommunicator );
+         mesh.setCommunicator( std::move( subCommunicator ) );
       }
       else {
          // set the communicator
diff --git a/src/UnitTests/MPI/CMakeLists.txt b/src/UnitTests/MPI/CMakeLists.txt
index 2c79e14a29..af047e46e6 100644
--- a/src/UnitTests/MPI/CMakeLists.txt
+++ b/src/UnitTests/MPI/CMakeLists.txt
@@ -1,4 +1,4 @@
-set( CPP_TESTS MPIUtilsTest )
+set( CPP_TESTS MPICommTest MPIUtilsTest )
 
 if( ${BUILD_MPI} )
 foreach( target IN ITEMS ${CPP_TESTS} )
diff --git a/src/UnitTests/MPI/MPICommTest.cpp b/src/UnitTests/MPI/MPICommTest.cpp
new file mode 100644
index 0000000000..13562f4189
--- /dev/null
+++ b/src/UnitTests/MPI/MPICommTest.cpp
@@ -0,0 +1,126 @@
+#ifdef HAVE_GTEST
+#include <gtest/gtest.h>
+
+#include <TNL/MPI/Comm.h>
+
+using namespace TNL;
+using namespace TNL::MPI;
+
+TEST( CommTest, COMM_WORLD )
+{
+   const Comm c = MPI_COMM_WORLD;
+   EXPECT_EQ( c.rank(), GetRank( MPI_COMM_WORLD ) );
+   EXPECT_EQ( c.size(), GetSize( MPI_COMM_WORLD ) );
+   EXPECT_EQ( MPI_Comm( c ), MPI_COMM_WORLD );
+   EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_IDENT );
+}
+
+TEST( CommTest, duplicate_static )
+{
+   const Comm c = Comm::duplicate( MPI_COMM_WORLD );
+   EXPECT_EQ( c.rank(), GetRank( MPI_COMM_WORLD ) );
+   EXPECT_EQ( c.size(), GetSize( MPI_COMM_WORLD ) );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+}
+
+TEST( CommTest, duplicate )
+{
+   const Comm c = Comm( MPI_COMM_WORLD ).duplicate();
+   EXPECT_EQ( c.rank(), GetRank( MPI_COMM_WORLD ) );
+   EXPECT_EQ( c.size(), GetSize( MPI_COMM_WORLD ) );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+}
+
+TEST( CommTest, split_static_odd_even )
+{
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const int size = GetSize( MPI_COMM_WORLD );
+   // split into two groups: odd and even based on the original rank
+   const Comm c = Comm::split( MPI_COMM_WORLD, rank % 2, rank );
+   const int my_size = ( size % 2 == 0 || rank % 2 == 1 ) ? size / 2 : size / 2 + 1;
+   EXPECT_EQ( c.rank(), rank / 2 );
+   EXPECT_EQ( c.size(), my_size );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   if( size == 1 ) {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+   }
+   else {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_UNEQUAL );
+   }
+}
+
+TEST( CommTest, split_odd_even )
+{
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const int size = GetSize( MPI_COMM_WORLD );
+   // split into two groups: odd and even based on the original rank
+   const Comm c = Comm( MPI_COMM_WORLD ).split( rank % 2, rank );
+   const int my_size = ( size % 2 == 0 || rank % 2 == 1 ) ? size / 2 : size / 2 + 1;
+   EXPECT_EQ( c.rank(), rank / 2 );
+   EXPECT_EQ( c.size(), my_size );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   if( size == 1 ) {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+   }
+   else {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_UNEQUAL );
+   }
+}
+
+TEST( CommTest, split_static_renumber )
+{
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const int size = GetSize( MPI_COMM_WORLD );
+   // same group, but different ranks
+   const Comm c = Comm::split( MPI_COMM_WORLD, 0, size - 1 - rank );
+   EXPECT_EQ( c.rank(), size - 1 - rank );
+   EXPECT_EQ( c.size(), size );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   if( size == 1 ) {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+   }
+   else {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_SIMILAR );
+   }
+}
+
+TEST( CommTest, split_renumber )
+{
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const int size = GetSize( MPI_COMM_WORLD );
+   // same group, but different ranks
+   const Comm c = Comm( MPI_COMM_WORLD ).split( 0, size - 1 - rank );
+   EXPECT_EQ( c.rank(), size - 1 - rank );
+   EXPECT_EQ( c.size(), size );
+   EXPECT_NE( MPI_Comm( c ), MPI_COMM_WORLD );
+   if( size == 1 ) {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+   }
+   else {
+      EXPECT_EQ( c.compare( MPI_COMM_WORLD ), MPI_SIMILAR );
+   }
+}
+
+#ifdef HAVE_MPI
+TEST( CommTest, split_type_static )
+{
+   // tests are run on a single node, so the resulting communicator is congruent to MPI_COMM_WORLD
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const Comm local_comm = Comm::split_type( MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL );
+   EXPECT_EQ( local_comm.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+}
+
+TEST( CommTest, split_type )
+{
+   // tests are run on a single node, so the resulting communicator is congruent to MPI_COMM_WORLD
+   const int rank = GetRank( MPI_COMM_WORLD );
+   const Comm local_comm = Comm( MPI_COMM_WORLD ).split_type( MPI_COMM_TYPE_SHARED, rank, MPI_INFO_NULL );
+   EXPECT_EQ( local_comm.compare( MPI_COMM_WORLD ), MPI_CONGRUENT );
+}
+#endif
+
+#endif  // HAVE_GTEST
+
+#include "../main_mpi.h"
-- 
GitLab