Commit 0375005c authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixes in MpiCommunicator and added wrapper for MPI_Alltoall

parent f9ac23a7
Loading
Loading
Loading
Loading
+32 −12
Original line number Diff line number Diff line
@@ -277,7 +277,7 @@ class MpiCommunicator
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
            TNL_ASSERT_NE(group, NullGroup, "ISend cannot be called with NullGroup");
            Request req;
            MPI_Isend((void*) data, count, MPIDataType(data) , dest, 0, group, &req);
            MPI_Isend((const void*) data, count, MPIDataType(data) , dest, 0, group, &req);
            return req;
#else
            throw Exceptions::MPISupportMissing();
@@ -285,7 +285,7 @@ class MpiCommunicator
        }

         template <typename T>
         static Request IRecv( const T *data, int count, int src, CommunicationGroup group)
         static Request IRecv( T* data, int count, int src, CommunicationGroup group)
         {
#ifdef HAVE_MPI
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
@@ -337,7 +337,7 @@ class MpiCommunicator


         template< typename T >
         static void Reduce( T* data,
         static void Reduce( const T* data,
                    T* reduced_data,
                    int count,
                    MPI_Op &op,
@@ -346,14 +346,14 @@ class MpiCommunicator
         {
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "Reduce cannot be called with NullGroup");
            MPI_Reduce( (void*) data, (void*) reduced_data,count,MPIDataType(data),op,root,group);
            MPI_Reduce( (const void*) data, (void*) reduced_data,count,MPIDataType(data),op,root,group);
#else
            throw Exceptions::MPISupportMissing();
#endif
        }

         template< typename T >
         static void SendReceive( T* sendData,
         static void SendReceive( const T* sendData,
                                  int sendCount,
                                  int destination,
                                  int sendTag,
@@ -366,7 +366,7 @@ class MpiCommunicator
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup");
            MPI_Status status;
            MPI_Sendrecv( ( void* ) sendData,
            MPI_Sendrecv( ( const void* ) sendData,
                          sendCount,
                          MPIDataType( sendData ),
                          destination,
@@ -383,6 +383,27 @@ class MpiCommunicator
#endif
         }

         template< typename T >
         static void Alltoall( const T* sendData,
                               int sendCount,
                               T* receiveData,
                               int receiveCount,
                               CommunicationGroup group )
         {
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup");
            MPI_Alltoall( ( const void* ) sendData,
                          sendCount,
                          MPIDataType( sendData ),
                          ( void* ) receiveData,
                          receiveCount,
                          MPIDataType( receiveData ),
                          group );
#else
            throw Exceptions::MPISupportMissing();
#endif
         }


      static void writeProlog( Logger& logger )
      {
@@ -428,10 +449,9 @@ class MpiCommunicator
      {
#ifdef HAVE_MPI
    #ifdef HAVE_CUDA
        	int count,rank, gpuCount, gpuNumber;
         MPI_Comm_size(MPI_COMM_WORLD,&count);
         MPI_Comm_rank(MPI_COMM_WORLD,&rank);

         const int count = GetSize(AllGroup);
         const int rank = GetRank(AllGroup);
         int gpuCount;
         cudaGetDeviceCount(&gpuCount);

         procName names[count];
@@ -454,7 +474,7 @@ class MpiCommunicator
               nodeRank++;
         }

         gpuNumber=nodeRank % gpuCount;
         const int gpuNumber = nodeRank % gpuCount;

         cudaSetDevice(gpuNumber);
         TNL_CHECK_CUDA_DEVICE;