Commit 1afab991 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added grid synchronizer methods for asynchronous synchronization.

parent e28a9e4d
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -347,6 +347,13 @@ class MpiCommunicator
#endif
        }
         
         static void wait( Request& request )
         {
#ifdef HAVE_MPI
            MPI_Wait( &request, MPI_STATUS_IGNORE );
#endif
         }

         static void WaitAll(Request *reqs, int length)
         {
#ifdef HAVE_MPI
+4 −0
Original line number Diff line number Diff line
@@ -89,6 +89,10 @@ class NoDistrCommunicator
          return 1;
      }

      static void wait( Request& request )
      {
      }
      
      static void WaitAll(Request *reqs, int length)
      {
      }
+16 −0
Original line number Diff line number Diff line
@@ -179,6 +179,22 @@ class MeshFunction :
                        const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
                        Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );
      
      template< typename CommunicatorType,
                typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > >
      void startSynchronization( std::list< typename CommunicatorType::Request >& requests,
         bool withPeriodicBoundaryConditions = false,
         const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
         Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );
      
      template< typename CommunicatorType,
                typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > >
      void finishSynchronization( std::list< typename CommunicatorType::Request >& requests,
         bool withPeriodicBoundaryConditions = false,
         const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
         Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );



   protected:

      // TODO: synchronizer should not be part of the mesh function - the way of synchronization
+37 −0
Original line number Diff line number Diff line
@@ -555,6 +555,42 @@ synchronize( bool periodicBoundaries,
    }
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
template< typename CommunicatorType,
          typename PeriodicBoundariesMaskType >
void
MeshFunction< Mesh, MeshEntityDimension, Real >:: 
startSynchronization( std::list< typename CommunicatorType::Request >& requests,
                      bool periodicBoundaries,
                      const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask )
{
   auto distrMesh = this->getMesh().getDistributedMesh();
   if(distrMesh != NULL && distrMesh->isDistributed())
   {
      this->synchronizer.template startSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask );
   }
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
template< typename CommunicatorType,
          typename PeriodicBoundariesMaskType >
void
MeshFunction< Mesh, MeshEntityDimension, Real >:: 
finishSynchronization( std::list< typename CommunicatorType::Request >& requests,
                      bool periodicBoundaries,
                      const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask )
{
   auto distrMesh = this->getMesh().getDistributedMesh();
   if(distrMesh != NULL && distrMesh->isDistributed())
   {
      this->synchronizer.template finishSynchronization<CommunicatorType>( *this, requests, periodicBoundaries, mask );
   }
}

template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
@@ -579,3 +615,4 @@ operator << ( std::ostream& str, const MeshFunction< Mesh, MeshEntityDimension,
   } // namespace Functions
} // namespace TNL

 
+80 −2
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@

#pragma once

#include <list>
#include <TNL/Meshes/Grid.h>
#include <TNL/Containers/Array.h>
#include <TNL/Meshes/DistributedMeshes/BufferEntitiesHelper.h>
@@ -125,6 +126,83 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
         }
      }

      template< typename CommunicatorType,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
      void startSynchronization( MeshFunctionType &meshFunction,
                                 std::list< typename CommunicatorType::Request >& requests,
                                 bool periodicBoundaries = false,
                                 const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
      {
         TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" );

    	   if( !distributedGrid->isDistributed() ) return;
         
         const int *neighbors = distributedGrid->getNeighbors();
         const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors();
        
         //fill send buffers
         copyBuffers( meshFunction, 
            sendBuffers, sendBegin,sendDimensions,
            true,
            neighbors,
            periodicBoundaries,
            PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data );

         //async send and receive
         //requests.setSize( 2 * this->getNeighborCount() );
         typename CommunicatorType::CommunicationGroup group;
         group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
         int requestsCount( 0 );

         //send everything, receive everything 
         for( int i=0; i<this->getNeighborCount(); i++ )
         {
            /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " 
               << sendSizes[ i ] << "sendDimensions -> " <<  sendDimensions[ i ]
               << " upperOverlap -> " << this->distributedGrid->getUpperOverlap() );*/
            if( neighbors[ i ] != -1 )
            {
               //TNL_MPI_PRINT( "Sending data to node " << neighbors[ i ] );
               requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group ) );
               //TNL_MPI_PRINT( "Receiving data from node " << neighbors[ i ] );
               requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group ) );
            }
            else if( periodicBoundaries && sendSizes[ i ] !=0 )
      	   {
               //TNL_MPI_PRINT( "Sending data to node " << periodicNeighbors[ i ] );
               requests.push_back( CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group ) );
               //TNL_MPI_PRINT( "Receiving data to node " << periodicNeighbors[ i ] );
               requests.push_back( CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group ) );
            }
         }
      };
      
      template< typename CommunicatorType,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
      void finishSynchronization( MeshFunctionType &meshFunction,
                                  std::list< typename CommunicatorType::Request >& requests,
                                  bool periodicBoundaries = false,
                                  const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
      {
         //wait until send is done
         //TNL_MPI_PRINT( "Waiting for data ..." )
         for( auto& request : requests )
            CommunicatorType::wait( request );

         //copy data from receive buffers
         //TNL_MPI_PRINT( "Copying data ..." )
         const int *neighbors = distributedGrid->getNeighbors();
         copyBuffers(meshFunction,
             recieveBuffers,recieveBegin,sendDimensions  ,
             false,
             neighbors,
             periodicBoundaries,
             mask );
      }

      
      template< typename CommunicatorType,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
@@ -154,7 +232,7 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
         group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
         int requestsCount( 0 );

         //send everything, recieve everything 
         //send everything, receive everything 
         for( int i=0; i<this->getNeighborCount(); i++ )
         {
            /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> "