Commit ae4adcd0 authored by Vít Hanousek's avatar Vít Hanousek
Browse files

Lihgt refactorization of 1D and 2D distributedGridSynchronizer. Now hav same...

Lihgt refactorization of 1D and 2D distributedGridSynchronizer. Now hav same interface as 3D distributedGridSynchronizer and function Synchronize is lighter.
parent 8647b05f
Loading
Loading
Loading
Loading
+5 −0
Original line number Diff line number Diff line
@@ -415,6 +415,11 @@ class DistributedGrid <GridType,2>
           return this->localsize;
       }

       CoordinatesType getLocalGridSize()
       {
           return this->localgridsize;
       }
       
              
       CoordinatesType getLocalBegin()
       {
+176 −106
Original line number Diff line number Diff line
@@ -32,45 +32,75 @@ template <typename DistributedGridType,
		typename MeshFunctionType>  
class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,1>
{
    typedef typename MeshFunctionType::RealType Real;

    public:
    static void Synchronize(DistributedGridType distributedgrid, MeshFunctionType meshfunction)
    {
        if(!distributedgrid.isMPIUsed())
                return;
#ifdef USE_MPI
public:
        typedef typename MeshFunctionType::MeshType::Cell Cell;
        typedef typename MeshFunctionType::RealType Real;


private:  
        Real * leftsendbuf;
        Real * rightsendbuf;
        Real * leftrcvbuf;
        Real * rightrcvbuf;

        int size = distributedgrid.getOverlap().x();
        int size;

        DistributedGridType *distributedgrid;
#endif

    
    public:
    DistributedGridSynchronizer(DistributedGridType *distrgrid)
    {
        this->distributedgrid=distrgrid;
#ifdef USE_MPI
        size = distributedgrid->getOverlap().x();

        leftsendbuf=new Real[size];
        rightsendbuf=new Real[size];
        leftrcvbuf=new Real[size];
        rightrcvbuf=new Real[size];      
#endif
    }

    ~DistributedGridSynchronizer()
    {
        delete [] leftrcvbuf;
        delete [] rightrcvbuf;
        delete [] leftsendbuf;
        delete [] rightsendbuf; 
    }

    void Synchronize(MeshFunctionType &meshfunction)
    {
        if(!distributedgrid->isMPIUsed())
                return;
#ifdef USE_MPI

        Cell leftentity(meshfunction.getMesh());
        Cell rightentity(meshfunction.getMesh());

        int left=distributedgrid->getLeft();
        int right=distributedgrid->getRight();

        //fill send buffers
        typename MeshFunctionType::MeshType::Cell leftentity(meshfunction.getMesh());
        typename MeshFunctionType::MeshType::Cell rightentity(meshfunction.getMesh());
        for(int i=0;i<size;i++)
        {
            if(left!=-1)
            {
                leftentity.getCoordinates().x() = size+i;
                leftentity.refresh();
            //leftsendbuf[i]=meshfunction.getValue(leftentity);
                leftsendbuf[i]=meshfunction.getData()[leftentity.getIndex()];
            }
    
            if(right!=-1)
            {
                rightentity.getCoordinates().x() = meshfunction.getMesh().getDimensions().x()-2*size+i;
                rightentity.refresh();            
            //rightsendbuf[i]=meshfunction.getValue(rightentity);
                rightsendbuf[i]=meshfunction.getData()[rightentity.getIndex()];

            }

        }

        //async send
        MPI::Request leftsendreq;
@@ -79,59 +109,49 @@ class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,1>
        MPI::Request rightrcvreq;

        //send everithing, recieve everything 
        //cout << distributedgrid.getLeft() << "   " << distributedgrid.getRight() << endl;
        if(distributedgrid.getLeft()!=-1)
        if(left!=-1)
        {
            leftsendreq=MPI::COMM_WORLD.Isend((void*) leftsendbuf, size, MPI::DOUBLE , distributedgrid.getLeft(), 0);
            leftrcvreq=MPI::COMM_WORLD.Irecv((void*) leftrcvbuf, size, MPI::DOUBLE, distributedgrid.getLeft(), 0);
            leftsendreq=MPI::COMM_WORLD.Isend((void*) leftsendbuf, size, MPI::DOUBLE , left, 0);
            leftrcvreq=MPI::COMM_WORLD.Irecv((void*) leftrcvbuf, size, MPI::DOUBLE, left, 0);
        }        
        if(distributedgrid.getRight()!=-1)
        if(right!=-1)
        {
            rightsendreq=MPI::COMM_WORLD.Isend((void*) rightsendbuf, size, MPI::DOUBLE , distributedgrid.getRight(), 0);
            rightrcvreq=MPI::COMM_WORLD.Irecv((void*) rightrcvbuf, size, MPI::DOUBLE, distributedgrid.getRight(), 0);
            rightsendreq=MPI::COMM_WORLD.Isend((void*) rightsendbuf, size, MPI::DOUBLE , right, 0);
            rightrcvreq=MPI::COMM_WORLD.Irecv((void*) rightrcvbuf, size, MPI::DOUBLE, right, 0);
        }

        //wait until send is done
        if(distributedgrid.getLeft()!=-1)
        if(left!=-1)
        {
            leftrcvreq.Wait();
            leftsendreq.Wait();
        }        
        if(distributedgrid.getRight()!=-1)
        if(right!=-1)
        {
            rightrcvreq.Wait();
            rightsendreq.Wait();
        }

        //copy data form rcv buffers
        if(distributedgrid.getLeft()!=-1)
        if(left!=-1)
        {
            for(int i=0;i<size;i++)
            {
                leftentity.getCoordinates().x() = i;
                leftentity.refresh();
                //leftsendbuf[i]=meshfunction.getValue(leftentity);
                meshfunction.getData()[leftentity.getIndex()]=leftrcvbuf[i];
            }
        }


        if(distributedgrid.getRight()!=-1)
        if(right!=-1)
        {
            for(int i=0;i<size;i++)
            {
                rightentity.getCoordinates().x() = meshfunction.getMesh().getDimensions().x()-size+i;
                rightentity.refresh();
                //rightsendbuf[i]=meshfunction.getValue(rightentity);
                meshfunction.getData()[rightentity.getIndex()]=rightrcvbuf[i];
            }
        }

        //free buffers
        delete [] leftrcvbuf;
        delete [] rightrcvbuf;
        delete [] leftsendbuf;
        delete [] rightsendbuf;  
#endif
    };
};
@@ -141,32 +161,51 @@ template <typename DistributedGridType,
		typename MeshFunctionType>  
class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,2>
{
    public:
    static void Synchronize(DistributedGridType distributedgrid, MeshFunctionType meshfunction)
    {
	if(!distributedgrid.isMPIUsed())
            return;

#ifdef USE_MPI
    public:
        typedef typename MeshFunctionType::RealType Real;
        typedef typename DistributedGridType::CoordinatesType CoordinatesType;

    private:
        DistributedGridType *distributedgrid;

        Real ** sendbuffs=new Real*[8];
        Real ** rcvbuffs=new Real*[8];
        Real * sendbuffs[8];
        Real * rcvbuffs[8];
        int sizes[8];
        
        int *neighbor=distributedgrid.getNeighbors();
        int leftSrc;
        int rightSrc;
        int upSrc;
        int downSrc;
        int xcenter;
        int ycenter;
        int leftDst;
        int rightDst;
        int upDst;
        int downDst;
        
        CoordinatesType overlap = distributedgrid.getOverlap();
        CoordinatesType localgridsize = meshfunction.getMesh().getDimensions();
        CoordinatesType overlap;
        CoordinatesType localsize;
#endif

        CoordinatesType localsize=distributedgrid.getLocalSize();
        CoordinatesType localbegin=distributedgrid.getLocalBegin();
    public:
    DistributedGridSynchronizer(DistributedGridType *distgrid)
    {
        
#ifdef USE_MPI
        this->distributedgrid=distgrid;

        overlap = distributedgrid->getOverlap();
        localsize = distributedgrid->getLocalSize();
        
        CoordinatesType localgridsize = this->distributedgrid->getLocalGridSize();
        CoordinatesType localbegin=this->distributedgrid->getLocalBegin();

        int updownsize=localsize.x()*overlap.y();
        int leftrightsize=localsize.y()*overlap.x();
        int connersize=overlap.x()*overlap.y();

        int sizes[8];
        sizes[Left]=leftrightsize;
        sizes[Right]=leftrightsize;
        sizes[Up]=updownsize;
@@ -182,25 +221,50 @@ class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,2>
            rcvbuffs[i]=new Real[sizes[i]];
        }

        //fill send buffers
	BufferEntities(meshfunction,sendbuffs[Left],localbegin.x(),localbegin.y(),overlap.x(),localsize.y(),true);
        BufferEntities(meshfunction,sendbuffs[Right],localgridsize.x()-2*overlap.x(),localbegin.y(),overlap.x(),localsize.y(),true);
	BufferEntities(meshfunction,sendbuffs[Up],localbegin.x(),localbegin.y(),localsize.x(),overlap.y(),true);
	BufferEntities(meshfunction,sendbuffs[Down],localbegin.x(),localgridsize.y()-2*overlap.y(),localsize.x(),overlap.y(),true);
        leftSrc=localbegin.x();
        rightSrc=localgridsize.x()-2*overlap.x();
        upSrc=localbegin.y();
        downSrc=localgridsize.y()-2*overlap.y();
            
        xcenter=localbegin.x();
        ycenter=localbegin.y();
        
        leftDst=0;
        rightDst=localgridsize.x()-overlap.x();
        upDst=0;
        downDst=localgridsize.y()-overlap.y();                       
#endif
        
	BufferEntities(meshfunction,sendbuffs[UpLeft],localbegin.x(),localbegin.y(),overlap.x(),overlap.y(),true);
	BufferEntities(meshfunction,sendbuffs[UpRight],localgridsize.x()-2*overlap.x(),localbegin.y(),overlap.x(),overlap.y(),true);
	BufferEntities(meshfunction,sendbuffs[DownLeft],localbegin.x(),localgridsize.y()-2*overlap.y(),overlap.x(),overlap.y(),true);
	BufferEntities(meshfunction,sendbuffs[DownRight],localgridsize.x()-2*overlap.x(),localgridsize.y()-2*overlap.y(),overlap.x(),overlap.y(),true);
    }

    ~DistributedGridSynchronizer()
    {
        for(int i=0;i<8;i++)
        {
            delete [] sendbuffs[i];
            delete [] rcvbuffs[i];
        }
    }
        
    void Synchronize(MeshFunctionType &meshfunction)
    {
	if(!distributedgrid->isMPIUsed())
            return;
#ifdef USE_MPI

    int *neighbor=distributedgrid->getNeighbors();

    CopyBuffers(meshfunction, sendbuffs, true,
            leftSrc, rightSrc, upSrc, downSrc,
            xcenter, ycenter,
            overlap,localsize,
            neighbor);
	
        //async send
        MPI::Request sendreq[8];
        MPI::Request rcvreq[8];
		                
        
                
        //send everithing, recieve everything 
        //cout << distributedgrid.getLeft() << "   " << distributedgrid.getRight() << endl;
        for(int i=0;i<8;i++)	
           if(neighbor[i]!=-1)
           {
@@ -219,35 +283,41 @@ class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,2>
        }

        //copy data form rcv buffers
        CopyBuffers(meshfunction, rcvbuffs, false,
            leftDst, rightDst, upDst, downDst,
            xcenter, ycenter,
            overlap,localsize,
            neighbor);
    };
    
    private:
    template <typename Real>
    void CopyBuffers(MeshFunctionType meshfunction, Real ** buffers, bool toBuffer,
            int left, int right, int up, int down,
            int xcenter, int ycenter,
            CoordinatesType shortDim, CoordinatesType longDim,
            int *neighbor)
    {
       	if(neighbor[Left]!=-1)        
            BufferEntities(meshfunction,rcvbuffs[Left],0,localbegin.y(),overlap.x(),localsize.y(),false);
            BufferEntities(meshfunction,buffers[Left],left,ycenter,shortDim.x(),longDim.y(),toBuffer);
        if(neighbor[Right]!=-1)
            BufferEntities(meshfunction,rcvbuffs[Right],localgridsize.x()-overlap.x(),localbegin.y(),overlap.x(),localsize.y(),false);
            BufferEntities(meshfunction,buffers[Right],right,ycenter,shortDim.x(),longDim.y(),toBuffer);
        if(neighbor[Up]!=-1)
            BufferEntities(meshfunction,rcvbuffs[Up],localbegin.x(),0,localsize.x(),overlap.y(),false);
            BufferEntities(meshfunction,buffers[Up],xcenter,up,longDim.x(),shortDim.y(),toBuffer);
        if(neighbor[Down]!=-1)
            BufferEntities(meshfunction,rcvbuffs[Down],localbegin.x(),localgridsize.y()-overlap.y(),localsize.x(),overlap.y(),false);
            BufferEntities(meshfunction,buffers[Down],xcenter,down,longDim.x(),shortDim.y(),toBuffer);
        if(neighbor[UpLeft]!=-1)
            BufferEntities(meshfunction,rcvbuffs[UpLeft],0,0,overlap.x(),overlap.y(),false);
            BufferEntities(meshfunction,buffers[UpLeft],left,up,shortDim.x(),shortDim.y(),toBuffer);
        if(neighbor[UpRight]!=-1)
            BufferEntities(meshfunction,rcvbuffs[UpRight],localgridsize.x()-overlap.x(),0,overlap.x(),overlap.y(),false);
            BufferEntities(meshfunction,buffers[UpRight],right,up,shortDim.x(),shortDim.y(),toBuffer);
        if(neighbor[DownLeft]!=-1)        
            BufferEntities(meshfunction,rcvbuffs[DownLeft],0,localgridsize.y()-overlap.y(),overlap.x(),overlap.y(),false);
            BufferEntities(meshfunction,buffers[DownLeft],left,down,shortDim.x(),shortDim.y(),toBuffer);
        if(neighbor[DownRight]!=-1)
            BufferEntities(meshfunction,rcvbuffs[DownRight],localgridsize.x()-overlap.x(),localgridsize.y()-overlap.y(),overlap.x(),overlap.y(),false);
		
        //free buffers
        for(int i=0;i<8;i++)
        {
            delete [] sendbuffs[i];
            delete [] rcvbuffs[i];
            BufferEntities(meshfunction,buffers[DownRight],right,down,shortDim.x(),shortDim.y(),toBuffer);
    }
    
    };
    
    private:    
    template <typename Real>
    static void BufferEntities(MeshFunctionType meshfunction, Real * buffer, int beginx, int beginy, int sizex, int sizey,bool tobuffer)
    void BufferEntities(MeshFunctionType meshfunction, Real * buffer, int beginx, int beginy, int sizex, int sizey,bool tobuffer)
    {

        typename MeshFunctionType::MeshType::Cell entity(meshfunction.getMesh());
@@ -365,7 +435,7 @@ class DistributedGridSynchronizer<DistributedGridType,MeshFunctionType,3>
        
    }
        
    void Synchronize(MeshFunctionType meshfunction)
    void Synchronize(MeshFunctionType &meshfunction)
    {
	if(!distributedgrid->isMPIUsed())
            return;
+4 −5
Original line number Diff line number Diff line
@@ -148,8 +148,7 @@ class DistributedGirdTest_1D : public ::testing::Test {
	
	meshFunctionptr->bind(gridptr,*dof);

	//DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,1> synchronizer(&distrgrid)
	synchronizer=new DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,1>;
	synchronizer=new DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,1>(distrgrid);
	
	constFunctionPtr->Number=rank;
  }
@@ -213,7 +212,7 @@ TEST_F(DistributedGirdTest_1D, LinearFunctionTest)
	//fill meshfunction with linear function (physical center of cell corresponds with its coordinates in grid) 
	setDof_1D(*dof,-1);
	linearFunctionEvaluator.evaluateAllEntities(meshFunctionptr, linearFunctionPtr);
	synchronizer->Synchronize(*distrgrid,*meshFunctionptr);
	synchronizer->Synchronize(*meshFunctionptr);
	
	auto entite= gridptr->template getEntity< Cell >(0);
	entite.refresh();
@@ -228,7 +227,7 @@ TEST_F(DistributedGirdTest_1D, SynchronizerNeighborTest)
{
	setDof_1D(*dof,-1);
	constFunctionEvaluator.evaluateAllEntities( meshFunctionptr , constFunctionPtr );
	synchronizer->Synchronize(*distrgrid,*meshFunctionptr);
	synchronizer->Synchronize(*meshFunctionptr);
	if(rank!=0)
		EXPECT_EQ((*dof)[0],rank-1)<< "Left Overlap was filled by wrong process.";
	if(rank!=nproc-1)
+3 −4
Original line number Diff line number Diff line
@@ -416,8 +416,7 @@ class DistributedGirdTest_2D : public ::testing::Test {
	
	meshFunctionptr->bind(gridptr,*dof);
	
	//DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,1> synchronizer(&distrgrid)
	synchronizer=new DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,2>;
	synchronizer=new DistributedGridSynchronizer<DistributedGrid<MeshType>,MeshFunctionType,2>(distrgrid);
	
	constFunctionPtr->Number=rank;
	
@@ -485,7 +484,7 @@ TEST_F(DistributedGirdTest_2D, LinearFunctionTest)
	//fill meshfunction with linear function (physical center of cell corresponds with its coordinates in grid) 
	setDof_2D(*dof,-1);
	linearFunctionEvaluator.evaluateAllEntities(meshFunctionptr, linearFunctionPtr);
	synchronizer->Synchronize(*distrgrid,*meshFunctionptr);
	synchronizer->Synchronize(*meshFunctionptr);
	
	int count =gridptr->template getEntitiesCount< Cell >();
	for(int i=0;i<count;i++)
@@ -500,7 +499,7 @@ TEST_F(DistributedGirdTest_2D, SynchronizerNeighborTest)
{
	setDof_2D(*dof,-1);
	constFunctionEvaluator.evaluateAllEntities( meshFunctionptr , constFunctionPtr );
	synchronizer->Synchronize(*distrgrid,*meshFunctionptr);
	synchronizer->Synchronize(*meshFunctionptr);
	checkNeighbor_2D(rank, *gridptr, *dof);
}