From 1afab991dc08f1f6349765b932625d4adfd8b747 Mon Sep 17 00:00:00 2001 From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz> Date: Mon, 1 Apr 2019 12:48:04 +0200 Subject: [PATCH] Added grid synchronizer methods for asynchronous synchronization. --- src/TNL/Communicators/MpiCommunicator.h | 7 ++ src/TNL/Communicators/NoDistrCommunicator.h | 4 + src/TNL/Functions/MeshFunction.h | 16 ++++ src/TNL/Functions/MeshFunction_impl.h | 37 +++++++++ .../DistributedGridSynchronizer.h | 82 ++++++++++++++++++- 5 files changed, 144 insertions(+), 2 deletions(-) diff --git a/src/TNL/Communicators/MpiCommunicator.h b/src/TNL/Communicators/MpiCommunicator.h index a40b2e4bb8..48d4d3c8b1 100644 --- a/src/TNL/Communicators/MpiCommunicator.h +++ b/src/TNL/Communicators/MpiCommunicator.h @@ -346,6 +346,13 @@ class MpiCommunicator throw Exceptions::MPISupportMissing(); #endif } + + static void wait( Request& request ) + { +#ifdef HAVE_MPI + MPI_Wait( &request, MPI_STATUS_IGNORE ); +#endif + } static void WaitAll(Request *reqs, int length) { diff --git a/src/TNL/Communicators/NoDistrCommunicator.h b/src/TNL/Communicators/NoDistrCommunicator.h index 5628522c3e..f471a99ba9 100644 --- a/src/TNL/Communicators/NoDistrCommunicator.h +++ b/src/TNL/Communicators/NoDistrCommunicator.h @@ -89,6 +89,10 @@ class NoDistrCommunicator return 1; } + static void wait( Request& request ) + { + } + static void WaitAll(Request *reqs, int length) { } diff --git a/src/TNL/Functions/MeshFunction.h b/src/TNL/Functions/MeshFunction.h index 9d67808e1e..860b44e37a 100644 --- a/src/TNL/Functions/MeshFunction.h +++ b/src/TNL/Functions/MeshFunction.h @@ -178,6 +178,22 @@ class MeshFunction : void synchronize( bool withPeriodicBoundaryConditions = false, const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask = Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) ); + + template< typename CommunicatorType, + typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > > + void startSynchronization( std::list< typename CommunicatorType::Request >& requests, + bool withPeriodicBoundaryConditions = false, + const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask = + Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) ); + + template< typename CommunicatorType, + typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > > + void finishSynchronization( std::list< typename CommunicatorType::Request >& requests, + bool withPeriodicBoundaryConditions = false, + const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask = + Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) ); + + protected: diff --git a/src/TNL/Functions/MeshFunction_impl.h b/src/TNL/Functions/MeshFunction_impl.h index e9426084e0..6ac031d4f0 100644 --- a/src/TNL/Functions/MeshFunction_impl.h +++ b/src/TNL/Functions/MeshFunction_impl.h @@ -555,6 +555,42 @@ synchronize( bool periodicBoundaries, } } +template< typename Mesh, + int MeshEntityDimension, + typename Real > +template< typename CommunicatorType, + typename PeriodicBoundariesMaskType > +void +MeshFunction< Mesh, MeshEntityDimension, Real >:: +startSynchronization( std::list< typename CommunicatorType::Request >& requests, + bool periodicBoundaries, + const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask ) +{ + auto distrMesh = this->getMesh().getDistributedMesh(); + if(distrMesh != NULL && distrMesh->isDistributed()) + { + this->synchronizer.template startSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask ); + } +} + +template< typename Mesh, + int MeshEntityDimension, + typename Real > +template< typename CommunicatorType, + typename PeriodicBoundariesMaskType > +void +MeshFunction< Mesh, MeshEntityDimension, Real >:: +finishSynchronization( std::list< typename CommunicatorType::Request >& requests, + bool periodicBoundaries, + const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask ) +{ + auto distrMesh = this->getMesh().getDistributedMesh(); + if(distrMesh != NULL && distrMesh->isDistributed()) + { + this->synchronizer.template finishSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask ); + } +} + template< typename Mesh, int MeshEntityDimension, typename Real > @@ -579,3 +615,4 @@ operator << ( std::ostream& str, const MeshFunction< Mesh, MeshEntityDimension, } // namespace Functions } // namespace TNL + diff --git a/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h b/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h index 1347683fd0..e99e755fd0 100644 --- a/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h +++ b/src/TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer.h @@ -10,6 +10,7 @@ #pragma once +#include <list> #include <TNL/Meshes/Grid.h> #include <TNL/Containers/Array.h> #include <TNL/Meshes/DistributedMeshes/BufferEntitiesHelper.h> @@ -123,8 +124,85 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, neighbors[ i ] == -1 ) swap( sendBegin[i], recieveBegin[i] ); } - } + } + + template< typename CommunicatorType, + typename MeshFunctionType, + typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > > + void startSynchronization( MeshFunctionType &meshFunction, + std::list< typename CommunicatorType::Request >& requests, + bool periodicBoundaries = false, + const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) ) + { + TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" ); + + if( !distributedGrid->isDistributed() ) return; + + const int *neighbors = distributedGrid->getNeighbors(); + const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors(); + + //fill send buffers + copyBuffers( meshFunction, + sendBuffers, sendBegin,sendDimensions, + true, + neighbors, + periodicBoundaries, + PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data ); + + //async send and receive + //requests.setSize( 2 * this->getNeighborCount() ); + typename CommunicatorType::CommunicationGroup group; + group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup())); + int requestsCount( 0 ); + + //send everything, receive everything + for( int i=0; i<this->getNeighborCount(); i++ ) + { + /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " + << sendSizes[ i ] << "sendDimensions -> " << sendDimensions[ i ] + << " upperOverlap -> " << this->distributedGrid->getUpperOverlap() );*/ + if( neighbors[ i ] != -1 ) + { + //TNL_MPI_PRINT( "Sending data to node " << neighbors[ i ] ); + requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(), sendSizes[ i ], neighbors[ i ], 0, group ) ); + //TNL_MPI_PRINT( "Receiving data from node " << neighbors[ i ] ); + requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(), sendSizes[ i ], neighbors[ i ], 0, group ) ); + } + else if( periodicBoundaries && sendSizes[ i ] !=0 ) + { + //TNL_MPI_PRINT( "Sending data to node " << periodicNeighbors[ i ] ); + requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(), sendSizes[ i ], periodicNeighbors[ i ], 1, group ) ); + //TNL_MPI_PRINT( "Receiving data to node " << periodicNeighbors[ i ] ); + requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(), sendSizes[ i ], periodicNeighbors[ i ], 1, group ) ); + } + } + }; + + template< typename CommunicatorType, + typename MeshFunctionType, + typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > > + void finishSynchronization( MeshFunctionType &meshFunction, + std::list< typename CommunicatorType::Request >& requests, + bool periodicBoundaries = false, + const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) ) + { + //wait until send is done + //TNL_MPI_PRINT( "Waiting for data ..." ) + for( auto& request : requests ) + CommunicatorType::wait( request ); + + //copy data from receive buffers + //TNL_MPI_PRINT( "Copying data ..." ) + const int *neighbors = distributedGrid->getNeighbors(); + copyBuffers(meshFunction, + recieveBuffers,recieveBegin,sendDimensions , + false, + neighbors, + periodicBoundaries, + mask ); + } + template< typename CommunicatorType, typename MeshFunctionType, typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > > @@ -154,7 +232,7 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup())); int requestsCount( 0 ); - //send everything, recieve everything + //send everything, receive everything for( int i=0; i<this->getNeighborCount(); i++ ) { /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " -- GitLab