Commit e9eecefb authored by Vít Hanousek's avatar Vít Hanousek
Browse files

DisstributedGrid 2D amd 3D refactorization - DistributedGrid_Base template

parent ed98de1f
Loading
Loading
Loading
Loading
+0 −2
Original line number Diff line number Diff line
@@ -28,8 +28,6 @@ class DistributedMesh< Grid< 1, RealType, Device, Index > > : public Distributed
      using typename DistributedGrid_Base<1, RealType, Device, Index >::PointType;
      using typename DistributedGrid_Base<1, RealType, Device, Index >::CoordinatesType;

      DistributedMesh();

      bool setup( const Config::ParameterContainer& parameters,
                  const String& prefix );
      
+0 −7
Original line number Diff line number Diff line
@@ -14,13 +14,6 @@ namespace TNL {
   namespace Meshes {
      namespace DistributedMeshes {

template<typename RealType, typename Device, typename Index >
DistributedMesh< Grid< 1, RealType, Device, Index > >::
DistributedMesh()
{
    this->domainDecomposition=CoordinatesType( 0 );
}

template< typename RealType, typename Device, typename Index >     
bool
DistributedMesh< Grid< 1, RealType, Device, Index > >::
+11 −60
Original line number Diff line number Diff line
@@ -19,31 +19,20 @@ namespace DistributedMeshes {
enum Directions2D { Left = 0 , Right = 1 , Up = 2, Down=3, UpLeft =4, UpRight=5, DownLeft=6, DownRight=7 }; 

template< typename RealType, typename Device, typename Index >
class DistributedMesh< Grid< 2, RealType, Device, Index > >
class DistributedMesh< Grid< 2, RealType, Device, Index > >: public DistributedGrid_Base<2, RealType, Device, Index >
{
   public:

      typedef Index IndexType;
      typedef Grid< 2, RealType, Device, IndexType > GridType;
      typedef typename GridType::PointType PointType;
      typedef Containers::StaticVector< 2, IndexType > CoordinatesType;
      using typename DistributedGrid_Base<2, RealType, Device, Index >::IndexType;
      using typename DistributedGrid_Base<2, RealType, Device, Index >::GridType;
      using typename DistributedGrid_Base<2, RealType, Device, Index >::PointType;
      using typename DistributedGrid_Base<2, RealType, Device, Index >::CoordinatesType;

      static constexpr int getMeshDimension() { return 2; };
    
     
   public:
     
      DistributedMesh();
      
      void setDomainDecomposition( const CoordinatesType& domainDecomposition );
      
      const CoordinatesType& getDomainDecomposition() const;
      
      template< int EntityDimension >
/*      template< int EntityDimension >
      IndexType getEntitiesCount() const;

      template< typename Entity >
      IndexType getEntitiesCount() const;            
      IndexType getEntitiesCount() const;*/

      bool setup( const Config::ParameterContainer& parameters,
                  const String& prefix );
@@ -58,53 +47,15 @@ class DistributedMesh< Grid< 2, RealType, Device, Index > >

      String printProcessDistr() const;
            
      bool isDistributed() const;
       
      const CoordinatesType& getOverlap() const;
       
      const int* getNeighbors() const;
             
      const CoordinatesType& getLocalSize() const;

      //number of elements of global grid
      const CoordinatesType& getGlobalSize() const;

      //coordinates of begin of local subdomain without overlaps in global grid
      const CoordinatesType& getGlobalBegin() const;

      const CoordinatesType& getLocalGridSize() const;
       
      const CoordinatesType& getLocalBegin() const;
       
      void writeProlog( Logger& logger ) const;
               
   private : 
       
      int getRankOfProcCoord(int x, int y) const;        
        
      GridType globalGrid;
      PointType spaceSteps;
      PointType localOrigin;
      CoordinatesType localSize;//velikost gridu zpracovavane danym uzlem bez prekryvu
      CoordinatesType localBegin;//souradnice zacatku zpracovavane vypoctove oblasi
      CoordinatesType localGridSize;//velikost lokálního gridu včetně překryvů
      CoordinatesType overlap;
      CoordinatesType globalSize;//velikost celé sítě
      CoordinatesType globalBegin;
        
        
      IndexType Dimensions;        
      bool distributed;
        
      int rank;
      int nproc;
        
      CoordinatesType domainDecomposition;
      CoordinatesType subdomainCoordinates;
      int numberOfLarger[2];
        
      int neighbors[8];
      bool isSet;

};

} // namespace DistributedMeshes
+79 −157
Original line number Diff line number Diff line
@@ -16,28 +16,7 @@ namespace TNL {
   namespace Meshes { 
      namespace DistributedMeshes {

template< typename RealType, typename Device, typename Index >
DistributedMesh< Grid< 2, RealType, Device, Index > >::
DistributedMesh()
: domainDecomposition( 0 ), isSet( false ) {}

template< typename RealType, typename Device, typename Index >
void
DistributedMesh< Grid< 2, RealType, Device, Index > >::
setDomainDecomposition( const CoordinatesType& domainDecomposition )
{
   this->domainDecomposition = domainDecomposition;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getDomainDecomposition() const
{
   return this->domainDecomposition;
}

template< typename RealType, typename Device, typename Index >     
/*template< typename RealType, typename Device, typename Index >     
   template< int EntityDimension >
Index
DistributedMesh< Grid< 2, RealType, Device, Index > >::
@@ -53,7 +32,7 @@ DistributedMesh< Grid< 2, RealType, Device, Index > >::
getEntitiesCount() const
{
   return this->globalGrid. template getEntitiesCount< Entity >();
}
}*/

template< typename RealType, typename Device, typename Index >
bool
@@ -74,126 +53,125 @@ setGlobalGrid( const GridType &globalGrid,
               const CoordinatesType& overlap )
{
   this->globalGrid = globalGrid;
   isSet=true;
   this->isSet=true;
   this->overlap=overlap;

   for( int i=0; i<8; i++ )
      neighbors[i]=-1;

   Dimensions= GridType::getMeshDimension();
   spaceSteps=globalGrid.getSpaceSteps();
   distributed=false;
   this->Dimensions= GridType::getMeshDimension();
   this->spaceSteps=globalGrid.getSpaceSteps();
   this->distributed=false;

   if( CommunicatorType::IsInitialized() )
   {
      rank=CommunicatorType::GetRank();
      this->rank=CommunicatorType::GetRank();
      this->nproc=CommunicatorType::GetSize();
      //use MPI only if have more than one process
      if(this->nproc>1)
      {
         distributed=true;
         this->distributed=true;
      }
   }

   if( !distributed )
   if( !this->distributed )
   {
      subdomainCoordinates[0]=0;
      subdomainCoordinates[1]=0;
      domainDecomposition[0]=1;
      domainDecomposition[1]=1;
      localOrigin=globalGrid.getOrigin();
      localGridSize=globalGrid.getDimensions();
      localSize=globalGrid.getDimensions();
      globalSize=globalGrid.getDimensions();
      globalBegin=CoordinatesType(0);
      localBegin.x()=0;
      localBegin.y()=0;
      this->subdomainCoordinates[0]=0;
      this->subdomainCoordinates[1]=0;
      this->domainDecomposition[0]=1;
      this->domainDecomposition[1]=1;
      this->localOrigin=globalGrid.getOrigin();
      this->localGridSize=globalGrid.getDimensions();
      this->localSize=globalGrid.getDimensions();
      this->globalBegin=CoordinatesType(0);
      this->localBegin.x()=0;
      this->localBegin.y()=0;

      return;
   }
   else
   {
      int numberOfLarger[2];
      //compute node distribution
      int dims[ 2 ];
      dims[ 0 ] = domainDecomposition[ 0 ];
      dims[ 1 ] = domainDecomposition[ 1 ];
      dims[ 0 ] = this->domainDecomposition[ 0 ];
      dims[ 1 ] = this->domainDecomposition[ 1 ];

      CommunicatorType::DimsCreate( nproc, 2, dims );
      domainDecomposition[ 0 ] = dims[ 0 ];
      domainDecomposition[ 1 ] = dims[ 1 ];
      CommunicatorType::DimsCreate( this->nproc, 2, dims );
      this->domainDecomposition[ 0 ] = dims[ 0 ];
      this->domainDecomposition[ 1 ] = dims[ 1 ];

      subdomainCoordinates[ 0 ] = rank % domainDecomposition[ 0 ];
      subdomainCoordinates[ 1 ] = rank / domainDecomposition[ 0 ];        
      this->subdomainCoordinates[ 0 ] = this->rank % this->domainDecomposition[ 0 ];
      this->subdomainCoordinates[ 1 ] = this->rank / this->domainDecomposition[ 0 ];        

      //compute local mesh size            
      globalSize=globalGrid.getDimensions();              
      numberOfLarger[0]=globalGrid.getDimensions().x()%domainDecomposition[0];
      numberOfLarger[1]=globalGrid.getDimensions().y()%domainDecomposition[1];
      numberOfLarger[0]=globalGrid.getDimensions().x()%this->domainDecomposition[0];
      numberOfLarger[1]=globalGrid.getDimensions().y()%this->domainDecomposition[1];

      localSize.x()=(globalGrid.getDimensions().x()/domainDecomposition[0]);
      localSize.y()=(globalGrid.getDimensions().y()/domainDecomposition[1]);
      this->localSize.x()=(globalGrid.getDimensions().x()/this->domainDecomposition[0]);
      this->localSize.y()=(globalGrid.getDimensions().y()/this->domainDecomposition[1]);

      if(numberOfLarger[0]>subdomainCoordinates[0])
           localSize.x()+=1;               
      if(numberOfLarger[1]>subdomainCoordinates[1])
          localSize.y()+=1;
      if(numberOfLarger[0]>this->subdomainCoordinates[0])
           this->localSize.x()+=1;               
      if(numberOfLarger[1]>this->subdomainCoordinates[1])
          this->localSize.y()+=1;

      if(numberOfLarger[0]>subdomainCoordinates[0])
          globalBegin.x()=subdomainCoordinates[0]*localSize.x();
      if(numberOfLarger[0]>this->subdomainCoordinates[0])
          this->globalBegin.x()=this->subdomainCoordinates[0]*this->localSize.x();
      else
          globalBegin.x()=numberOfLarger[0]*(localSize.x()+1)+(subdomainCoordinates[0]-numberOfLarger[0])*localSize.x();
          this->globalBegin.x()=numberOfLarger[0]*(this->localSize.x()+1)+(this->subdomainCoordinates[0]-numberOfLarger[0])*this->localSize.x();

      if(numberOfLarger[1]>subdomainCoordinates[1])
          globalBegin.y()=subdomainCoordinates[1]*localSize.y();
      if(numberOfLarger[1]>this->subdomainCoordinates[1])
          this->globalBegin.y()=this->subdomainCoordinates[1]*this->localSize.y();

      else
          globalBegin.y()=numberOfLarger[1]*(localSize.y()+1)+(subdomainCoordinates[1]-numberOfLarger[1])*localSize.y();
          this->globalBegin.y()=numberOfLarger[1]*(this->localSize.y()+1)+(this->subdomainCoordinates[1]-numberOfLarger[1])*this->localSize.y();

      localOrigin=globalGrid.getOrigin()+TNL::Containers::tnlDotProduct(globalGrid.getSpaceSteps(),globalBegin-overlap);
      this->localOrigin=globalGrid.getOrigin()+TNL::Containers::tnlDotProduct(globalGrid.getSpaceSteps(),this->globalBegin-this->overlap);

      //nearnodes
      if(subdomainCoordinates[0]>0)
          neighbors[Left]=getRankOfProcCoord(subdomainCoordinates[0]-1,subdomainCoordinates[1]);
      if(subdomainCoordinates[0]<domainDecomposition[0]-1)
          neighbors[Right]=getRankOfProcCoord(subdomainCoordinates[0]+1,subdomainCoordinates[1]);
      if(subdomainCoordinates[1]>0)
          neighbors[Up]=getRankOfProcCoord(subdomainCoordinates[0],subdomainCoordinates[1]-1);
      if(subdomainCoordinates[1]<domainDecomposition[1]-1)
          neighbors[Down]=getRankOfProcCoord(subdomainCoordinates[0],subdomainCoordinates[1]+1);
      if(subdomainCoordinates[0]>0 && subdomainCoordinates[1]>0)
          neighbors[UpLeft]=getRankOfProcCoord(subdomainCoordinates[0]-1,subdomainCoordinates[1]-1);
      if(subdomainCoordinates[0]>0 && subdomainCoordinates[1]<domainDecomposition[1]-1)
          neighbors[DownLeft]=getRankOfProcCoord(subdomainCoordinates[0]-1,subdomainCoordinates[1]+1);
      if(subdomainCoordinates[0]<domainDecomposition[0]-1 && subdomainCoordinates[1]>0)
          neighbors[UpRight]=getRankOfProcCoord(subdomainCoordinates[0]+1,subdomainCoordinates[1]-1);
      if(subdomainCoordinates[0]<domainDecomposition[0]-1 && subdomainCoordinates[1]<domainDecomposition[1]-1)
          neighbors[DownRight]=getRankOfProcCoord(subdomainCoordinates[0]+1,subdomainCoordinates[1]+1);

      localBegin=overlap;
      if(this->subdomainCoordinates[0]>0)
          neighbors[Left]=getRankOfProcCoord(this->subdomainCoordinates[0]-1,this->subdomainCoordinates[1]);
      if(this->subdomainCoordinates[0]<this->domainDecomposition[0]-1)
          neighbors[Right]=getRankOfProcCoord(this->subdomainCoordinates[0]+1,this->subdomainCoordinates[1]);
      if(this->subdomainCoordinates[1]>0)
          neighbors[Up]=getRankOfProcCoord(this->subdomainCoordinates[0],this->subdomainCoordinates[1]-1);
      if(this->subdomainCoordinates[1]<this->domainDecomposition[1]-1)
          neighbors[Down]=getRankOfProcCoord(this->subdomainCoordinates[0],this->subdomainCoordinates[1]+1);
      if(this->subdomainCoordinates[0]>0 && this->subdomainCoordinates[1]>0)
          neighbors[UpLeft]=getRankOfProcCoord(this->subdomainCoordinates[0]-1,this->subdomainCoordinates[1]-1);
      if(this->subdomainCoordinates[0]>0 && this->subdomainCoordinates[1]<this->domainDecomposition[1]-1)
          neighbors[DownLeft]=getRankOfProcCoord(this->subdomainCoordinates[0]-1,this->subdomainCoordinates[1]+1);
      if(this->subdomainCoordinates[0]<this->domainDecomposition[0]-1 && this->subdomainCoordinates[1]>0)
          neighbors[UpRight]=getRankOfProcCoord(this->subdomainCoordinates[0]+1,this->subdomainCoordinates[1]-1);
      if(this->subdomainCoordinates[0]<this->domainDecomposition[0]-1 && this->subdomainCoordinates[1]<this->domainDecomposition[1]-1)
          neighbors[DownRight]=getRankOfProcCoord(this->subdomainCoordinates[0]+1,this->subdomainCoordinates[1]+1);

      this->localBegin=this->overlap;

      if(neighbors[Left]==-1)
      {
           localOrigin.x()+=overlap.x()*globalGrid.getSpaceSteps().x();
           localBegin.x()=0;
           this->localOrigin.x()+=this->overlap.x()*globalGrid.getSpaceSteps().x();
           this->localBegin.x()=0;
      }

      if(neighbors[Up]==-1)
      {
          localOrigin.y()+=overlap.y()*globalGrid.getSpaceSteps().y();
          localBegin.y()=0;
          this->localOrigin.y()+=this->overlap.y()*globalGrid.getSpaceSteps().y();
          this->localBegin.y()=0;
      }

      localGridSize=localSize;
      this->localGridSize=this->localSize;
      //Add Overlaps
      if(neighbors[Left]!=-1)
          localGridSize.x()+=overlap.x();
          this->localGridSize.x()+=this->overlap.x();
      if(neighbors[Right]!=-1)
          localGridSize.x()+=overlap.x();
          this->localGridSize.x()+=this->overlap.x();

      if(neighbors[Up]!=-1)
          localGridSize.y()+=overlap.y();
          this->localGridSize.y()+=this->overlap.y();
      if(neighbors[Down]!=-1)
          localGridSize.y()+=overlap.y();
          this->localGridSize.y()+=this->overlap.y();
  }
}

@@ -202,11 +180,11 @@ void
DistributedMesh< Grid< 2, RealType, Device, Index > >::
setupGrid( GridType& grid )
{
   TNL_ASSERT_TRUE(isSet,"DistributedGrid is not set, but used by SetupGrid");
   grid.setOrigin( localOrigin );
   grid.setDimensions( localGridSize );
   TNL_ASSERT_TRUE(this->isSet,"DistributedGrid is not set, but used by SetupGrid");
   grid.setOrigin( this->localOrigin );
   grid.setDimensions( this->localGridSize );
   //compute local proporions by sideefect
   grid.setSpaceSteps( spaceSteps );
   grid.setSpaceSteps( this->spaceSteps );
   grid.SetDistMesh(this);
};

@@ -215,7 +193,7 @@ String
DistributedMesh< Grid< 2, RealType, Device, Index > >::
printProcessCoords() const
{
   return convertToString(subdomainCoordinates[0])+String("-")+convertToString(subdomainCoordinates[1]);
   return convertToString(this->subdomainCoordinates[0])+String("-")+convertToString(this->subdomainCoordinates[1]);
};

template< typename RealType, typename Device, typename Index >
@@ -223,23 +201,7 @@ String
DistributedMesh< Grid< 2, RealType, Device, Index > >::
printProcessDistr() const
{
   return convertToString(domainDecomposition[0])+String("-")+convertToString(domainDecomposition[1]);
};  

template< typename RealType, typename Device, typename Index >
bool
DistributedMesh< Grid< 2, RealType, Device, Index > >::
isDistributed() const
{
   return this->distributed;
};

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getOverlap() const
{
   return this->overlap;
   return convertToString(this->domainDecomposition[0])+String("-")+convertToString(this->domainDecomposition[1]);
};  

template< typename RealType, typename Device, typename Index >
@@ -247,50 +209,10 @@ const int*
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getNeighbors() const
{
   TNL_ASSERT_TRUE(isSet,"DistributedGrid is not set, but used by getNeighbors");
   TNL_ASSERT_TRUE(this->isSet,"DistributedGrid is not set, but used by getNeighbors");
   return this->neighbors;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getLocalSize() const
{
   return this->localSize;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getGlobalSize() const
{
   return this->globalSize;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getGlobalBegin() const
{
   return this->globalBegin;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getLocalGridSize() const
{
   return this->localGridSize;
}

template< typename RealType, typename Device, typename Index >
const typename DistributedMesh< Grid< 2, RealType, Device, Index > >::CoordinatesType&
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getLocalBegin() const
{
   return this->localBegin;
}

template< typename RealType, typename Device, typename Index >
void
DistributedMesh< Grid< 2, RealType, Device, Index > >::
@@ -304,7 +226,7 @@ int
DistributedMesh< Grid< 2, RealType, Device, Index > >::
getRankOfProcCoord(int x, int y) const
{
   return y*domainDecomposition[0]+x;
   return y*this->domainDecomposition[0]+x;
}
         
      } //namespace DistributedMeshes
+9 −54
Original line number Diff line number Diff line
@@ -27,34 +27,28 @@ enum Directions3D { West = 0 , East = 1 , North = 2, South=3, Top =4, Bottom=5,


template< typename RealType, typename Device, typename Index >
class DistributedMesh<Grid< 3, RealType, Device, Index >>
class DistributedMesh<Grid< 3, RealType, Device, Index >> : public DistributedGrid_Base<3, RealType, Device, Index >
{

    public:

      typedef Index IndexType;
      typedef Grid< 3, RealType, Device, IndexType > GridType;
      typedef typename GridType::PointType PointType;
      typedef Containers::StaticVector< 3, IndexType > CoordinatesType;
      using typename DistributedGrid_Base<3, RealType, Device, Index >::IndexType;
      using typename DistributedGrid_Base<3, RealType, Device, Index >::GridType;
      using typename DistributedGrid_Base<3, RealType, Device, Index >::PointType;
      using typename DistributedGrid_Base<3, RealType, Device, Index >::CoordinatesType;

      static constexpr int getMeshDimension() { return 3; };    
    
      DistributedMesh();
    
      static void configSetup( Config::ConfigDescription& config );
      
      bool setup( const Config::ParameterContainer& parameters,
                  const String& prefix );
              
      void setDomainDecomposition( const CoordinatesType& domainDecomposition );
      
      const CoordinatesType& getDomainDecomposition() const;
            
      template< int EntityDimension >
/*      template< int EntityDimension >
      IndexType getEntitiesCount() const;

      template< typename Entity >
      IndexType getEntitiesCount() const;      
      IndexType getEntitiesCount() const;   */   

      template< typename CommunicatorType > 
      void setGlobalGrid( const GridType& globalGrid,
@@ -66,54 +60,15 @@ class DistributedMesh<Grid< 3, RealType, Device, Index >>

      String printProcessDistr() const;

      bool isDistributed() const;
       
      const CoordinatesType& getOverlap() const;
       
      const int* getNeighbors() const;
       
      const CoordinatesType& getLocalSize() const;
       
      const CoordinatesType& getLocalGridSize() const;
       
      const CoordinatesType& getLocalBegin() const;

      //number of elements of global grid
      const CoordinatesType& getGlobalSize() const;

      //coordinates of begin of local subdomain without overlaps in global grid
      const CoordinatesType& getGlobalBegin() const;
       
      void writeProlog( Logger& logger );

   private:

      int getRankOfProcCoord(int x, int y, int z) const;           
        
      GridType globalGrid;
      
      PointType spaceSteps;
      PointType localOrigin;
      CoordinatesType localSize;
      CoordinatesType localGridSize;
      CoordinatesType localBegin;
      CoordinatesType overlap;
      CoordinatesType globalSize;
      CoordinatesType globalBegin;
        
      IndexType Dimensions;        
      bool distributed;
        
      int rank;
      int nproc;
        
      CoordinatesType domainDecomposition;
      CoordinatesType subdomainCoordinates;
      int numberOfLarger[3];
        
      int neighbors[26];

      bool isSet;
};

} // namespace DistributedMeshes
Loading