diff --git a/src/TNL/Communicators/CMakeLists.txt b/src/TNL/Communicators/CMakeLists.txt index fb3193b739c75ea41ddae9ecf664592e44821d79..87feba13eb52023894808f6b75fca3e05eed2708 100644 --- a/src/TNL/Communicators/CMakeLists.txt +++ b/src/TNL/Communicators/CMakeLists.txt @@ -1,6 +1,7 @@ SET( headers MpiCommunicator.h - MpiDefs.h - NoDistrCommunicator.h + MpiDefs.h + MPITypeResolver.h + NoDistrCommunicator.h ScopedInitializer.h ) diff --git a/src/TNL/Communicators/MPITypeResolver.h b/src/TNL/Communicators/MPITypeResolver.h new file mode 100644 index 0000000000000000000000000000000000000000..2d0b45e2b3fcc4a1385a261c18d4bb987f86e2e4 --- /dev/null +++ b/src/TNL/Communicators/MPITypeResolver.h @@ -0,0 +1,103 @@ +/*************************************************************************** + MPITypeResolver.h - description + ------------------- + begin : Feb 4, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +namespace TNL { +namespace Communicators { + +#ifdef HAVE_MPI +template<typename Type> +struct MPITypeResolver +{ + static inline MPI_Datatype getType() + { + switch( sizeof( Type ) ) + { + case sizeof( char ): + return MPI_CHAR; + case sizeof( int ): + return MPI_INT; + case sizeof( short int ): + return MPI_SHORT; + case sizeof( long int ): + return MPI_LONG; + } + TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type"); + }; +}; + +template<> struct MPITypeResolver< char > +{ + static inline MPI_Datatype getType(){return MPI_CHAR;}; +}; + +template<> struct MPITypeResolver< int > +{ + static inline MPI_Datatype getType(){return MPI_INT;}; +}; + +template<> struct MPITypeResolver< short int > +{ + static inline MPI_Datatype getType(){return MPI_SHORT;}; +}; + +template<> struct MPITypeResolver< long int > +{ + static inline MPI_Datatype getType(){return MPI_LONG;}; +}; + +template<> struct MPITypeResolver< unsigned char > +{ + static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;}; +}; + +template<> struct MPITypeResolver< unsigned short int > +{ + static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;}; +}; + +template<> struct MPITypeResolver< unsigned int > +{ + static inline MPI_Datatype getType(){return MPI_UNSIGNED;}; +}; + +template<> struct MPITypeResolver< unsigned long int > +{ + static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;}; +}; + +template<> struct MPITypeResolver< float > +{ + static inline MPI_Datatype getType(){return MPI_FLOAT;}; +}; + +template<> struct MPITypeResolver< double > +{ + static inline MPI_Datatype getType(){return MPI_DOUBLE;}; +}; + +template<> struct MPITypeResolver< long double > +{ + static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;}; +}; + +template<> struct MPITypeResolver< bool > +{ + // sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859 + static_assert( sizeof(bool) == 1, "The systems where sizeof(bool) != 1 are not supported by MPI." ); + static inline MPI_Datatype getType() { return MPI_C_BOOL; }; +}; +#endif + + + + } // namespace Communicators +} // namespace TNL diff --git a/src/TNL/Communicators/MpiCommunicator.h b/src/TNL/Communicators/MpiCommunicator.h index 4b66ad33a3b2d2c4814a1831c3845c7e88c735dc..cc5a5cb57f59ea07baf7720770e0df0111b0204e 100644 --- a/src/TNL/Communicators/MpiCommunicator.h +++ b/src/TNL/Communicators/MpiCommunicator.h @@ -39,7 +39,7 @@ #include <TNL/Config/ConfigDescription.h> #include <TNL/Exceptions/MPISupportMissing.h> #include <TNL/Exceptions/MPIDimsCreateError.h> - +#include <TNL/Communicators/MPITypeResolver.h> namespace TNL { @@ -51,7 +51,7 @@ class MpiCommunicator public: // TODO: this was private #ifdef HAVE_MPI - inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; }; + /*inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; }; inline static MPI_Datatype MPIDataType( const signed short int* ) { return MPI_SHORT; }; inline static MPI_Datatype MPIDataType( const signed int* ) { return MPI_INT; }; inline static MPI_Datatype MPIDataType( const signed long int* ) { return MPI_LONG; }; @@ -69,7 +69,7 @@ class MpiCommunicator // sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859 static_assert( sizeof(bool) == 1, "The programmer did not count with systems where sizeof(bool) != 1." ); return MPI_CHAR; - }; + };*/ using Request = MPI_Request; using CommunicationGroup = MPI_Comm; @@ -241,11 +241,19 @@ class MpiCommunicator #endif } - //dim-number of dimensions, distr array of guess distr - 0 for computation - //distr array will be filled by computed distribution - //more information in MPI documentation - static void DimsCreate(int nproc, int dim, int *distr) - { +#ifdef HAVE_MPI + template< typename T > + static MPI_Datatype getDataType( const T& t ) + { + return MPITypeResolver< T >::getType(); + }; +#endif + + //dim-number of dimensions, distr array of guess distr - 0 for computation + //distr array will be filled by computed distribution + //more information in MPI documentation + static void DimsCreate(int nproc, int dim, int *distr) + { #ifdef HAVE_MPI int sum = 0, prod = 1; for( int i = 0;i < dim; i++ ) @@ -288,7 +296,7 @@ class MpiCommunicator TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized"); TNL_ASSERT_NE(group, NullGroup, "ISend cannot be called with NullGroup"); Request req; - MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPIDataType(data) , dest, tag, group, &req); + MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPITypeResolver< T >::getType(), dest, tag, group, &req); return req; #else throw Exceptions::MPISupportMissing(); @@ -302,7 +310,7 @@ class MpiCommunicator TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized"); TNL_ASSERT_NE(group, NullGroup, "IRecv cannot be called with NullGroup"); Request req; - MPI_Irecv((void*) data, count, MPIDataType(data) , src, tag, group, &req); + MPI_Irecv((void*) data, count, MPITypeResolver< T >::getType() , src, tag, group, &req); return req; #else throw Exceptions::MPISupportMissing(); @@ -325,7 +333,7 @@ class MpiCommunicator #ifdef HAVE_MPI TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized"); TNL_ASSERT_NE(group, NullGroup, "BCast cannot be called with NullGroup"); - MPI_Bcast((void*) data, count, MPIDataType(data), root, group); + MPI_Bcast((void*) data, count, MPITypeResolver< T >::getType(), root, group); #else throw Exceptions::MPISupportMissing(); #endif @@ -340,7 +348,7 @@ class MpiCommunicator { #ifdef HAVE_MPI TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup"); - MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPIDataType(data),op,group); + MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,group); #else throw Exceptions::MPISupportMissing(); #endif @@ -355,7 +363,7 @@ class MpiCommunicator { #ifdef HAVE_MPI TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup"); - MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPIDataType(data),op,group); + MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPITypeResolver< T >::getType(),op,group); #else throw Exceptions::MPISupportMissing(); #endif @@ -372,7 +380,7 @@ class MpiCommunicator { #ifdef HAVE_MPI TNL_ASSERT_NE(group, NullGroup, "Reduce cannot be called with NullGroup"); - MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPIDataType(data),op,root,group); + MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,root,group); #else throw Exceptions::MPISupportMissing(); #endif @@ -394,12 +402,12 @@ class MpiCommunicator MPI_Status status; MPI_Sendrecv( const_cast< void* >( ( void* ) sendData ), sendCount, - MPIDataType( sendData ), + MPITypeResolver< T >::getType(), destination, sendTag, ( void* ) receiveData, receiveCount, - MPIDataType( receiveData ), + MPITypeResolver< T >::getType(), source, receiveTag, group, @@ -420,10 +428,10 @@ class MpiCommunicator TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup"); MPI_Alltoall( const_cast< void* >( ( void* ) sendData ), sendCount, - MPIDataType( sendData ), + MPITypeResolver< T >::getType(), ( void* ) receiveData, receiveCount, - MPIDataType( receiveData ), + MPITypeResolver< T >::getType(), group ); #else throw Exceptions::MPISupportMissing(); @@ -523,88 +531,6 @@ std::streambuf* MpiCommunicator::backup = nullptr; std::ofstream MpiCommunicator::filestr; bool MpiCommunicator::redirect = true; -#ifdef HAVE_MPI -// TODO: this duplicates MpiCommunicator::MPIDataType -template<typename Type> -struct MPITypeResolver -{ - static inline MPI_Datatype getType() - { - switch( sizeof( Type ) ) - { - case sizeof( char ): - return MPI_CHAR; - case sizeof( int ): - return MPI_INT; - case sizeof( short int ): - return MPI_SHORT; - case sizeof( long int ): - return MPI_LONG; - } - TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type"); - }; -}; - -template<> struct MPITypeResolver< char > -{ - static inline MPI_Datatype getType(){return MPI_CHAR;}; -}; - -template<> struct MPITypeResolver< int > -{ - static inline MPI_Datatype getType(){return MPI_INT;}; -}; - -template<> struct MPITypeResolver< short int > -{ - static inline MPI_Datatype getType(){return MPI_SHORT;}; -}; - -template<> struct MPITypeResolver< long int > -{ - static inline MPI_Datatype getType(){return MPI_LONG;}; -}; - -template<> struct MPITypeResolver< unsigned char > -{ - static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;}; -}; - -template<> struct MPITypeResolver< unsigned short int > -{ - static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;}; -}; - -template<> struct MPITypeResolver< unsigned int > -{ - static inline MPI_Datatype getType(){return MPI_UNSIGNED;}; -}; - -template<> struct MPITypeResolver< unsigned long int > -{ - static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;}; -}; - -template<> struct MPITypeResolver< float > -{ - static inline MPI_Datatype getType(){return MPI_FLOAT;}; -}; - -template<> struct MPITypeResolver< double > -{ - static inline MPI_Datatype getType(){return MPI_DOUBLE;}; -}; - -template<> struct MPITypeResolver< long double > -{ - static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;}; -}; - -template<> struct MPITypeResolver< bool > -{ - static inline MPI_Datatype getType(){return MPI_C_BOOL;}; -}; -#endif } // namespace <unnamed> } // namespace Communicators