From 1afab991dc08f1f6349765b932625d4adfd8b747 Mon Sep 17 00:00:00 2001
From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz>
Date: Mon, 1 Apr 2019 12:48:04 +0200
Subject: [PATCH] Added grid synchronizer methods for asynchronous
 synchronization.

---
 src/TNL/Communicators/MpiCommunicator.h       |  7 ++
 src/TNL/Communicators/NoDistrCommunicator.h   |  4 +
 src/TNL/Functions/MeshFunction.h              | 16 ++++
 src/TNL/Functions/MeshFunction_impl.h         | 37 +++++++++
 .../DistributedGridSynchronizer.h             | 82 ++++++++++++++++++-
 5 files changed, 144 insertions(+), 2 deletions(-)

diff --git a/src/TNL/Communicators/MpiCommunicator.h b/src/TNL/Communicators/MpiCommunicator.h
index a40b2e4bb8..48d4d3c8b1 100644
--- a/src/TNL/Communicators/MpiCommunicator.h
+++ b/src/TNL/Communicators/MpiCommunicator.h
@@ -346,6 +346,13 @@ class MpiCommunicator
             throw Exceptions::MPISupportMissing();
 #endif
         }
+         
+         static void wait( Request& request )
+         {
+#ifdef HAVE_MPI
+            MPI_Wait( &request, MPI_STATUS_IGNORE );
+#endif
+         }
 
          static void WaitAll(Request *reqs, int length)
          {
diff --git a/src/TNL/Communicators/NoDistrCommunicator.h b/src/TNL/Communicators/NoDistrCommunicator.h
index 5628522c3e..f471a99ba9 100644
--- a/src/TNL/Communicators/NoDistrCommunicator.h
+++ b/src/TNL/Communicators/NoDistrCommunicator.h
@@ -89,6 +89,10 @@ class NoDistrCommunicator
           return 1;
       }
 
+      static void wait( Request& request )
+      {
+      }
+      
       static void WaitAll(Request *reqs, int length)
       {
       }
diff --git a/src/TNL/Functions/MeshFunction.h b/src/TNL/Functions/MeshFunction.h
index 9d67808e1e..860b44e37a 100644
--- a/src/TNL/Functions/MeshFunction.h
+++ b/src/TNL/Functions/MeshFunction.h
@@ -178,6 +178,22 @@ class MeshFunction :
       void synchronize( bool withPeriodicBoundaryConditions = false,
                         const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
                         Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );
+      
+      template< typename CommunicatorType,
+                typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > >
+      void startSynchronization( std::list< typename CommunicatorType::Request >& requests,
+         bool withPeriodicBoundaryConditions = false,
+         const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
+         Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );
+      
+      template< typename CommunicatorType,
+                typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > >
+      void finishSynchronization( std::list< typename CommunicatorType::Request >& requests,
+         bool withPeriodicBoundaryConditions = false,
+         const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
+         Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );
+
+
 
    protected:
 
diff --git a/src/TNL/Functions/MeshFunction_impl.h b/src/TNL/Functions/MeshFunction_impl.h
index e9426084e0..6ac031d4f0 100644
--- a/src/TNL/Functions/MeshFunction_impl.h
+++ b/src/TNL/Functions/MeshFunction_impl.h
@@ -555,6 +555,42 @@ synchronize( bool periodicBoundaries,
     }
 }
 
+template< typename Mesh,
+          int MeshEntityDimension,
+          typename Real >
+template< typename CommunicatorType,
+          typename PeriodicBoundariesMaskType >
+void
+MeshFunction< Mesh, MeshEntityDimension, Real >:: 
+startSynchronization( std::list< typename CommunicatorType::Request >& requests,
+                      bool periodicBoundaries,
+                      const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask )
+{
+   auto distrMesh = this->getMesh().getDistributedMesh();
+   if(distrMesh != NULL && distrMesh->isDistributed())
+   {
+      this->synchronizer.template startSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask );
+   }
+}
+
+template< typename Mesh,
+          int MeshEntityDimension,
+          typename Real >
+template< typename CommunicatorType,
+          typename PeriodicBoundariesMaskType >
+void
+MeshFunction< Mesh, MeshEntityDimension, Real >:: 
+finishSynchronization( std::list< typename CommunicatorType::Request >& requests,
+                      bool periodicBoundaries,
+                      const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask )
+{
+   auto distrMesh = this->getMesh().getDistributedMesh();
+   if(distrMesh != NULL && distrMesh->isDistributed())
+   {
+      this->synchronizer.template finishSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask );
+   }
+}
+
 template< typename Mesh,
           int MeshEntityDimension,
           typename Real >
@@ -579,3 +615,4 @@ operator << ( std::ostream& str, const MeshFunction< Mesh, MeshEntityDimension,
    } // namespace Functions
 } // namespace TNL
 
+ 
diff --git a/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h b/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h
index 1347683fd0..e99e755fd0 100644
--- a/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h
+++ b/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h
@@ -10,6 +10,7 @@
 
 #pragma once
 
+#include <list>
 #include <TNL/Meshes/Grid.h>
 #include <TNL/Containers/Array.h>
 #include <TNL/Meshes/DistributedMeshes/BufferEntitiesHelper.h>
@@ -123,8 +124,85 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
                neighbors[ i ] == -1 )
                   swap( sendBegin[i], recieveBegin[i] );
          }
-     }
+      }
+
+      template< typename CommunicatorType,
+                typename MeshFunctionType,
+                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
+      void startSynchronization( MeshFunctionType &meshFunction,
+                                 std::list< typename CommunicatorType::Request >& requests,
+                                 bool periodicBoundaries = false,
+                                 const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
+      {
+         TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" );
+
+    	   if( !distributedGrid->isDistributed() ) return;
+         
+         const int *neighbors = distributedGrid->getNeighbors();
+         const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors();
+        
+         //fill send buffers
+         copyBuffers( meshFunction, 
+            sendBuffers, sendBegin,sendDimensions,
+            true,
+            neighbors,
+            periodicBoundaries,
+            PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data );
+
+         //async send and receive
+         //requests.setSize( 2 * this->getNeighborCount() );
+         typename CommunicatorType::CommunicationGroup group;
+         group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
+         int requestsCount( 0 );
+
+         //send everything, receive everything 
+         for( int i=0; i<this->getNeighborCount(); i++ )
+         {
+            /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " 
+               << sendSizes[ i ] << "sendDimensions -> " <<  sendDimensions[ i ]
+               << " upperOverlap -> " << this->distributedGrid->getUpperOverlap() );*/
+            if( neighbors[ i ] != -1 )
+            {
+               //TNL_MPI_PRINT( "Sending data to node " << neighbors[ i ] );
+               requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group ) );
+               //TNL_MPI_PRINT( "Receiving data from node " << neighbors[ i ] );
+               requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group ) );
+            }
+            else if( periodicBoundaries && sendSizes[ i ] !=0 )
+      	   {
+               //TNL_MPI_PRINT( "Sending data to node " << periodicNeighbors[ i ] );
+               requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group ) );
+               //TNL_MPI_PRINT( "Receiving data to node " << periodicNeighbors[ i ] );
+               requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group ) );
+            }
+         }
+      };
+      
+      template< typename CommunicatorType,
+                typename MeshFunctionType,
+                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
+      void finishSynchronization( MeshFunctionType &meshFunction,
+                                  std::list< typename CommunicatorType::Request >& requests,
+                                  bool periodicBoundaries = false,
+                                  const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
+      {
+         //wait until send is done
+         //TNL_MPI_PRINT( "Waiting for data ..." )
+         for( auto& request : requests )
+            CommunicatorType::wait( request );
+
+         //copy data from receive buffers
+         //TNL_MPI_PRINT( "Copying data ..." )
+         const int *neighbors = distributedGrid->getNeighbors();
+         copyBuffers(meshFunction,
+             recieveBuffers,recieveBegin,sendDimensions  ,
+             false,
+             neighbors,
+             periodicBoundaries,
+             mask );
+      }
 
+      
       template< typename CommunicatorType,
                 typename MeshFunctionType,
                 typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
@@ -154,7 +232,7 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
          group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
          int requestsCount( 0 );
 
-         //send everything, recieve everything 
+         //send everything, receive everything 
          for( int i=0; i<this->getNeighborCount(); i++ )
          {
             /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " 
-- 
GitLab