Commit e99d6c87 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Avoided duplicity in MPI data type wrapper implementation.

parent bd14cf8e
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
SET( headers MpiCommunicator.h
SET( headers MpiCommunicator.h
             MpiDefs.h
             MpiDefs.h
             MPITypeResolver.h
             NoDistrCommunicator.h
             NoDistrCommunicator.h
             ScopedInitializer.h
             ScopedInitializer.h
    )
    )
+103 −0
Original line number Original line Diff line number Diff line
/***************************************************************************
                          MPITypeResolver.h  -  description
                             -------------------
    begin                : Feb 4, 2019
    copyright            : (C) 2019 by Tomas Oberhuber
    email                : tomas.oberhuber@fjfi.cvut.cz
 ***************************************************************************/

/* See Copyright Notice in tnl/Copyright */

#pragma once

namespace TNL {
namespace Communicators {

#ifdef HAVE_MPI
template<typename Type>
struct MPITypeResolver
{
   static inline MPI_Datatype getType()
   {
      switch( sizeof( Type ) )
      {
         case sizeof( char ):
            return MPI_CHAR;
         case sizeof( int ):
            return MPI_INT;
         case sizeof( short int ):
            return MPI_SHORT;
         case sizeof( long int ):
            return MPI_LONG;
      }
      TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type");
   };
};

template<> struct MPITypeResolver< char >
{
    static inline MPI_Datatype getType(){return MPI_CHAR;};
};

template<> struct MPITypeResolver< int >
{
    static inline MPI_Datatype getType(){return MPI_INT;};
};

template<> struct MPITypeResolver< short int >
{
    static inline MPI_Datatype getType(){return MPI_SHORT;};
};

template<> struct MPITypeResolver< long int >
{
    static inline MPI_Datatype getType(){return MPI_LONG;};
};

template<> struct MPITypeResolver< unsigned char >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;};
};

template<> struct MPITypeResolver< unsigned short int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;};
};

template<> struct MPITypeResolver< unsigned int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED;};
};

template<> struct MPITypeResolver< unsigned long int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;};
};

template<> struct MPITypeResolver< float >
{
    static inline MPI_Datatype getType(){return MPI_FLOAT;};
};

template<> struct MPITypeResolver< double >
{
    static inline MPI_Datatype getType(){return MPI_DOUBLE;};
};

template<> struct MPITypeResolver< long double >
{
    static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;};
};

template<> struct MPITypeResolver< bool >
{
   // sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859
   static_assert( sizeof(bool) == 1, "The systems where sizeof(bool) != 1 are not supported by MPI." );
   static inline MPI_Datatype getType() { return MPI_C_BOOL; };
};
#endif
   
   
   
   } // namespace Communicators
} // namespace TNL
+26 −100
Original line number Original line Diff line number Diff line
@@ -39,7 +39,7 @@
#include <TNL/Config/ConfigDescription.h>
#include <TNL/Config/ConfigDescription.h>
#include <TNL/Exceptions/MPISupportMissing.h>
#include <TNL/Exceptions/MPISupportMissing.h>
#include <TNL/Exceptions/MPIDimsCreateError.h>
#include <TNL/Exceptions/MPIDimsCreateError.h>

#include <TNL/Communicators/MPITypeResolver.h>




namespace TNL {
namespace TNL {
@@ -51,7 +51,7 @@ class MpiCommunicator


   public: // TODO: this was private
   public: // TODO: this was private
#ifdef HAVE_MPI
#ifdef HAVE_MPI
      inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; };
      /*inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; };
      inline static MPI_Datatype MPIDataType( const signed short int* ) { return MPI_SHORT; };
      inline static MPI_Datatype MPIDataType( const signed short int* ) { return MPI_SHORT; };
      inline static MPI_Datatype MPIDataType( const signed int* ) { return MPI_INT; };
      inline static MPI_Datatype MPIDataType( const signed int* ) { return MPI_INT; };
      inline static MPI_Datatype MPIDataType( const signed long int* ) { return MPI_LONG; };
      inline static MPI_Datatype MPIDataType( const signed long int* ) { return MPI_LONG; };
@@ -69,7 +69,7 @@ class MpiCommunicator
         // sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859
         // sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859
         static_assert( sizeof(bool) == 1, "The programmer did not count with systems where sizeof(bool) != 1." );
         static_assert( sizeof(bool) == 1, "The programmer did not count with systems where sizeof(bool) != 1." );
         return MPI_CHAR;
         return MPI_CHAR;
      };
      };*/


      using Request = MPI_Request;
      using Request = MPI_Request;
      using CommunicationGroup = MPI_Comm;
      using CommunicationGroup = MPI_Comm;
@@ -241,6 +241,14 @@ class MpiCommunicator
#endif
#endif
      }
      }


#ifdef HAVE_MPI
      template< typename T >
      static MPI_Datatype getDataType( const T& t )
      { 
         return MPITypeResolver< T >::getType();
      };
#endif
      
      //dim-number of dimensions, distr array of guess distr - 0 for computation
      //dim-number of dimensions, distr array of guess distr - 0 for computation
      //distr array will be filled by computed distribution
      //distr array will be filled by computed distribution
      //more information in MPI documentation
      //more information in MPI documentation
@@ -288,7 +296,7 @@ class MpiCommunicator
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
            TNL_ASSERT_NE(group, NullGroup, "ISend cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "ISend cannot be called with NullGroup");
            Request req;
            Request req;
            MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPIDataType(data) , dest, tag, group, &req);
            MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPITypeResolver< T >::getType(), dest, tag, group, &req);
            return req;
            return req;
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
@@ -302,7 +310,7 @@ class MpiCommunicator
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
            TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
            TNL_ASSERT_NE(group, NullGroup, "IRecv cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "IRecv cannot be called with NullGroup");
            Request req;
            Request req;
            MPI_Irecv((void*) data, count, MPIDataType(data) , src, tag, group, &req);
            MPI_Irecv((void*) data, count, MPITypeResolver< T >::getType() , src, tag, group, &req);
            return req;
            return req;
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
@@ -325,7 +333,7 @@ class MpiCommunicator
#ifdef HAVE_MPI
#ifdef HAVE_MPI
           TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
           TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
           TNL_ASSERT_NE(group, NullGroup, "BCast cannot be called with NullGroup");
           TNL_ASSERT_NE(group, NullGroup, "BCast cannot be called with NullGroup");
           MPI_Bcast((void*) data, count, MPIDataType(data), root, group);
           MPI_Bcast((void*) data, count, MPITypeResolver< T >::getType(), root, group);
#else
#else
           throw Exceptions::MPISupportMissing();
           throw Exceptions::MPISupportMissing();
#endif
#endif
@@ -340,7 +348,7 @@ class MpiCommunicator
        {
        {
#ifdef HAVE_MPI
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
            MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPIDataType(data),op,group);
            MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,group);
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
#endif
#endif
@@ -355,7 +363,7 @@ class MpiCommunicator
        {
        {
#ifdef HAVE_MPI
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
            MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPIDataType(data),op,group);
            MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPITypeResolver< T >::getType(),op,group);
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
#endif
#endif
@@ -372,7 +380,7 @@ class MpiCommunicator
         {
         {
#ifdef HAVE_MPI
#ifdef HAVE_MPI
            TNL_ASSERT_NE(group, NullGroup, "Reduce cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "Reduce cannot be called with NullGroup");
            MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPIDataType(data),op,root,group);
            MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,root,group);
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
#endif
#endif
@@ -394,12 +402,12 @@ class MpiCommunicator
            MPI_Status status;
            MPI_Status status;
            MPI_Sendrecv( const_cast< void* >( ( void* ) sendData ),
            MPI_Sendrecv( const_cast< void* >( ( void* ) sendData ),
                          sendCount,
                          sendCount,
                          MPIDataType( sendData ),
                          MPITypeResolver< T >::getType(),
                          destination,
                          destination,
                          sendTag,
                          sendTag,
                          ( void* ) receiveData,
                          ( void* ) receiveData,
                          receiveCount,
                          receiveCount,
                          MPIDataType( receiveData ),
                          MPITypeResolver< T >::getType(),
                          source,
                          source,
                          receiveTag,
                          receiveTag,
                          group,
                          group,
@@ -420,10 +428,10 @@ class MpiCommunicator
            TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup");
            TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup");
            MPI_Alltoall( const_cast< void* >( ( void* ) sendData ),
            MPI_Alltoall( const_cast< void* >( ( void* ) sendData ),
                          sendCount,
                          sendCount,
                          MPIDataType( sendData ),
                          MPITypeResolver< T >::getType(),
                          ( void* ) receiveData,
                          ( void* ) receiveData,
                          receiveCount,
                          receiveCount,
                          MPIDataType( receiveData ),
                          MPITypeResolver< T >::getType(),
                          group );
                          group );
#else
#else
            throw Exceptions::MPISupportMissing();
            throw Exceptions::MPISupportMissing();
@@ -523,88 +531,6 @@ std::streambuf* MpiCommunicator::backup = nullptr;
std::ofstream MpiCommunicator::filestr;
std::ofstream MpiCommunicator::filestr;
bool MpiCommunicator::redirect = true;
bool MpiCommunicator::redirect = true;


#ifdef HAVE_MPI
// TODO: this duplicates MpiCommunicator::MPIDataType
template<typename Type>
struct MPITypeResolver
{
    static inline MPI_Datatype getType()
    {
      switch( sizeof( Type ) )
      {
         case sizeof( char ):
            return MPI_CHAR;
         case sizeof( int ):
            return MPI_INT;
         case sizeof( short int ):
            return MPI_SHORT;
         case sizeof( long int ):
            return MPI_LONG;
      }
      TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type");
    };
};

template<> struct MPITypeResolver< char >
{
    static inline MPI_Datatype getType(){return MPI_CHAR;};
};

template<> struct MPITypeResolver< int >
{
    static inline MPI_Datatype getType(){return MPI_INT;};
};

template<> struct MPITypeResolver< short int >
{
    static inline MPI_Datatype getType(){return MPI_SHORT;};
};

template<> struct MPITypeResolver< long int >
{
    static inline MPI_Datatype getType(){return MPI_LONG;};
};

template<> struct MPITypeResolver< unsigned char >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;};
};

template<> struct MPITypeResolver< unsigned short int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;};
};

template<> struct MPITypeResolver< unsigned int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED;};
};

template<> struct MPITypeResolver< unsigned long int >
{
    static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;};
};

template<> struct MPITypeResolver< float >
{
    static inline MPI_Datatype getType(){return MPI_FLOAT;};
};

template<> struct MPITypeResolver< double >
{
    static inline MPI_Datatype getType(){return MPI_DOUBLE;};
};

template<> struct MPITypeResolver< long double >
{
    static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;};
};

template<> struct MPITypeResolver< bool >
{
    static inline MPI_Datatype getType(){return MPI_C_BOOL;};
};
#endif


} // namespace <unnamed>
} // namespace <unnamed>
} // namespace Communicators
} // namespace Communicators