Commit 34e9dac0 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed distributeSubentities

parent 81eaf23e
Loading
Loading
Loading
Loading
+70 −42
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#pragma once

#include <numeric>   // std::iota
#include <atomic>

#include <TNL/Meshes/DistributedMeshes/DistributedMeshSynchronizer.h>
#include <TNL/Meshes/MeshDetails/layers/EntityTags/Traits.h>
@@ -43,35 +44,63 @@ distributeSubentities( DistributedMesh& mesh )
   const int rank = CommunicatorType::GetRank( mesh.getCommunicationGroup() );
   const int nproc = CommunicatorType::GetSize( mesh.getCommunicationGroup() );

   // exchange the global vertex index offsets so that each rank can determine the
   // owner of every vertex by its global index
   const GlobalIndexType ownVertexStart = mesh.template getGlobalIndices< 0 >().getElement( 0 );
   Containers::Array< GlobalIndexType, Devices::Host, int > vertexOffsets( nproc );
   {
      Containers::Array< GlobalIndexType, Devices::Host, int > sendbuf( nproc );
      sendbuf.setValue( ownVertexStart );
      CommunicatorType::Alltoall( sendbuf.getData(), 1,
                                  vertexOffsets.getData(), 1,
                                  mesh.getCommunicationGroup() );
   }
   // 0. exchange vertex data to prepare getVertexOwner and later on synchronizeSparse
   DistributedMeshSynchronizer< DistributedMesh, 0 > synchronizer;
   synchronizer.initialize( mesh );

   auto getVertexOwner = [&] ( GlobalIndexType local_idx ) -> int
   {
      const GlobalIndexType global_idx = mesh.template getGlobalIndices< 0 >()[ local_idx ];
      for( int i = 0; i < nproc - 1; i++ )
         if( vertexOffsets[ i ] <= global_idx && global_idx < vertexOffsets[ i + 1 ] )
            return i;
      return nproc - 1;
      return synchronizer.getEntityOwner( global_idx );
   };

   // find which rank owns all vertices of its local cells
   int rankOwningAllLocalCellSubvertices = nproc;
   {
      std::atomic<bool> its_us( true );
      mesh.getLocalMesh().template forLocal< DistributedMesh::getMeshDimension() >( [&] ( GlobalIndexType i ) mutable {
         for( LocalIndexType v = 0; v < mesh.getLocalMesh().template getSubentitiesCount< DistributedMesh::getMeshDimension(), 0 >( i ); v++ ) {
            const GlobalIndexType gv = mesh.getLocalMesh().template getSubentityIndex< DistributedMesh::getMeshDimension(), 0 >( i, v );
            if( getVertexOwner( gv ) != rank )
               its_us = false;
         }
      });
      Containers::Array< bool, Devices::Host, int > recvbuf( nproc ), sendbuf( nproc );
      sendbuf.setValue( its_us );
      CommunicatorType::Alltoall( sendbuf.getData(), 1,
                                  recvbuf.getData(), 1,
                                  mesh.getCommunicationGroup() );
      for( int i = 0; i < nproc; i++ )
         if( recvbuf[ i ] ) {
            rankOwningAllLocalCellSubvertices = i;
            break;
         }
   }
   if( rankOwningAllLocalCellSubvertices != 0 && rankOwningAllLocalCellSubvertices != nproc - 1 )
      throw std::runtime_error("Vertices are not distributed consistently. Shared vertices on the boundaries must be assigned "
                               "either to the highest or to the lowest rank. Thus, either the first or the last rank must "
                               "own all subvertices of its local cells.");

   auto getEntityOwner = [&] ( GlobalIndexType local_idx ) -> int
   {
      auto entity = mesh.getLocalMesh().template getEntity< Dimension >( local_idx );
      int owner = 0;
      int owner = (rankOwningAllLocalCellSubvertices == 0) ? 0 : nproc;
      if( rankOwningAllLocalCellSubvertices == 0 ) {
         // this assumes that vertices at the boundaries were assigned to the subdomain with the lowest rank
         // (this is used in DistributedMeshTest for simplicitty)
         for( LocalIndexType v = 0; v < entity.template getSubentitiesCount< 0 >(); v++ ) {
            const GlobalIndexType gv = entity.template getSubentityIndex< 0 >( v );
            owner = TNL::max( owner, getVertexOwner( gv ) );
         }
      }
      else {
         // this assumes that vertices at the boundaries were assigned to the subdomain with the highest rank
         // (this is what tnl-decompose-mesh does)
         for( LocalIndexType v = 0; v < entity.template getSubentitiesCount< 0 >(); v++ ) {
            const GlobalIndexType gv = entity.template getSubentityIndex< 0 >( v );
            owner = TNL::min( owner, getVertexOwner( gv ) );
         }
      }
      return owner;
   };

@@ -134,36 +163,35 @@ distributeSubentities( DistributedMesh& mesh )
      mesh.template getGlobalIndices< Dimension >()[ i ] = globalOffsets[ rank ] + i;
   });

   // 6. exchange cell data to prepare the communication pattern
   DistributedMeshSynchronizer< DistributedMesh > synchronizer;
   synchronizer.initialize( mesh );

   // 7. exchange local indices for ghost entities
   const auto sparseResult = synchronizer.synchronizeSparse( localMesh.template getSubentitiesMatrix< DistributedMesh::getMeshDimension(), Dimension >() );
   // 6. exchange local indices for ghost entities
   // We have to synchronize the vertex-entity superentity matrix, synchronization based
   // on the cell-entity subentity matrix is not general. For example, two subdomains can
   // have a common face, but no common cell, even when ghost_levels > 0. On the other
   // hand, if two subdomains have a common face, they have common all its subvertices,
   // so it is ensured that we send/receive indices for all ghost entities (with a rather
   // great redundancy).
   const auto sparseResult = synchronizer.synchronizeSparse( localMesh.template getSuperentitiesMatrix< 0, Dimension >() );
   const auto& rankOffsets = std::get< 0 >( sparseResult );
   const auto& rowPointers = std::get< 1 >( sparseResult );
   const auto& columnIndices = std::get< 2 >( sparseResult );

   // 8. set the global indices of our ghost entities
   for( int i = 0; i < nproc; i++ ) {
      if( i == rank )
         continue;
      for( GlobalIndexType cell = synchronizer.getGhostOffsets()[ i ]; cell < synchronizer.getGhostOffsets()[ i + 1 ]; cell++ ) {
         for( LocalIndexType e = 0; e < mesh.getLocalMesh().template getSubentitiesCount< DistributedMesh::getMeshDimension(), Dimension >( cell ); e++ ) {
            const GlobalIndexType entityIndex = mesh.getLocalMesh().template getSubentityIndex< DistributedMesh::getMeshDimension(), Dimension >( cell, e );
   // 7. set the global indices of our ghost entities
   localMesh.template forGhost< Dimension >( [&] ( GlobalIndexType entityIndex ) mutable {
      const int owner = getEntityOwner( entityIndex );
            // pick the right owner as we might have received an index from multiple ranks
            if( owner == i ) {
               const GlobalIndexType ghostOffset = cell - synchronizer.getGhostOffsets()[ owner ];
      for( LocalIndexType v = 0; v < localMesh.template getSubentitiesCount< Dimension, 0 >( entityIndex ); v++ ) {
         const GlobalIndexType vertex = localMesh.template getSubentityIndex< Dimension, 0 >( entityIndex, v );
         const int vertexOwner = getVertexOwner( vertex );
         if( vertexOwner == owner ) {
            const GlobalIndexType ghostOffset = vertex - synchronizer.getGhostOffsets()[ vertexOwner ];
            // global index = owner's local index + owner's offset
               const GlobalIndexType globalEntityIndex = columnIndices[ rowPointers[ rankOffsets[ owner ] + ghostOffset ] + e ] + globalOffsets[ owner ];
            const GlobalIndexType globalEntityIndex = columnIndices[ rowPointers[ rankOffsets[ vertexOwner ] + ghostOffset ] + v ] + globalOffsets[ owner ];
            mesh.template getGlobalIndices< Dimension >()[ entityIndex ] = globalEntityIndex;
            break;
         }
      }
      }
   }
   });

   // 9. reorder the entities to make sure that global indices are sorted
   // 8. reorder the entities to make sure that global indices are sorted
   {
      // prepare vector with an identity permutation
      std::vector< GlobalIndexType > permutation( localMesh.template getEntitiesCount< Dimension >() );