Skip to content
Snippets Groups Projects
Commit e99d6c87 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Avoided duplicity in MPI data type wrapper implementation.

parent bd14cf8e
No related branches found
No related tags found
1 merge request!26Lbm
SET( headers MpiCommunicator.h
MpiDefs.h
NoDistrCommunicator.h
MpiDefs.h
MPITypeResolver.h
NoDistrCommunicator.h
ScopedInitializer.h
)
......
/***************************************************************************
MPITypeResolver.h - description
-------------------
begin : Feb 4, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
namespace TNL {
namespace Communicators {
#ifdef HAVE_MPI
template<typename Type>
struct MPITypeResolver
{
static inline MPI_Datatype getType()
{
switch( sizeof( Type ) )
{
case sizeof( char ):
return MPI_CHAR;
case sizeof( int ):
return MPI_INT;
case sizeof( short int ):
return MPI_SHORT;
case sizeof( long int ):
return MPI_LONG;
}
TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type");
};
};
template<> struct MPITypeResolver< char >
{
static inline MPI_Datatype getType(){return MPI_CHAR;};
};
template<> struct MPITypeResolver< int >
{
static inline MPI_Datatype getType(){return MPI_INT;};
};
template<> struct MPITypeResolver< short int >
{
static inline MPI_Datatype getType(){return MPI_SHORT;};
};
template<> struct MPITypeResolver< long int >
{
static inline MPI_Datatype getType(){return MPI_LONG;};
};
template<> struct MPITypeResolver< unsigned char >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;};
};
template<> struct MPITypeResolver< unsigned short int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;};
};
template<> struct MPITypeResolver< unsigned int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED;};
};
template<> struct MPITypeResolver< unsigned long int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;};
};
template<> struct MPITypeResolver< float >
{
static inline MPI_Datatype getType(){return MPI_FLOAT;};
};
template<> struct MPITypeResolver< double >
{
static inline MPI_Datatype getType(){return MPI_DOUBLE;};
};
template<> struct MPITypeResolver< long double >
{
static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;};
};
template<> struct MPITypeResolver< bool >
{
// sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859
static_assert( sizeof(bool) == 1, "The systems where sizeof(bool) != 1 are not supported by MPI." );
static inline MPI_Datatype getType() { return MPI_C_BOOL; };
};
#endif
} // namespace Communicators
} // namespace TNL
......@@ -39,7 +39,7 @@
#include <TNL/Config/ConfigDescription.h>
#include <TNL/Exceptions/MPISupportMissing.h>
#include <TNL/Exceptions/MPIDimsCreateError.h>
#include <TNL/Communicators/MPITypeResolver.h>
namespace TNL {
......@@ -51,7 +51,7 @@ class MpiCommunicator
public: // TODO: this was private
#ifdef HAVE_MPI
inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; };
/*inline static MPI_Datatype MPIDataType( const signed char* ) { return MPI_CHAR; };
inline static MPI_Datatype MPIDataType( const signed short int* ) { return MPI_SHORT; };
inline static MPI_Datatype MPIDataType( const signed int* ) { return MPI_INT; };
inline static MPI_Datatype MPIDataType( const signed long int* ) { return MPI_LONG; };
......@@ -69,7 +69,7 @@ class MpiCommunicator
// sizeof(bool) is implementation-defined: https://stackoverflow.com/a/4897859
static_assert( sizeof(bool) == 1, "The programmer did not count with systems where sizeof(bool) != 1." );
return MPI_CHAR;
};
};*/
using Request = MPI_Request;
using CommunicationGroup = MPI_Comm;
......@@ -241,11 +241,19 @@ class MpiCommunicator
#endif
}
//dim-number of dimensions, distr array of guess distr - 0 for computation
//distr array will be filled by computed distribution
//more information in MPI documentation
static void DimsCreate(int nproc, int dim, int *distr)
{
#ifdef HAVE_MPI
template< typename T >
static MPI_Datatype getDataType( const T& t )
{
return MPITypeResolver< T >::getType();
};
#endif
//dim-number of dimensions, distr array of guess distr - 0 for computation
//distr array will be filled by computed distribution
//more information in MPI documentation
static void DimsCreate(int nproc, int dim, int *distr)
{
#ifdef HAVE_MPI
int sum = 0, prod = 1;
for( int i = 0;i < dim; i++ )
......@@ -288,7 +296,7 @@ class MpiCommunicator
TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
TNL_ASSERT_NE(group, NullGroup, "ISend cannot be called with NullGroup");
Request req;
MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPIDataType(data) , dest, tag, group, &req);
MPI_Isend( const_cast< void* >( ( const void* ) data ), count, MPITypeResolver< T >::getType(), dest, tag, group, &req);
return req;
#else
throw Exceptions::MPISupportMissing();
......@@ -302,7 +310,7 @@ class MpiCommunicator
TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
TNL_ASSERT_NE(group, NullGroup, "IRecv cannot be called with NullGroup");
Request req;
MPI_Irecv((void*) data, count, MPIDataType(data) , src, tag, group, &req);
MPI_Irecv((void*) data, count, MPITypeResolver< T >::getType() , src, tag, group, &req);
return req;
#else
throw Exceptions::MPISupportMissing();
......@@ -325,7 +333,7 @@ class MpiCommunicator
#ifdef HAVE_MPI
TNL_ASSERT_TRUE(IsInitialized(), "Fatal Error - MPI communicator is not initialized");
TNL_ASSERT_NE(group, NullGroup, "BCast cannot be called with NullGroup");
MPI_Bcast((void*) data, count, MPIDataType(data), root, group);
MPI_Bcast((void*) data, count, MPITypeResolver< T >::getType(), root, group);
#else
throw Exceptions::MPISupportMissing();
#endif
......@@ -340,7 +348,7 @@ class MpiCommunicator
{
#ifdef HAVE_MPI
TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPIDataType(data),op,group);
MPI_Allreduce( const_cast< void* >( ( void* ) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,group);
#else
throw Exceptions::MPISupportMissing();
#endif
......@@ -355,7 +363,7 @@ class MpiCommunicator
{
#ifdef HAVE_MPI
TNL_ASSERT_NE(group, NullGroup, "Allreduce cannot be called with NullGroup");
MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPIDataType(data),op,group);
MPI_Allreduce( MPI_IN_PLACE, (void*) data,count,MPITypeResolver< T >::getType(),op,group);
#else
throw Exceptions::MPISupportMissing();
#endif
......@@ -372,7 +380,7 @@ class MpiCommunicator
{
#ifdef HAVE_MPI
TNL_ASSERT_NE(group, NullGroup, "Reduce cannot be called with NullGroup");
MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPIDataType(data),op,root,group);
MPI_Reduce( const_cast< void* >( ( void*) data ), (void*) reduced_data,count,MPITypeResolver< T >::getType(),op,root,group);
#else
throw Exceptions::MPISupportMissing();
#endif
......@@ -394,12 +402,12 @@ class MpiCommunicator
MPI_Status status;
MPI_Sendrecv( const_cast< void* >( ( void* ) sendData ),
sendCount,
MPIDataType( sendData ),
MPITypeResolver< T >::getType(),
destination,
sendTag,
( void* ) receiveData,
receiveCount,
MPIDataType( receiveData ),
MPITypeResolver< T >::getType(),
source,
receiveTag,
group,
......@@ -420,10 +428,10 @@ class MpiCommunicator
TNL_ASSERT_NE(group, NullGroup, "SendReceive cannot be called with NullGroup");
MPI_Alltoall( const_cast< void* >( ( void* ) sendData ),
sendCount,
MPIDataType( sendData ),
MPITypeResolver< T >::getType(),
( void* ) receiveData,
receiveCount,
MPIDataType( receiveData ),
MPITypeResolver< T >::getType(),
group );
#else
throw Exceptions::MPISupportMissing();
......@@ -523,88 +531,6 @@ std::streambuf* MpiCommunicator::backup = nullptr;
std::ofstream MpiCommunicator::filestr;
bool MpiCommunicator::redirect = true;
#ifdef HAVE_MPI
// TODO: this duplicates MpiCommunicator::MPIDataType
template<typename Type>
struct MPITypeResolver
{
static inline MPI_Datatype getType()
{
switch( sizeof( Type ) )
{
case sizeof( char ):
return MPI_CHAR;
case sizeof( int ):
return MPI_INT;
case sizeof( short int ):
return MPI_SHORT;
case sizeof( long int ):
return MPI_LONG;
}
TNL_ASSERT_TRUE(false, "Fatal Error - Unknown MPI Type");
};
};
template<> struct MPITypeResolver< char >
{
static inline MPI_Datatype getType(){return MPI_CHAR;};
};
template<> struct MPITypeResolver< int >
{
static inline MPI_Datatype getType(){return MPI_INT;};
};
template<> struct MPITypeResolver< short int >
{
static inline MPI_Datatype getType(){return MPI_SHORT;};
};
template<> struct MPITypeResolver< long int >
{
static inline MPI_Datatype getType(){return MPI_LONG;};
};
template<> struct MPITypeResolver< unsigned char >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_CHAR;};
};
template<> struct MPITypeResolver< unsigned short int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_SHORT;};
};
template<> struct MPITypeResolver< unsigned int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED;};
};
template<> struct MPITypeResolver< unsigned long int >
{
static inline MPI_Datatype getType(){return MPI_UNSIGNED_LONG;};
};
template<> struct MPITypeResolver< float >
{
static inline MPI_Datatype getType(){return MPI_FLOAT;};
};
template<> struct MPITypeResolver< double >
{
static inline MPI_Datatype getType(){return MPI_DOUBLE;};
};
template<> struct MPITypeResolver< long double >
{
static inline MPI_Datatype getType(){return MPI_LONG_DOUBLE;};
};
template<> struct MPITypeResolver< bool >
{
static inline MPI_Datatype getType(){return MPI_C_BOOL;};
};
#endif
} // namespace <unnamed>
} // namespace Communicators
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment