Skip to content
Snippets Groups Projects
DistributedGridSynchronizer.h 9.27 KiB
Newer Older
  • Learn to ignore specific revisions
  • /***************************************************************************
    
                              DistributedGridSynchronizer.h  -  description
    
                                 -------------------
        begin                : March 17, 2017
        copyright            : (C) 2017 by Tomas Oberhuber
        email                : tomas.oberhuber@fjfi.cvut.cz
     ***************************************************************************/
    
    /* See Copyright Notice in tnl/Copyright */
    
    #pragma once
    
    
    #include <TNL/Meshes/Grid.h>
    #include <TNL/Containers/Array.h>
    #include <TNL/Meshes/DistributedMeshes/BufferEntitiesHelper.h>
    #include <TNL/Meshes/DistributedMeshes/Directions.h>
    
    Jakub Klinkovský's avatar
    Jakub Klinkovský committed
    #include <TNL/Pointers/SharedPointer.h>
    
    
    namespace TNL {
    namespace Functions{
    template< typename Mesh,
              int MeshEntityDimension,
              typename Real  >
    class MeshFunction;
    }//Functions
    }//TNL
    
    namespace TNL {
    namespace Meshes { 
    namespace DistributedMeshes { 
    
    template <typename RealType,
              int EntityDimension,
              int MeshDimension,
              typename Index,
              typename Device,
              typename GridReal>  
    class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, GridReal, Device, Index >,EntityDimension, RealType>>
    {
    
       public:
          static constexpr int getMeshDimension() { return MeshDimension; };
          static constexpr int getNeighborCount() {return DirectionCount<MeshDimension>::get();};
    
          typedef typename Grid< MeshDimension, GridReal, Device, Index >::Cell Cell;
          // FIXME: clang does not like this (incomplete type error)
    //      typedef typename Functions::MeshFunction< Grid< 3, GridReal, Device, Index >,EntityDimension, RealType> MeshFunctionType;
          typedef typename Grid< MeshDimension, GridReal, Device, Index >::DistributedMeshType DistributedGridType; 
          typedef typename DistributedGridType::CoordinatesType CoordinatesType;
          using SubdomainOverlapsType = typename DistributedGridType::SubdomainOverlapsType;
    
    
          enum PeriodicBoundariesCopyDirection
          {
             BoundaryToOverlap,
             OverlapToBoundary
          };
    
    
          DistributedMeshSynchronizer()
          {
             isSet = false;
          };
    
          DistributedMeshSynchronizer( DistributedGridType *distributedGrid )
          {
             isSet = false;
             setDistributedGrid( distributedGrid );
          };
    
    
          void setPeriodicBoundariesCopyDirection( const PeriodicBoundariesCopyDirection dir )
          {
             this->periodicBoundariesCopyDirection = dir;
          }
    
    
          void setDistributedGrid( DistributedGridType *distributedGrid )
          {
             isSet = true;
    
             this->distributedGrid = distributedGrid;
             
             const SubdomainOverlapsType& lowerOverlap = this->distributedGrid->getLowerOverlap();
             const SubdomainOverlapsType& upperOverlap = this->distributedGrid->getUpperOverlap();
           
             const CoordinatesType& localBegin = this->distributedGrid->getLocalBegin(); 
             const CoordinatesType& localSize = this->distributedGrid->getLocalSize(); 
    
             const CoordinatesType& localGridSize = this->distributedGrid->getLocalGridSize();
    
             const int *neighbors = distributedGrid->getNeighbors();
    
    
             for( int i=0; i<this->getNeighborCount(); i++ )
             {
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                Index sendSize=1;//send and receive  areas have the same size
    
               // bool isBoundary=( neighbor[ i ] == -1 );
                auto directions=Directions::template getXYZ<getMeshDimension()>(i);
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                sendDimensions[i]=localSize; //send and receive areas have the same dimensions
    
                sendBegin[i]=localBegin;
                recieveBegin[i]=localBegin;
    
                for(int j=0;j<this->getMeshDimension();j++)
                {
                   if(directions[j]==-1)
                   {
                      sendDimensions[i][j]=lowerOverlap[j];
                      recieveBegin[i][j]=0;
                   }
    
                   if(directions[j]==1)
                   {
                      sendDimensions[i][j]=upperOverlap[j];
                      sendBegin[i][j]=localBegin[j]+localSize[j]-upperOverlap[j];
                      recieveBegin[i][j]=localBegin[j]+localSize[j];
                   }
    
                   sendSize*=sendDimensions[i][j];
                }
    
                sendSizes[ i ] = sendSize;
                sendBuffers[ i ].setSize( sendSize );
                recieveBuffers[ i ].setSize( sendSize);
    
                if( this->periodicBoundariesCopyDirection == OverlapToBoundary &&
                   neighbors[ i ] == -1 )
                      swap( sendBegin[i], recieveBegin[i] );
    
          template< typename CommunicatorType,
                    typename MeshFunctionType,
                    typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
          void synchronize( MeshFunctionType &meshFunction,
                            bool periodicBoundaries = false,
                            const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
          {
    
             TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" );
    
        	   if( !distributedGrid->isDistributed() ) return;
             
             const int *neighbors = distributedGrid->getNeighbors();
             const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors();
            
             //fill send buffers
             copyBuffers( meshFunction, 
                sendBuffers, sendBegin,sendDimensions,
                true,
                neighbors,
                periodicBoundaries,
                PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data );
    
             //async send and receive
             typename CommunicatorType::Request requests[2*this->getNeighborCount()];
             typename CommunicatorType::CommunicationGroup group;
             group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
             int requestsCount( 0 );
    
             //send everything, recieve everything 
             for( int i=0; i<this->getNeighborCount(); i++ )
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                /*TNL_MPI_PRINT( "Sending data... " << i << " sizes -> " 
    
                   << sendSizes[ i ] << "sendDimensions -> " <<  sendDimensions[ i ]
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                   << " upperOverlap -> " << this->distributedGrid->getUpperOverlap() );*/
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                   //TNL_MPI_PRINT( "Sending data to node " << neighbors[ i ] );
    
    Vít Hanousek's avatar
    Vít Hanousek committed
                   requests[ requestsCount++ ] = CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group );
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                   //TNL_MPI_PRINT( "Receiving data from node " << neighbors[ i ] );
    
    Vít Hanousek's avatar
    Vít Hanousek committed
                   requests[ requestsCount++ ] = CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group );
    
                }
                else if( periodicBoundaries && sendSizes[ i ] !=0 )
          	   {
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                   //TNL_MPI_PRINT( "Sending data to node " << periodicNeighbors[ i ] );
    
    Vít Hanousek's avatar
    Vít Hanousek committed
                   requests[ requestsCount++ ] = CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group );
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
                   //TNL_MPI_PRINT( "Receiving data to node " << periodicNeighbors[ i ] );
    
    Vít Hanousek's avatar
    Vít Hanousek committed
                   requests[ requestsCount++ ] = CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group );
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
            //TNL_MPI_PRINT( "Waiting for data ..." )
    
            CommunicatorType::WaitAll( requests, requestsCount );
    
            //copy data from receive buffers
    
    Tomáš Oberhuber's avatar
    Tomáš Oberhuber committed
            //TNL_MPI_PRINT( "Copying data ..." )
    
            copyBuffers(meshFunction,
                recieveBuffers,recieveBegin,sendDimensions  ,
    
    
       private:
          template< typename Real_,
    
                    typename MeshFunctionType,
                    typename PeriodicBoundariesMaskPointer >
    
             MeshFunctionType& meshFunction,
             Containers::Array<Real_, Device, Index>* buffers,
             CoordinatesType* begins,
             CoordinatesType* sizes,
             bool toBuffer,
             const int* neighbor,
             bool periodicBoundaries,
             const PeriodicBoundariesMaskPointer& mask )
          {
    
             using Helper = BufferEntitiesHelper< MeshFunctionType, PeriodicBoundariesMaskPointer, getMeshDimension(), Real_, Device >;
    
           
             for(int i=0;i<this->getNeighborCount();i++)
             {
    
                bool isBoundary=( neighbor[ i ] == -1 );
    
                if( ! isBoundary || periodicBoundaries )
                {
    
                   Helper::BufferEntities( meshFunction, mask, buffers[ i ].getData(), isBoundary, begins[i], sizes[i], toBuffer );
                }
             }
    
          Containers::Array<RealType, Device, Index> sendBuffers[getNeighborCount()];
          Containers::Array<RealType, Device, Index> recieveBuffers[getNeighborCount()];
          Containers::StaticArray< getNeighborCount(), int > sendSizes;
    
          PeriodicBoundariesCopyDirection periodicBoundariesCopyDirection = BoundaryToOverlap;
    
    
          CoordinatesType sendDimensions[getNeighborCount()];
          CoordinatesType recieveDimensions[getNeighborCount()];
          CoordinatesType sendBegin[getNeighborCount()];
          CoordinatesType recieveBegin[getNeighborCount()];
    
          DistributedGridType *distributedGrid;
    
          bool isSet;
    };
    
    
    } // namespace DistributedMeshes
    } // namespace Meshes
    } // namespace TNL