Commit dafe0492 authored by Vít Hanousek's avatar Vít Hanousek
Browse files

Add very dirty support of saving Distributed mesh from multiple GPU to single file using MPIIO.

parent 7b6adc5a
Loading
Loading
Loading
Loading
+58 −12
Original line number Diff line number Diff line
@@ -23,6 +23,8 @@
#include <TNL/Meshes/DistributedMeshes/DistributedMesh.h>
#include <TNL/Meshes/DistributedMeshes/CopyEntitiesHelper.h>
#include <TNL/Functions/MeshFunction.h>
#include <TNL/Devices/Host.h>
#include <TNL/Devices/Cuda.h>

namespace TNL {
namespace Meshes {   
@@ -31,13 +33,15 @@ namespace DistributedMeshes {
enum DistrGridIOTypes { Dummy = 0 , LocalCopy = 1, MpiIO=2 };
    
template<typename MeshFunctionType,
         DistrGridIOTypes type = LocalCopy> 
         DistrGridIOTypes type = LocalCopy,
         typename Device=typename MeshFunctionType::DeviceType> 
class DistributedGridIO
{
};

template<typename MeshFunctionType> 
class DistributedGridIO<MeshFunctionType,Dummy>
template<typename MeshFunctionType,
         typename Device> 
class DistributedGridIO<MeshFunctionType,Dummy,Device>
{
    bool save(const String& fileName, MeshFunctionType &meshFunction)
    {
@@ -55,8 +59,9 @@ class DistributedGridIO<MeshFunctionType,Dummy>
 * This variant cerate copy of MeshFunction but smaller, reduced to local entities, without overlap. 
 * It is slow and has high RAM consumption
 */
template<typename MeshFunctionType> 
class DistributedGridIO<MeshFunctionType,LocalCopy>
template<typename MeshFunctionType,
         typename Device> 
class DistributedGridIO<MeshFunctionType,LocalCopy,Device>
{

    public:
@@ -161,7 +166,7 @@ class DistributedGridIO<MeshFunctionType,LocalCopy>
#ifdef HAVE_MPI
#ifdef MPIIO  
template<typename MeshFunctionType> 
class DistributedGridIO<MeshFunctionType,MpiIO>
class DistributedGridIO_MPIIOBase
{
   public:

@@ -172,7 +177,7 @@ class DistributedGridIO<MeshFunctionType,MpiIO>
      typedef typename MeshFunctionType::VectorType VectorType;
      //typedef DistributedGrid< MeshType,MeshFunctionType::getMeshDimension()> DistributedGridType;
    
    static bool save(const String& fileName, MeshFunctionType &meshFunction)
    static bool save(const String& fileName, MeshFunctionType &meshFunction, RealType *data)
    {
     
        auto *distrGrid=meshFunction.getMesh().getDistributedMesh();
@@ -186,8 +191,6 @@ class DistributedGridIO<MeshFunctionType,MpiIO>
       MPI_Datatype atype;
       int dataCount=CreateDataTypes(distrGrid,&ftype,&atype);

       RealType* data=meshFunction.getData().getData();

       //write 
       MPI_File file;
       MPI_File_open( MPI_COMM_WORLD,
@@ -325,7 +328,7 @@ class DistributedGridIO<MeshFunctionType,MpiIO>
    };
            
    /* Funky bomb - no checks - only dirty load */
    static bool load(const String& fileName,MeshFunctionType &meshFunction) 
    static bool load(const String& fileName,MeshFunctionType &meshFunction, double *data ) 
    {
        auto *distrGrid=meshFunction.getMesh().getDistributedMesh();
        if(distrGrid==NULL) //not distributed
@@ -337,8 +340,6 @@ class DistributedGridIO<MeshFunctionType,MpiIO>
       MPI_Datatype atype;
       int dataCount=CreateDataTypes(distrGrid,&ftype,&atype);

       double * data=meshFunction.getData().getData();//TYP

       //write 
       MPI_File file;
       MPI_File_open( MPI_COMM_WORLD,
@@ -403,6 +404,51 @@ class DistributedGridIO<MeshFunctionType,MpiIO>
    };
    
};

template<typename MeshFunctionType> 
class DistributedGridIO<MeshFunctionType,MpiIO,TNL::Devices::Cuda>
{
    public:
    static bool save(const String& fileName, MeshFunctionType &meshFunction)
    {
        using HostVectorType = Containers::Vector<typename MeshFunctionType::RealType, Devices::Host, typename MeshFunctionType::IndexType >; 
        HostVectorType hostVector;
        hostVector=meshFunction.getData();
        typename MeshFunctionType::RealType * data=hostVector.getData();  
        return DistributedGridIO_MPIIOBase<MeshFunctionType>::save(fileName,meshFunction,data);
    };

    static bool load(const String& fileName,MeshFunctionType &meshFunction) 
    {
        using HostVectorType = Containers::Vector<typename MeshFunctionType::RealType, Devices::Host, typename MeshFunctionType::IndexType >; 
        HostVectorType hostVector;
        hostVector.setLike(meshFunction.getData());
        double * data=hostVector.getData();
        DistributedGridIO_MPIIOBase<MeshFunctionType>::load(fileName,meshFunction,data);
        meshFunction.getData()=hostVector;
        return true;
    };

};

template<typename MeshFunctionType> 
class DistributedGridIO<MeshFunctionType,MpiIO,TNL::Devices::Host>
{
    public:
    static bool save(const String& fileName, MeshFunctionType &meshFunction)
    {
        typename MeshFunctionType::RealType * data=meshFunction.getData().getData();      
        return DistributedGridIO_MPIIOBase<MeshFunctionType>::save(fileName,meshFunction,data);
    };

    static bool load(const String& fileName,MeshFunctionType &meshFunction) 
    {
        double * data=meshFunction.getData().getData();      
        return DistributedGridIO_MPIIOBase<MeshFunctionType>::load(fileName,meshFunction,data);
    };

};

#endif
#endif
}
+3 −3
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@

#include "HeatEquationProblem.h"

//#define MPIIO
#define MPIIO
#include <TNL/Meshes/DistributedMeshes/DistributedGridIO.h>


@@ -149,7 +149,7 @@ setInitialCondition( const Config::ParameterContainer& parameters,
   if(CommunicatorType::isDistributed())
    {
        std::cout<<"Nodes Distribution: " << uPointer->getMesh().getDistributedMesh()->printProcessDistr() << std::endl;
        Meshes::DistributedMeshes::DistributedGridIO<MeshFunctionType,Meshes::DistributedMeshes::LocalCopy> ::load(initialConditionFile, *uPointer );
        Meshes::DistributedMeshes::DistributedGridIO<MeshFunctionType,Meshes::DistributedMeshes::MpiIO> ::load(initialConditionFile, *uPointer );
        uPointer->template synchronize<CommunicatorType>();
    }
    else
@@ -214,7 +214,7 @@ makeSnapshot( const RealType& time,

   if(CommunicatorType::isDistributed())
   {
      Meshes::DistributedMeshes::DistributedGridIO<MeshFunctionType,Meshes::DistributedMeshes::LocalCopy> ::save(fileName.getFileName(), *uPointer );
      Meshes::DistributedMeshes::DistributedGridIO<MeshFunctionType,Meshes::DistributedMeshes::MpiIO> ::save(fileName.getFileName(), *uPointer );
   }
   else
   {