Commit 1f622ebb authored by Vít Hanousek's avatar Vít Hanousek Committed by Jakub Klinkovský
Browse files

Issue #15 - working synchronizer

	- without PeriodicBC implementation
	- SubdomainOverlapGetter has changed interface
	 -> need more updates, only basic grid tests work
parent e21b28af
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -62,7 +62,9 @@ class DistributedMesh< Grid< Dimension, Real, Device, Index > >
      const GridType& getGlobalGrid() const;
      
      void setOverlaps( const SubdomainOverlapsType& lower,
                        const SubdomainOverlapsType& upper );
                        const SubdomainOverlapsType& upper,
                        const SubdomainOverlapsType& globalLower,
                        const SubdomainOverlapsType& globalUpper );
      
      void setupGrid( GridType& grid);

@@ -74,10 +76,16 @@ class DistributedMesh< Grid< Dimension, Real, Device, Index > >
      // It is still being used in cuts set-up
      const CoordinatesType& getOverlap() const { return this->overlap;};
      
      //currently used overlaps at this subdomain
      const SubdomainOverlapsType& getLowerOverlap() const;
      
      const SubdomainOverlapsType& getUpperOverlap() const;

      //original overlaps set by user - same values return all subdomains..
      const SubdomainOverlapsType& getGlobalLowerOverlap() const;
      
      const SubdomainOverlapsType& getGlobalUpperOverlap() const;

      //number of elements of local sub domain WITHOUT overlap
      // TODO: getSubdomainDimensions
      const CoordinatesType& getLocalSize() const;
@@ -147,7 +155,7 @@ class DistributedMesh< Grid< Dimension, Real, Device, Index > >
      CoordinatesType globalBegin;
      PointType spaceSteps;
      
      SubdomainOverlapsType lowerOverlap, upperOverlap;
      SubdomainOverlapsType lowerOverlap, upperOverlap, globalLowerOverlap, globalUpperOverlap;

      CoordinatesType domainDecomposition;
      CoordinatesType subdomainCoordinates;   
+22 −1
Original line number Diff line number Diff line
@@ -170,8 +170,13 @@ template< int Dimension, typename Real, typename Device, typename Index >
void
DistributedMesh< Grid< Dimension, Real, Device, Index > >::
setOverlaps( const SubdomainOverlapsType& lower,
             const SubdomainOverlapsType& upper )
             const SubdomainOverlapsType& upper,
             const SubdomainOverlapsType& globalLower,
             const SubdomainOverlapsType& globalUpper )
{
   this->globalLowerOverlap = globalLower;
   this->globalUpperOverlap = globalUpper;

   this->lowerOverlap = lower;
   this->upperOverlap = upper;

@@ -258,6 +263,22 @@ getUpperOverlap() const
   return this->upperOverlap;
};

template< int Dimension, typename Real, typename Device, typename Index >     
const typename DistributedMesh< Grid< Dimension, Real, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< Dimension, Real, Device, Index > >::
getGlobalLowerOverlap() const
{
   return this->globalLowerOverlap;
};

template< int Dimension, typename Real, typename Device, typename Index >     
const typename DistributedMesh< Grid< Dimension, Real, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< Dimension, Real, Device, Index > >::
getGlobalUpperOverlap() const
{
   return this->globalUpperOverlap;
};

template< int Dimension, typename Real, typename Device, typename Index >     
const typename DistributedMesh< Grid< Dimension, Real, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< Dimension, Real, Device, Index > >::
+49 −79
Original line number Diff line number Diff line
@@ -65,8 +65,10 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,

         this->distributedGrid = distributedGrid;
         
         const SubdomainOverlapsType& lowerOverlap = this->distributedGrid->getLowerOverlap();
         const SubdomainOverlapsType& upperOverlap = this->distributedGrid->getUpperOverlap();
         const SubdomainOverlapsType& globalLowerOverlap = this->distributedGrid->getGlobalLowerOverlap();
         const SubdomainOverlapsType& globalUpperOverlap = this->distributedGrid->getGlobalUpperOverlap();
        // const SubdomainOverlapsType& globalPeriodicBCOverlap = this->distributedGrid->getGlobalPeriodicBCOverlap();
         const CoordinatesType& localBegin = this->distributedGrid->getLocalBegin(); 
         const CoordinatesType& localSize = this->distributedGrid->getLocalSize(); 
         const CoordinatesType& localGridSize = this->distributedGrid->getLocalGridSize();         

@@ -74,34 +76,42 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
         {
            Index sendSize=1;
            Index rcvSize=1;

            //bool isBoundary=( neighbor[ i ] == -1 );
            auto directions=Directions::template getXYZ<this->getMeshDimension()>(i);

            sendDimensions[i]=localSize;
            recieveDimensions[i]=localSize;
            sendBegin[i]=localBegin;
            recieveBegin[i]=localBegin;

            for(int j=0;j<this->getMeshDimension();j++)
            {
               if(directions[j]==-1)
               {
                  sendSize*=upperOverlap[j];
                  rcvSize*=lowerOverlap[j];
               }
               if(directions[j]==0)
               {
                  sendSize*=localSize[j];
                  rcvSize*=localSize[j];
                  //TODO:periodicBC
                  sendDimensions[i][j]=globalUpperOverlap[j];
                  recieveDimensions[i][j]=globalLowerOverlap[j];
                  recieveBegin[i][j]=0;
               }

               if(directions[j]==1)
               {
                  sendSize*=lowerOverlap[j];
                  rcvSize*=upperOverlap[j];
                  //TODO:periodicBC
                  sendDimensions[i][j]=globalLowerOverlap[j];
                  recieveDimensions[i][j]=globalUpperOverlap[j];
                  sendBegin[i][j]=localBegin[j]+localSize[j]-globalLowerOverlap[j];
                  recieveBegin[i][j]=localBegin[j]+localSize[j];
               }

               sendSize*=sendDimensions[i][j];
               rcvSize*=recieveDimensions[i][j];
            }

            sendSizes[ i ] = sendSize;
            recieveSizes[ i ] = rcvSize;
            sendBuffers[ i ].setSize( sendSize );
            recieveBuffers[ i ].setSize( rcvSize);

             int world_rank;
             MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
             std::cout<< world_rank<<": " << " "<<lowerOverlap << upperOverlap << std::endl;
         }
        
     }
@@ -118,11 +128,6 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,

    	   if( !distributedGrid->isDistributed() ) return;
         
         const SubdomainOverlapsType& lowerOverlap = distributedGrid->getLowerOverlap();
         const SubdomainOverlapsType& upperOverlap = distributedGrid->getUpperOverlap();
         const CoordinatesType& localSize = this->distributedGrid->getLocalSize(); 
         const CoordinatesType& localBegin = this->distributedGrid->getLocalBegin();
        
         const int *neighbors = distributedGrid->getNeighbors();
         const int *periodicNeighbors = distributedGrid->getPeriodicNeighbors();
         
@@ -132,9 +137,9 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
         }         
        
         //fill send buffers
         copyBuffers( meshFunction, sendBuffers, true,
            localBegin,localSize,
            lowerOverlap, upperOverlap,
         copyBuffers( meshFunction, 
            sendBuffers, sendBegin,sendDimensions,
            true,
            neighbors,
            periodicBoundaries,
            PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data );
@@ -161,15 +166,10 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
        //wait until send is done
        CommunicatorType::WaitAll( requests, requestsCount );

        int world_rank;
        MPI_Comm_rank(MPI_COMM_WORLD, &world_rank);
        for(int i=0;i<this->getNeighborCount();i++)
          std::cout<< world_rank<<": " << i << " send:"<<sendBuffers[ i ]<<" recv:"<<recieveBuffers[ i ] << std::endl;

        //copy data from receive buffers
        copyBuffers(meshFunction, recieveBuffers, false,
            localBegin, localSize,
            lowerOverlap, upperOverlap,
        copyBuffers(meshFunction, 
            recieveBuffers,recieveBegin,recieveDimensions  ,          
            false,
            neighbors,
            periodicBoundaries,
            mask );
@@ -182,67 +182,37 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< MeshDimension,
      void copyBuffers( 
         MeshFunctionType& meshFunction,
         Containers::Array<Real_, Device, Index>* buffers,
         CoordinatesType* begins,
         CoordinatesType* sizes,
         bool toBuffer,
         const CoordinatesType& localBegin,
         const CoordinatesType& localSize,
         const CoordinatesType& lowerOverlap,
         const CoordinatesType& upperOverlap,
         const int* neighbor,
         bool periodicBoundaries,
         const PeriodicBoundariesMaskPointer& mask )
      {

         using Helper = BufferEntitiesHelper< MeshFunctionType, PeriodicBoundariesMaskPointer, this->getMeshDimension(), Real_, Device >;
       
         for(int i=0;i<this->getNeighborCount();i++)//performace isssue - this should be buffered when Synchronizer is created
         for(int i=0;i<this->getNeighborCount();i++)
         {
            bool isBoundary=( neighbor[ i ] == -1 );           
            
            CoordinatesType begin=localBegin;
            CoordinatesType size=localSize;
            auto directions=Directions::template getXYZ<this->getMeshDimension()>(i);
            for(int j=0;j<this->getMeshDimension();j++)
            {
               if(toBuffer)
               {
                  if(directions[j]==-1)
                  {
                     size[j]=upperOverlap[j];
                  }
                  if(directions[j]==1)
                  {
                     begin[j]=localBegin[j]+localSize[j]-lowerOverlap[j];
                     size[j]=lowerOverlap[j];
                  }
               }
               else
               {  
                  if(directions[j]==-1)
                  {
                     //tady se asi bude řešit periodic boundary
                     begin[j]=0;
                     size[j]=lowerOverlap[j];
                  }
                  if(directions[j]==1)
            if( ! isBoundary || periodicBoundaries )
            {
                     begin[j]=localBegin[j]+localSize[j];
                     size[j]=upperOverlap[j];
                  Helper::BufferEntities( meshFunction, mask, buffers[ i ].getData(), isBoundary, begins[i], sizes[i], toBuffer );
            }                  
         }      
      }
    
            if( ! isBoundary || periodicBoundaries )
                  Helper::BufferEntities( meshFunction, mask, buffers[ i ].getData(), isBoundary, begin, size, toBuffer );
   private:
   
         }
      }
      Containers::Array<RealType, Device, Index> sendBuffers[getNeighborCount()];
      Containers::Array<RealType, Device, Index> recieveBuffers[getNeighborCount()];
      Containers::StaticArray< getNeighborCount(), int > sendSizes;
      Containers::StaticArray< getNeighborCount(), int > recieveSizes;

   private:
  
      Containers::Array<RealType, Device, Index> sendBuffers[DirectionCount<MeshDimension>::get()];
      Containers::Array<RealType, Device, Index> recieveBuffers[DirectionCount<MeshDimension>::get()];
      Containers::StaticArray< DirectionCount<MeshDimension>::get(), int > sendSizes;
      Containers::StaticArray< DirectionCount<MeshDimension>::get(), int > recieveSizes;
      CoordinatesType sendDimensions[getNeighborCount()];
      CoordinatesType recieveDimensions[getNeighborCount()];
      CoordinatesType sendBegin[getNeighborCount()];
      CoordinatesType recieveBegin[getNeighborCount()];
      
      DistributedGridType *distributedGrid;

+2 −0
Original line number Diff line number Diff line
@@ -73,6 +73,8 @@ class SubdomainOverlapsGetter< Grid< Dimension, Real, Device, Index >, Communica
      static void getOverlaps( const DistributedMeshType* distributedMesh,
                               SubdomainOverlapsType& lower,
                               SubdomainOverlapsType& upper,
                               SubdomainOverlapsType& globalLower,
                               SubdomainOverlapsType& globalUpper,
                               IndexType subdomainOverlapSize,
                               const SubdomainOverlapsType& periodicBoundariesOverlapSize = 0 );
   
+5 −0
Original line number Diff line number Diff line
@@ -27,6 +27,8 @@ SubdomainOverlapsGetter< Grid< Dimension, Real, Device, Index >, Communicator >:
getOverlaps( const DistributedMeshType* distributedMesh,
             SubdomainOverlapsType& lower,
             SubdomainOverlapsType& upper,
             SubdomainOverlapsType& globalLower,
             SubdomainOverlapsType& globalUpper,
             IndexType subdomainOverlapSize,
             const SubdomainOverlapsType& periodicBoundariesOverlapSize )
{
@@ -38,6 +40,9 @@ getOverlaps( const DistributedMeshType* distributedMesh,
   
   for( int i = 0; i < Dimension; i++ )
   {
      globalLower[i]=subdomainOverlapSize;
      globalUpper[i]=subdomainOverlapSize;

      if( subdomainCoordinates[ i ] > 0 )
         lower[ i ] = subdomainOverlapSize;
      else
Loading