Commit a2eb4793 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added masks for periodic boundaries to distributed grid synchronizer.

parent cdd9b5d8
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -164,8 +164,11 @@ class MeshFunction :
 
      using Object::boundLoad;

      template< typename CommunicatorType>
      void synchronize( bool withPeriodicBoundaryConditions = false );
      template< typename CommunicatorType,
                typename PeriodicBoundariesMaskType = MeshFunction< Mesh, MeshEntityDimension, bool > >
      void synchronize( bool withPeriodicBoundaryConditions = false,
                        const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask =
                           Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >( nullptr ) );

 
   protected:
+5 −3
Original line number Diff line number Diff line
@@ -548,15 +548,17 @@ write( const String& fileName,
template< typename Mesh,
          int MeshEntityDimension,
          typename Real >
template< typename CommunicatorType>
template< typename CommunicatorType,
          typename PeriodicBoundariesMaskType >
void
MeshFunction< Mesh, MeshEntityDimension, Real >:: 
synchronize( bool periodicBoundaries )
synchronize( bool periodicBoundaries,
             const Pointers::SharedPointer< PeriodicBoundariesMaskType, DeviceType >& mask )
{
    auto distrMesh = this->getMesh().getDistributedMesh();
    if(distrMesh != NULL && distrMesh->isDistributed())
    {
        this->synchronizer.template synchronize<CommunicatorType>( *this, periodicBoundaries );
        this->synchronizer.template synchronize<CommunicatorType>( *this, periodicBoundaries, mask );
    }
}

+112 −69
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ namespace DistributedMeshes {


template < typename MeshFunctionType,
           typename PeriodicBoundariesMaskPointer,
           int dim,
           typename RealType=typename MeshFunctionType::MeshType::RealType,
           typename Device=typename MeshFunctionType::MeshType::DeviceType,
@@ -28,93 +29,135 @@ class BufferEntitiesHelper
{
};

//======================================== 1D ====================================================

template < typename MeshFunctionType, typename RealType, typename Device, typename Index >
class BufferEntitiesHelper<MeshFunctionType,1,RealType,Device,Index>
template < typename MeshFunctionType,
           typename MaskPointer,
           typename RealType,
           typename Device,
           typename Index >
class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 1, RealType, Device, Index >
{
   public:
    static void BufferEntities(MeshFunctionType& meshFunction, RealType * buffer, Index beginx, Index sizex, bool tobuffer)
      static void BufferEntities( 
         MeshFunctionType& meshFunction,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& sizex,
         bool tobuffer )
      {
         auto mesh = meshFunction.getMesh();
         RealType* meshFunctionData = meshFunction.getData().getData();
        auto kernel = [tobuffer, mesh, buffer, meshFunctionData, beginx] __cuda_callable__ ( Index j )
         const typename MaskPointer::ObjectType* mask( nullptr );
         if( maskPointer )
            mask = &maskPointer.template getData< Device >();
         auto kernel = [tobuffer, mesh, buffer, isBoundary, meshFunctionData, mask, beginx ] __cuda_callable__ ( Index j )
         {
            typename MeshFunctionType::MeshType::Cell entity(mesh);
            entity.getCoordinates().x()=beginx+j;
            entity.refresh();
            if( ! isBoundary || ! mask || ( *mask )[ entity.getIndex() ] )
            {
               if( tobuffer )
                  buffer[ j ] = meshFunctionData[ entity.getIndex() ];
               else
                  meshFunctionData[ entity.getIndex() ] = buffer[ j ];
            }
         };
         ParallelFor< Device >::exec( 0, sizex, kernel );
      };  
};


//======================================== 2D ====================================================
template <typename MeshFunctionType, typename RealType, typename Device, typename Index  > 
class BufferEntitiesHelper<MeshFunctionType,2,RealType,Device,Index>
template< typename MeshFunctionType,
          typename MaskPointer, 
          typename RealType,
          typename Device,
          typename Index  > 
class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 2, RealType, Device, Index >
{
   public:
    static void BufferEntities(MeshFunctionType& meshFunction, RealType * buffer, Index beginx, Index beginy, Index sizex, Index sizey,bool tobuffer)
      static void BufferEntities(
         MeshFunctionType& meshFunction,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& beginy,
         const Index& sizex,
         const Index& sizey,
         bool tobuffer)
      {
         auto mesh=meshFunction.getMesh();
         RealType* meshFunctionData = meshFunction.getData().getData();      
        auto kernel = [tobuffer, mesh, buffer, meshFunctionData, beginx, sizex, beginy] __cuda_callable__ ( Index i, Index j )
         const typename MaskPointer::ObjectType* mask( nullptr );
         if( maskPointer )
            mask = &maskPointer.template getData< Device >();

         auto kernel = [ tobuffer, mask, mesh, buffer, isBoundary, meshFunctionData, beginx, sizex, beginy] __cuda_callable__ ( Index i, Index j )
         {
            typename MeshFunctionType::MeshType::Cell entity(mesh);
            entity.getCoordinates().x() = beginx + j;
            entity.getCoordinates().y() = beginy + i;				
            entity.refresh();
            if( ! isBoundary || ! mask || ( *mask )[ entity.getIndex() ] )
            {
               if( tobuffer )
                  buffer[ i * sizex + j ] = meshFunctionData[ entity.getIndex() ];
               else
                  meshFunctionData[ entity.getIndex() ] = buffer[ i * sizex + j ];
            }
         };
        
         ParallelFor2D< Device >::exec( 0, 0, sizey, sizex, kernel );     
        
      };
};


//======================================== 3D ====================================================
template <typename MeshFunctionType, typename RealType, typename Device, typename Index >
class BufferEntitiesHelper<MeshFunctionType,3,RealType,Device,Index>
template< typename MeshFunctionType,
          typename MaskPointer,
          typename RealType,
          typename Device,
          typename Index >
class BufferEntitiesHelper< MeshFunctionType, MaskPointer, 3, RealType, Device, Index >
{
   public:
    static void BufferEntities(MeshFunctionType& meshFunction, RealType * buffer, Index beginx, Index beginy, Index beginz, Index sizex, Index sizey, Index sizez, bool tobuffer)
      static void BufferEntities(
         MeshFunctionType& meshFunction,
         const MaskPointer& maskPointer,
         RealType* buffer,
         bool isBoundary,
         const Index& beginx,
         const Index& beginy,
         const Index& beginz,
         const Index& sizex,
         const Index& sizey,
         const Index& sizez,
         bool tobuffer)
      {

         auto mesh=meshFunction.getMesh();
         RealType * meshFunctionData=meshFunction.getData().getData();
        auto kernel = [tobuffer, mesh, buffer, meshFunctionData, beginx, sizex, beginy, sizey, beginz] __cuda_callable__ ( Index k, Index i, Index j )
         const typename MaskPointer::ObjectType* mask( nullptr );
         if( maskPointer )
            mask = &maskPointer.template getData< Device >();         
         auto kernel = [ tobuffer, mesh, mask, buffer, isBoundary, meshFunctionData, beginx, sizex, beginy, sizey, beginz] __cuda_callable__ ( Index k, Index i, Index j )
         {
            typename MeshFunctionType::MeshType::Cell entity(mesh);
            entity.getCoordinates().x() = beginx + j;
            entity.getCoordinates().z() = beginz + k;
            entity.getCoordinates().y() = beginy + i;
            entity.refresh();
            if( ! isBoundary || ! mask || ( *mask )[ entity.getIndex() ] )
            {
               if( tobuffer )
                    buffer[k*sizex*sizey+i*sizex+j]=meshFunctionData[entity.getIndex()];
                  buffer[ k * sizex * sizey + i * sizex + j ] = 
                     meshFunctionData[ entity.getIndex() ];
               else
                  meshFunctionData[ entity.getIndex() ] = buffer[ k * sizex * sizey + i * sizex + j ];
            }
         };

         ParallelFor3D< Device >::exec( 0, 0, 0, sizez, sizey, sizex, kernel ); 

        /*for(int k=0;k<sizez;k++)
        {
            for(int i=0;i<sizey;i++)
            {
                for(int j=0;j<sizex;j++)
                {
                        kernel(k,i,j);
                }
            }
        }*/
      };
};

+0 −2
Original line number Diff line number Diff line
@@ -380,8 +380,6 @@ void
DistributedMesh< Grid< Dimension, Real, Device, Index > >::
setupNeighbors()
{
   int *neighbors = this->neighbors;

   for( int i = 0; i < getNeighborsCount(); i++ )
   {
      auto direction = Directions::template getXYZ< Dimension >( i );
+26 −13
Original line number Diff line number Diff line
@@ -75,9 +75,12 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 1, GridReal, D

      };

      template<typename CommunicatorType, typename MeshFunctionType>
      template< typename CommunicatorType,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer = Pointers::SharedPointer< MeshFunctionType > >
      void synchronize( MeshFunctionType &meshFunction,
                        bool periodicBoundaries = false )
                        bool periodicBoundaries = false,
                        const PeriodicBoundariesMaskPointer& mask = PeriodicBoundariesMaskPointer( nullptr ) )
      {
         TNL_ASSERT_TRUE( isSet, "Synchronizer is not set, but used to synchronize" );
         
@@ -109,7 +112,8 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 1, GridReal, D
                      leftSource, rightSource,
                      lowerOverlap, upperOverlap,
                      neighbors,
                      periodicBoundaries );
                      periodicBoundaries,
                      PeriodicBoundariesMaskPointer( nullptr ) ); // the mask is used only when receiving data 

         //async send
         typename CommunicatorType::Request requests[ 4 ];
@@ -156,24 +160,33 @@ class DistributedMeshSynchronizer< Functions::MeshFunction< Grid< 1, GridReal, D
            lowerOverlap,
            upperOverlap,
            neighbors,
            periodicBoundaries );
            periodicBoundaries,
            mask );
      }
      
   private:
      template <typename Real_, typename MeshFunctionType >
      void copyBuffers( MeshFunctionType& meshFunction, TNL::Containers::Array<Real_,Device>* buffers, bool toBuffer,
      template< typename Real_,
                typename MeshFunctionType,
                typename PeriodicBoundariesMaskPointer >
      void copyBuffers( 
         MeshFunctionType& meshFunction,
         TNL::Containers::Array<Real_,Device>* buffers,
         bool toBuffer,
         int left, int right,
         const SubdomainOverlapsType& lowerOverlap,
         const SubdomainOverlapsType& upperOverlap,
         const int* neighbors,
         bool periodicBoundaries )
         bool periodicBoundaries,
         const PeriodicBoundariesMaskPointer& mask )
      
      {
         typedef BufferEntitiesHelper< MeshFunctionType, 1, Real_, Device > Helper;
         if( neighbors[ Left ] != -1 || periodicBoundaries )
            Helper::BufferEntities( meshFunction, buffers[ Left ].getData(), left, lowerOverlap.x(), toBuffer );
         if( neighbors[ Right ] != -1 || periodicBoundaries )
            Helper::BufferEntities( meshFunction, buffers[ Right ].getData(), right, upperOverlap.x(), toBuffer );
         typedef BufferEntitiesHelper< MeshFunctionType, PeriodicBoundariesMaskPointer, 1, Real_, Device > Helper;
         bool leftIsBoundary = ( neighbors[ Left ] == -1 );
         bool rightIsBoundary = ( neighbors[ Right ] == -1 );
         if( ! leftIsBoundary || periodicBoundaries )
            Helper::BufferEntities( meshFunction, mask, buffers[ Left ].getData(), leftIsBoundary, left, lowerOverlap.x(), toBuffer );
         if( ! rightIsBoundary || periodicBoundaries )
            Helper::BufferEntities( meshFunction, mask, buffers[ Right ].getData(), rightIsBoundary, right, upperOverlap.x(), toBuffer );
      }

      Containers::Array<RealType, Device> sendBuffers[ 2 ], receiveBuffers[ 2 ];
Loading