Commit a27e4875 authored by Vít Hanousek's avatar Vít Hanousek
Browse files

Refactor MPI Communicator from (deprecated) C++ MPI API to C MPI API

Preapare Distributed Grids for Communication groups (aka MPI Communicator)
parent af6d900e
Loading
Loading
Loading
Loading
+49 −30
Original line number Diff line number Diff line
@@ -56,14 +56,16 @@ class MpiCommunicator
      inline static MPI_Datatype MPIDataType( const double* ) { return MPI_DOUBLE; };
      inline static MPI_Datatype MPIDataType( const long double* ) { return MPI_LONG_DOUBLE; };

      using Request = MPI::Request;
      using Request = MPI_Request;
      using CommunicationGroup = MPI_Comm;
#else
      using Request = int;
      using CommunicationGroup = int;
#endif

      static bool isDistributed()
      {
         return GetSize()>1;
         return GetSize(AllGroup)>1;
      };

      static void configSetup( Config::ConfigDescription& config, const String& prefix = "" )
@@ -89,8 +91,9 @@ class MpiCommunicator
      static void Init(int argc, char **argv )
      {
#ifdef HAVE_MPI
         MPI::Init( argc, argv );
         NullRequest=MPI::REQUEST_NULL;
         MPI_Init( &argc, &argv );
         NullRequest=MPI_REQUEST_NULL;
         AllGroup=MPI_COMM_WORLD;
         redirect = true;

         selectGPU();
@@ -111,9 +114,9 @@ class MpiCommunicator
            backup=std::cout.rdbuf();

            //redirect output to files...
            if(MPI::COMM_WORLD.Get_rank()!=0)
            if(GetRank(AllGroup)!=0)
            {
               std::cout<< GetRank() <<": Redirecting std::out to file" <<std::endl;
               std::cout<< GetRank(AllGroup) <<": Redirecting std::out to file" <<std::endl;
               String stdoutFile;
               stdoutFile=String( "./stdout-")+convertToString(MPI::COMM_WORLD.Get_rank())+String(".txt");
               filestr.open (stdoutFile.getString()); 
@@ -129,41 +132,47 @@ class MpiCommunicator
#ifdef HAVE_MPI
         if(isDistributed())
         {
            if(MPI::COMM_WORLD.Get_rank()!=0)
            if(GetRank(AllGroup)!=0)
            {
               std::cout.rdbuf(backup);
               filestr.close();
            }
         }
         MPI::Finalize();
         MPI_Finalize();
#endif
      };

      static bool IsInitialized()
      {
#ifdef HAVE_MPI
         return MPI::Is_initialized() && !MPI::Is_finalized();
         int inicialized, finalized;
         MPI_Initialized(&inicialized);
         MPI_Finalized(&finalized);
         return inicialized && !finalized;
#else
        return false;
#endif
      };

      static int GetRank()
      static int GetRank(CommunicationGroup group)
      {
         //CHECK_INICIALIZED_RET(MPI::COMM_WORLD.Get_rank());
#ifdef HAVE_MPI
        TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
        return MPI::COMM_WORLD.Get_rank();
        int rank;
        MPI_Comm_rank(group,&rank);
        return rank;
#else
        return 1;
#endif
      };

      static int GetSize()
      static int GetSize(CommunicationGroup group)
      {
#ifdef HAVE_MPI
        TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
        return MPI::COMM_WORLD.Get_size();
        int size;
        MPI_Comm_size(group,&size);
        return size;
#else
        return 1;
#endif
@@ -194,33 +203,37 @@ class MpiCommunicator
#endif
        };

         static void Barrier()
         static void Barrier(CommunicationGroup comm)
         {
#ifdef HAVE_MPI
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not inicialized");
            MPI::COMM_WORLD.Barrier();;
            MPI_Barrier(comm);
#else
            throw Exceptions::MPISupportMissing();
#endif
        };

         template <typename T>
         static Request ISend( const T *data, int count, int dest)
         static Request ISend( const T *data, int count, int dest, CommunicationGroup group)
         {
#ifdef HAVE_MPI
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not inicialized");
            return MPI::COMM_WORLD.Isend((void*) data, count, MPIDataType(data) , dest, 0);
            Request req;
            MPI_Isend((void*) data, count, MPIDataType(data) , dest, 0, group, &req);
            return req;
#else
            throw Exceptions::MPISupportMissing();
#endif
        }

         template <typename T>
         static Request IRecv( const T *data, int count, int src)
         static Request IRecv( const T *data, int count, int src, CommunicationGroup group)
         {
#ifdef HAVE_MPI
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not inicialized");
            return MPI::COMM_WORLD.Irecv((void*) data, count, MPIDataType(data) , src, 0);
            Request req;
            MPI_Irecv((void*) data, count, MPIDataType(data) , src, 0, group, &req);
            return req;
#else
            throw Exceptions::MPISupportMissing();
#endif
@@ -230,18 +243,18 @@ class MpiCommunicator
         {
#ifdef HAVE_MPI
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not inicialized");
            MPI::Request::Waitall(length, reqs);
            MPI_Waitall(length, reqs, MPI_STATUSES_IGNORE);
#else
            throw Exceptions::MPISupportMissing();
#endif
        };

        template< typename T > 
        static void Bcast(  T& data, int count, int root)
        static void Bcast(  T& data, int count, int root,CommunicationGroup group)
        {
#ifdef HAVE_MPI
        TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not inicialized");
        MPI::COMM_WORLD.Bcast((void*) &data, count,  MPIDataType(data), root);
        MPI_Bcast((void*) &data, count,  MPIDataType(data), root, group);
#else
        throw Exceptions::MPISupportMissing();
#endif
@@ -251,10 +264,11 @@ class MpiCommunicator
        static void Allreduce( T* data,
                               T* reduced_data,
                               int count,
                               const MPI_Op &op )
                               const MPI_Op &op,
                               CommunicationGroup group)
        {
#ifdef HAVE_MPI
            MPI::COMM_WORLD.Allreduce( (void*) data, (void*) reduced_data,count,MPIDataType(data),op);
            MPI_Allreduce( (void*) data, (void*) reduced_data,count,MPIDataType(data),op,group);
#else
            throw Exceptions::MPISupportMissing();
#endif
@@ -266,10 +280,11 @@ class MpiCommunicator
                    T* reduced_data,
                    int count,
                    MPI_Op &op,
                    int root)
                    int root,
                    CommunicationGroup group)
         {
#ifdef HAVE_MPI
            MPI::COMM_WORLD.Reduce( (void*) data, (void*) reduced_data,count,MPIDataType(data),op,root);
            MPI_Reduce( (void*) data, (void*) reduced_data,count,MPIDataType(data),op,root,group);
#else
            throw Exceptions::MPISupportMissing();
#endif
@@ -280,14 +295,16 @@ class MpiCommunicator
      {
         if( isDistributed() )
         {
            logger.writeParameter( "MPI processes:", GetSize() );
            logger.writeParameter( "MPI processes:", GetSize(AllGroup) );
         }
      }

#ifdef HAVE_MPI
      static MPI::Request NullRequest;
      static MPI_Request NullRequest;
      static MPI_Comm AllGroup;
#else
      static int NullRequest;
      static int AllGroup;
#endif
    private :
      static std::streambuf *psbuf;
@@ -341,9 +358,11 @@ class MpiCommunicator
};

#ifdef HAVE_MPI
MPI::Request MpiCommunicator::NullRequest;
MPI_Request MpiCommunicator::NullRequest;
MPI_Comm MpiCommunicator::AllGroup;
#else
int MpiCommunicator::NullRequest;
int MpiCommunicator::AllGroup;
#endif
std::streambuf *MpiCommunicator::psbuf;
std::streambuf *MpiCommunicator::backup;
+13 −8
Original line number Diff line number Diff line
@@ -27,7 +27,9 @@ class NoDistrCommunicator
   public:

      typedef int Request;
      typedef int CommunicationGroup;
      static Request NullRequest;
      static CommunicationGroup AllGroup;

      static void configSetup( Config::ConfigDescription& config, const String& prefix = "" ){};
 
@@ -58,12 +60,12 @@ class NoDistrCommunicator
          return false;
      };

      static int GetRank()
      static int GetRank(CommunicationGroup group)
      {
          return 0;
      };

      static int GetSize()
      static int GetSize(CommunicationGroup group)
      {
          return 1;
      };
@@ -76,18 +78,18 @@ class NoDistrCommunicator
          }
      };

      static void Barrier()
      static void Barrier(CommunicationGroup group)
      {
      };

      template <typename T>
      static Request ISend( const T *data, int count, int dest)
      static Request ISend( const T *data, int count, int dest, CommunicationGroup group)
      {
          return 1;
      }

      template <typename T>
      static Request IRecv( const T *data, int count, int src)
      static Request IRecv( const T *data, int count, int src, CommunicationGroup group)
      {
          return 1;
      }
@@ -97,7 +99,7 @@ class NoDistrCommunicator
      };

      template< typename T > 
      static void Bcast(  T& data, int count, int root)
      static void Bcast(  T& data, int count, int root, CommunicationGroup group)
      {
      }

@@ -105,7 +107,8 @@ class NoDistrCommunicator
      static void Allreduce( T* data,
                             T* reduced_data,
                             int count,
                             const MPI_Op &op )
                             const MPI_Op &op,
                             CommunicationGroup group )
      {
         memcpy( ( void* ) reduced_data, ( void* ) data, count * sizeof( T ) );
      };
@@ -115,7 +118,8 @@ class NoDistrCommunicator
                          T* reduced_data,
                          int count,
                          MPI_Op &op,
                          int root )
                          int root,
                          CommunicationGroup group )
      {
         memcpy( ( void* ) reduced_data, ( void* ) data, count * sizeof( T ) );
      };
@@ -125,6 +129,7 @@ class NoDistrCommunicator


  int NoDistrCommunicator::NullRequest;
  int NoDistrCommunicator::AllGroup;

} // namespace Communicators
} // namespace TNL
+8 −6
Original line number Diff line number Diff line
@@ -180,13 +180,14 @@ class DistributedGridIO_MPIIOBase
            return meshFunction.save(fileName);
        }

       MPI_Comm group=*((MPI_Comm*)(distrGrid->getCommunicationGroup()));
       MPI_Datatype ftype;
       MPI_Datatype atype;
       int dataCount=CreateDataTypes(distrGrid,&ftype,&atype);

       //write 
       MPI_File file;
       MPI_File_open( MPI_COMM_WORLD,
       MPI_File_open( group,
                      const_cast< char* >( fileName.getString() ),
                      MPI_MODE_CREATE | MPI_MODE_WRONLY,
                      MPI_INFO_NULL,
@@ -194,11 +195,11 @@ class DistributedGridIO_MPIIOBase

       int headerSize=0;

       if(Communicators::MpiCommunicator::GetRank()==0)
       if(Communicators::MpiCommunicator::GetRank(group)==0)
       {
            headerSize=writeMeshFunctionHeader(file,meshFunction,dataCount);
       }
       MPI_Bcast(&headerSize, 1, MPI_INT,0, MPI_COMM_WORLD);
       MPI_Bcast(&headerSize, 1, MPI_INT,0, group);

       if( std::is_same< RealType, double >::value)
         MPI_File_set_view(file,headerSize,MPI_DOUBLE,ftype,"native",MPI_INFO_NULL);
@@ -329,13 +330,14 @@ class DistributedGridIO_MPIIOBase
            return meshFunction.boundLoad(fileName);
        }

       MPI_Comm group=*((MPI_Comm*)(distrGrid->getCommunicationGroup()));
       MPI_Datatype ftype;
       MPI_Datatype atype;
       int dataCount=CreateDataTypes(distrGrid,&ftype,&atype);

       //write 
       MPI_File file;
       MPI_File_open( MPI_COMM_WORLD,
       MPI_File_open( group,
                      const_cast< char* >( fileName.getString() ),
                      MPI_MODE_RDONLY,
                      MPI_INFO_NULL,
@@ -343,11 +345,11 @@ class DistributedGridIO_MPIIOBase
       
       int headerSize=0;

       if(Communicators::MpiCommunicator::GetRank()==0)
       if(Communicators::MpiCommunicator::GetRank(group)==0)
       {
            headerSize=readMeshFunctionHeader(file,meshFunction,dataCount);
       }
       MPI_Bcast(&headerSize, 1, MPI_INT,0, MPI_COMM_WORLD);
       MPI_Bcast(&headerSize, 1, MPI_INT,0, group);
       
       if(headerSize<0)
            return false;
+14 −9
Original line number Diff line number Diff line
@@ -91,12 +91,13 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 1, GridReal, D

          //async send
          typename CommunicatorType::Request req[4];

          typename CommunicatorType::CommunicationGroup group;
          group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
          //send everithing, recieve everything 
          if(leftN!=-1)
          {
              req[0]=CommunicatorType::ISend(sendbuffs[Left].getData(), overlapSize, leftN);
              req[2]=CommunicatorType::IRecv(rcvbuffs[Left].getData(), overlapSize, leftN);
              req[0]=CommunicatorType::ISend(sendbuffs[Left].getData(), overlapSize, leftN,group);
              req[2]=CommunicatorType::IRecv(rcvbuffs[Left].getData(), overlapSize, leftN,group);
          }
          else
          {
@@ -106,8 +107,8 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 1, GridReal, D

          if(rightN!=-1)
          {
              req[1]=CommunicatorType::ISend(sendbuffs[Right].getData(), overlapSize, rightN);
              req[3]=CommunicatorType::IRecv(rcvbuffs[Right].getData(), overlapSize, rightN);
              req[1]=CommunicatorType::ISend(sendbuffs[Right].getData(), overlapSize, rightN,group);
              req[3]=CommunicatorType::IRecv(rcvbuffs[Right].getData(), overlapSize, rightN,group);
          }
          else
          {
@@ -240,13 +241,15 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 2, GridReal, D

         //async send and rcv
         typename CommunicatorType::Request req[16];
         typename CommunicatorType::CommunicationGroup group;
         group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));

         //send everything, recieve everything 
         for(int i=0;i<8;i++)	
            if(neighbor[i]!=-1)
            {
               req[i]=CommunicatorType::ISend(sendbuffs[i].getData(), sizes[i], neighbor[i]);
               req[8+i]=CommunicatorType::IRecv(rcvbuffs[i].getData(), sizes[i], neighbor[i]);
               req[i]=CommunicatorType::ISend(sendbuffs[i].getData(), sizes[i], neighbor[i],group);
               req[8+i]=CommunicatorType::IRecv(rcvbuffs[i].getData(), sizes[i], neighbor[i],group);
            }
            else
            {
@@ -444,13 +447,15 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 3, GridReal, D
        
        //async send and rcv
        typename CommunicatorType::Request req[52];
        typename CommunicatorType::CommunicationGroup group;
        group=*((typename CommunicatorType::CommunicationGroup *)(distributedGrid->getCommunicationGroup()));
		                
        //send everithing, recieve everything 
        for(int i=0;i<26;i++)	
           if(neighbor[i]!=-1)
           {
               req[i]=CommunicatorType::ISend(sendbuffs[i].getData(), sizes[i], neighbor[i]);
               req[26+i]=CommunicatorType::IRecv(rcvbuffs[i].getData(), sizes[i], neighbor[i]);
               req[i]=CommunicatorType::ISend(sendbuffs[i].getData(), sizes[i], neighbor[i],group);
               req[26+i]=CommunicatorType::IRecv(rcvbuffs[i].getData(), sizes[i], neighbor[i],group);
           }
		   else
      	   {
+4 −2
Original line number Diff line number Diff line
@@ -31,6 +31,8 @@ DistributedMesh< Grid< 1, RealType, Device, Index > >::
setGlobalGrid( const GridType& globalGrid,
               const CoordinatesType& overlap )
{
   typename CommunicatorType::CommunicationGroup &group = CommunicatorType::AllGroup;
   this->communicationGroup=(void*)& group;
   this->globalGrid = globalGrid;
   this->isSet = true;
   this->overlap = overlap;
@@ -43,8 +45,8 @@ setGlobalGrid( const GridType& globalGrid,
   this->distributed = false;
   if( CommunicatorType::IsInitialized() )
   {
       this->rank = CommunicatorType::GetRank();
       this->nproc = CommunicatorType::GetSize();
       this->rank = CommunicatorType::GetRank(group);
       this->nproc = CommunicatorType::GetSize(group);
       if( this->nproc>1 )
       {
           this->distributed = true;
Loading