Commit 632087b6 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Merge branch 'mpi' into 'develop'

MPI: refactoring of DistributedGridSynchronizer

See merge request !21
parents 9402d6cd e1dec977
Loading
Loading
Loading
Loading
+24 −12
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>
#include <TNL/ParallelFor.h>
#include <TNL/Containers/StaticVector.h>

namespace TNL {
namespace Meshes { 
@@ -43,10 +44,14 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 1, RealType, Device,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& sizex,
         const Containers::StaticVector<1,Index>& begin,
         const Containers::StaticVector<1,Index>& size,
         bool tobuffer )
      {

         Index beginx=begin.x();
         Index sizex=size.x();

         auto mesh = meshFunction.getMesh();
         RealType* meshFunctionData = meshFunction.getData().getData();
         const typename MaskPointer::ObjectType* mask( nullptr );
@@ -83,12 +88,16 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 2, RealType, Device,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& beginy,
         const Index& sizex,
         const Index& sizey,
         const Containers::StaticVector<2,Index>& begin,
         const Containers::StaticVector<2,Index>& size,
         bool tobuffer)
      {

         Index beginx=begin.x();
         Index beginy=begin.y();
         Index sizex=size.x();
         Index sizey=size.y();

         auto mesh=meshFunction.getMesh();
         RealType* meshFunctionData = meshFunction.getData().getData();      
         const typename MaskPointer::ObjectType* mask( nullptr );
@@ -127,15 +136,18 @@ class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 3, RealType, Device,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& beginy,
         const Index& beginz,
         const Index& sizex,
         const Index& sizey,
         const Index& sizez,
         const Containers::StaticVector<3,Index>& begin,
         const Containers::StaticVector<3,Index>& size,
         bool tobuffer)
      {

         Index beginx=begin.x();
         Index beginy=begin.y();
         Index beginz=begin.z();
         Index sizex=size.x();
         Index sizey=size.y();
         Index sizez=size.z();

         auto mesh=meshFunction.getMesh();
         RealType * meshFunctionData=meshFunction.getData().getData();
         const typename MaskPointer::ObjectType* mask( nullptr );
+0 −3
Original line number Diff line number Diff line
@@ -5,9 +5,6 @@ SET( headers BufferEntitiesHelper.h
             DistributedGrid.h
             DistributedGrid.hpp
             DistributedGridSynchronizer.h
             DistributedGridSynchronizer_1D.h
             DistributedGridSynchronizer_2D.h
             DistributedGridSynchronizer_3D.h
             DistributedGridIO.h
             DistributedGridIO_MeshFunction.h
             DistributedGridIO_VectorField.h
+27 −10
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ namespace DistributedMeshes {
//  -> 0 - not used, 1 negative direction, 2 positive direction
//finaly we subtrackt 1 because we dont need (0,0,0) aka 0 aka no direction

enum Directions2D { Left = 0 , Right = 1 , Up = 2, UpLeft =3, UpRight=4, Down=5, DownLeft=6, DownRight=7 }; 
//enum Directions2D { Left = 0 , Right = 1 , Up = 2, UpLeft =3, UpRight=4, Down=5, DownLeft=6, DownRight=7 }; 

/*MEH - osa zed je zdola nahoru, asi---
enum Directions3D { West = 0 , East = 1 , 
@@ -28,16 +28,33 @@ enum Directions3D { West = 0 , East = 1 ,
                    BottomSouth = 23, BottomSouthWest = 24, BottomSouthEast = 25,
                  };*/

enum Directions3D { West = 0 , East = 1 , 
                    North = 2, NorthWest = 3, NorthEast = 4,
                    South = 5, SouthWest = 6, SouthEast = 7,
                    Bottom = 8 ,BottomWest = 9 , BottomEast = 10 , 
                    BottomNorth = 11, BottomNorthWest = 12, BottomNorthEast = 13,
                    BottomSouth = 14, BottomSouthWest = 15, BottomSouthEast = 16,
                    Top = 17, TopWest = 18, TopEast =19,
                    TopNorth = 20, TopNorthWest = 21, TopNorthEast = 22,
                    TopSouth = 23, TopSouthWest = 24,TopSouthEast = 25,
/*
with self
enum Directions3D { 
                    ZzYzXz =  0, ZzYzXm =  1, ZzYzXp =  2, 
                    ZzYmXz =  3, ZzYmXm =  4, ZzYmXp =  5,
                    ZzYpXz =  6, ZzYpXm =  7, ZzYpXp =  8,                    
                    ZmYzXz =  9, ZmYzXm = 10, ZmYzXp = 11, 
                    ZmYmXz = 12, ZmYmXm = 13, ZmYmXp = 14,
                    ZmYpXz = 15, ZmYpXm = 16, ZmYpXp = 17,
                    ZpYzXz = 18, ZpYzXm = 19, ZpYzXp = 20, 
                    ZpYmXz = 21, ZpYmXm = 22, ZpYmXp = 23,
                    ZpYpXz = 24, ZpYpXm = 25, ZpYpXp = 26
                  };
*/

enum Directions3D { 
                    ZzYzXm =  0, ZzYzXp =  1, 
                    ZzYmXz =  2, ZzYmXm =  3, ZzYmXp =  4,
                    ZzYpXz =  5, ZzYpXm =  6, ZzYpXp =  7,                    
                    ZmYzXz =  8, ZmYzXm =  9, ZmYzXp = 10, 
                    ZmYmXz = 11, ZmYmXm = 12, ZmYmXp = 13,
                    ZmYpXz = 14, ZmYpXm = 15, ZmYpXp = 16,
                    ZpYzXz = 17, ZpYzXm = 18, ZpYzXp = 19, 
                    ZpYmXz = 20, ZpYmXm = 21, ZpYmXp = 22,
                    ZpYpXz = 23, ZpYpXm = 24, ZpYpXp = 25
                  };


class Directions {

+3 −2
Original line number Diff line number Diff line
@@ -74,6 +74,7 @@ class DistributedMesh< Grid< Dimension, Real, Device, Index > >
      // It is still being used in cuts set-up
      const CoordinatesType& getOverlap() const { return this->overlap;};
      
      //currently used overlaps at this subdomain
      const SubdomainOverlapsType& getLowerOverlap() const;
      
      const SubdomainOverlapsType& getUpperOverlap() const;
@@ -147,7 +148,7 @@ class DistributedMesh< Grid< Dimension, Real, Device, Index > >
      CoordinatesType globalBegin;
      PointType spaceSteps;
      
      SubdomainOverlapsType lowerOverlap, upperOverlap;
      SubdomainOverlapsType lowerOverlap, upperOverlap, globalLowerOverlap, globalUpperOverlap;

      CoordinatesType domainDecomposition;
      CoordinatesType subdomainCoordinates;   
+219 −3
Original line number Diff line number Diff line
@@ -10,6 +10,222 @@

#pragma once

#include <TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer_1D.h>
#include <TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer_2D.h>
#include <TNL/Meshes/DistributedMeshes/DistributedGridSynchronizer_3D.h>
#include <TNL/Meshes/Grid.h>
#include <TNL/Containers/Array.h>
#include <TNL/Meshes/DistributedMeshes/BufferEntitiesHelper.h>
#include <TNL/Meshes/DistributedMeshes/Directions.h>

namespace TNL {
namespace Functions{
template< typename Mesh,
          int MeshEntityDimension,
          typename Real  >
class MeshFunction;
}//Functions
}//TNL

namespace TNL {
namespace Meshes { 
namespace DistributedMeshes { 

template <typename RealType,
          int EntityDimension,
          int MeshDimension,
          typename Index,
          typename Device,
          typename GridReal>  
class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension, GridReal, Device, Index >,EntityDimension, RealType>>
{

   public:
      static constexpr int getMeshDimension() { return MeshDimension; };
      static constexpr int getNeighborCount() {return DirectionCount<MeshDimension>::get();};

      typedef typename Grid< MeshDimension, GridReal, Device, Index >::Cell Cell;
      // FIXME: clang does not like this (incomplete type error)
//      typedef typename Functions::MeshFunction< Grid< 3, GridReal, Device, Index >,EntityDimension, RealType> MeshFunctionType;
      typedef typename Grid< MeshDimension, GridReal, Device, Index >::DistributedMeshType DistributedGridType; 
      typedef typename DistributedGridType::CoordinatesType CoordinatesType;
      using SubdomainOverlapsType = typename DistributedGridType::SubdomainOverlapsType;
          
      DistributedMeshSynchronizer()
      {
         isSet = false;
      };

      DistributedMeshSynchronizer( DistributedGridType *distributedGrid )
      {
         isSet = false;
         setDistributedGrid( distributedGrid );
      };

      void setDistributedGrid( DistributedGridType *distributedGrid )
      {
         isSet = true;

         this->distributedGrid = distributedGrid;
         
         const SubdomainOverlapsType& lowerOverlap = this->distributedGrid->getLowerOverlap();
         const SubdomainOverlapsType& upperOverlap = this->distributedGrid->getUpperOverlap();
       
         const CoordinatesType& localBegin = this->distributedGrid->getLocalBegin(); 
         const CoordinatesType& localSize = this->distributedGrid->getLocalSize(); 
         const CoordinatesType& localGridSize = this->distributedGrid->getLocalGridSize();

         const int *neighbors = distributedGrid->getNeighbors();

         for( int i=0; i<this->getNeighborCount(); i++ )
         {
            Index sendSize=1;//sended and recieve areas has same size

           // bool isBoundary=( neighbor[ i ] == -1 );
            auto directions=Directions::template getXYZ<getMeshDimension()>(i);

            sendDimensions[i]=localSize;//send and recieve areas has same dimensions
            sendBegin[i]=localBegin;
            recieveBegin[i]=localBegin;

            for(int j=0;j<this->getMeshDimension();j++)
            {
               if(directions[j]==-1)
               {
                  sendDimensions[i][j]=lowerOverlap[j];
                  recieveBegin[i][j]=0;
               }

               if(directions[j]==1)
               {
                  sendDimensions[i][j]=upperOverlap[j];
                  sendBegin[i][j]=localBegin[j]+localSize[j]-upperOverlap[j];
                  recieveBegin[i][j]=localBegin[j]+localSize[j];
               }

               sendSize*=sendDimensions[i][j];
            }

            sendSizes[ i ] = sendSize;
            sendBuffers[ i ].setSize( sendSize );
            recieveBuffers[ i ].setSize( sendSize);

            //Periodic-BC copy from overlap into domain
            //if Im on boundary, and i is direction of the boundary i will swap source and destination
            //i do this only for basic 6 directions, 
            //because this swap at conners and edges produces writing multiple values at sam place in localsubdomain
            {
               if(  (  i==ZzYzXm || i==ZzYzXp 
                     ||i==ZzYmXz || i==ZzYpXz 
                     ||i==ZmYzXz || i==ZpYzXz ) 
                  && neighbors[ i ] == -1) 
               {
                  //swap begins
                  CoordinatesType tmp = sendBegin[i];
                  sendBegin[i]=recieveBegin[i];
                  recieveBegin[i]=tmp;
               }
            }

         }
     }
        
      template< typename CommunicatorType,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
      void synchronize( MeshFunctionType &meshFunction,
                        bool periodicBoundaries = false,
                        const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
      {

         TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" );

    	   if( !distributedGrid->isDistributed() ) return;
         
         const int *neighbors = distributedGrid->getNeighbors();
         const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors();
        
         //fill send buffers
         copyBuffers( meshFunction, 
            sendBuffers, sendBegin,sendDimensions,
            true,
            neighbors,
            periodicBoundaries,
            PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data );
        
         //async send and receive
         typename CommunicatorType::Request requests[2*this->getNeighborCount()];
         typename CommunicatorType::CommunicationGroup group;
         group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
         int requestsCount( 0 );
		                
         //send everything, recieve everything 
         for( int i=0; i<this->getNeighborCount(); i++ )
            if( neighbors[ i ] != -1 )
            {
               requests[ requestsCount++ ] = CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group );
               requests[ requestsCount++ ] = CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], neighbors[ i ], 0, group );
            }
            else if( periodicBoundaries && sendSizes[ i ] !=0 )
      	   {
               requests[ requestsCount++ ] = CommunicatorType::ISend( sendBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group );
               requests[ requestsCount++ ] = CommunicatorType::IRecv( recieveBuffers[ i ].getData(),  sendSizes[ i ], periodicNeighbors[ i ], 1, group );
            }

        //wait until send is done
        CommunicatorType::WaitAll( requests, requestsCount );

        //copy data from receive buffers
        copyBuffers(meshFunction, 
            recieveBuffers,recieveBegin,sendDimensions  ,          
            false,
            neighbors,
            periodicBoundaries,
            mask );
    }
    
   private:      
      template< typename Real_, 
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer >
      void copyBuffers( 
         MeshFunctionType& meshFunction,
         Containers::Array<Real_, Device, Index>* buffers,
         CoordinatesType* begins,
         CoordinatesType* sizes,
         bool toBuffer,
         const int* neighbor,
         bool periodicBoundaries,
         const PeriodicBoundariesMaskPointer& mask )
      {
         using Helper = BufferEntitiesHelper< MeshFunctionType, PeriodicBoundariesMaskPointer, getMeshDimension(), Real_, Device >;
       
         for(int i=0;i<this->getNeighborCount();i++)
         {
            bool isBoundary=( neighbor[ i ] == -1 );           
            if( ! isBoundary || periodicBoundaries )
            {
                  Helper::BufferEntities( meshFunction, mask, buffers[ i ].getData(), isBoundary, begins[i], sizes[i], toBuffer );
            }                  
         }      
      }
    
   private:
   
      Containers::Array<RealType, Device, Index> sendBuffers[getNeighborCount()];
      Containers::Array<RealType, Device, Index> recieveBuffers[getNeighborCount()];
      Containers::StaticArray< getNeighborCount(), int > sendSizes;
  
      CoordinatesType sendDimensions[getNeighborCount()];
      CoordinatesType recieveDimensions[getNeighborCount()];
      CoordinatesType sendBegin[getNeighborCount()];
      CoordinatesType recieveBegin[getNeighborCount()];
      
      DistributedGridType *distributedGrid;

      bool isSet;
    
};


} // namespace DistributedMeshes
} // namespace Meshes
} // namespace TNL
Loading