Commit be4dbbb9 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Implemented more efficient algorithm for the spheres benchmark

parent 4143a4a6
Loading
Loading
Loading
Loading
+24 −47
Original line number Original line Diff line number Diff line
@@ -431,71 +431,48 @@ struct MeshBenchmarks
      if( ! checkDevice< Device >( parameters ) )
      if( ! checkDevice< Device >( parameters ) )
         return;
         return;


      const Index entitiesCount = mesh_src.template getEntitiesCount< 0 >();
      const Index verticesCount = mesh_src.template getEntitiesCount< 0 >();
      const Index facesCount = mesh_src.template getEntitiesCount< Mesh::getMeshDimension() - 1 >();


      const DeviceMesh mesh = mesh_src;
      const DeviceMesh mesh = mesh_src;
      Pointers::DevicePointer< const DeviceMesh > meshPointer( mesh );
      Pointers::DevicePointer< const DeviceMesh > meshPointer( mesh );
      Containers::Array< Real, Device, Index > spheres;
      Containers::Vector< Real, Device, Index > spheres;
      spheres.setSize( entitiesCount );
      spheres.setSize( verticesCount );


//      auto hasSubvertex = [] __cuda_callable__
      auto getLocalFaceIndex = [] __cuda_callable__
//         ( const typename DeviceMesh::Face & face,
//           const Index i )
//      {
////         constexpr auto verticesCount = Mesh::Face::template getSubentitiesCount< 0 >();
//         constexpr auto verticesCount = Mesh::Face::template SubentityTraits< 0 >::count;
//         for( LocalIndex v = 0; v < verticesCount; v++ ) {
//            const auto vid = face.template getSubentityIndex< 0 >( v );
//            if( vid == i )
//               return true;
//         }
//         return false;
//      };

      auto getLocalVertexIndex = [] __cuda_callable__
         ( const typename DeviceMesh::Cell & cell,
         ( const typename DeviceMesh::Cell & cell,
           const Index i )
           const Index i )
      {
      {
//         constexpr auto verticesCount = Mesh::Cell::template getSubentitiesCount< 0 >();
//         constexpr auto facesCount = Mesh::Cell::template getSubentitiesCount< 0 >();
         constexpr auto verticesCount = Mesh::Cell::template SubentityTraits< 0 >::count;
         constexpr auto facesCount = Mesh::Cell::template SubentityTraits< Mesh::getMeshDimension() - 1 >::count;
         for( LocalIndex v = 0; v < verticesCount; v++ ) {
         for( LocalIndex f = 0; f < facesCount; f++ ) {
            const auto vid = cell.template getSubentityIndex< 0 >( v );
            const auto fid = cell.template getSubentityIndex< Mesh::getMeshDimension() - 1 >( f );
            if( vid == i ) {
            if( fid == i ) {
               return v;
               return f;
            }
            }
         }
         }
         TNL_ASSERT( false,
         TNL_ASSERT( false,
                     std::cerr << "local vertex index not found -- this is a BUG!" << std::endl; );
                     std::cerr << "local face index not found -- this is a BUG!" << std::endl; );
         return (LocalIndex) 0;
         return (LocalIndex) 0;
      };
      };


      auto kernel_spheres = [getLocalVertexIndex] __cuda_callable__
      auto kernel_spheres = [getLocalFaceIndex] __cuda_callable__
         ( Index i,
         ( Index fid,
           const DeviceMesh* mesh,
           const DeviceMesh* mesh,
           Real* array )
           Real* array )
      {
      {
         Real s = 0.0;
        const auto& face = mesh->template getEntity< Mesh::getMeshDimension() - 1 >( fid );
         const auto& vertex = mesh->template getEntity< 0 >( i );
        const auto face_measure = getEntityMeasure( *mesh, face );
         const auto cellsCount = vertex.template getSuperentitiesCount< Mesh::getMeshDimension() >();

         const auto cellsCount = face.template getSuperentitiesCount< Mesh::getMeshDimension() >();
         for( LocalIndex c = 0; c < cellsCount; c++ ) {
         for( LocalIndex c = 0; c < cellsCount; c++ ) {
            const auto cid = vertex.template getSuperentityIndex< Mesh::getMeshDimension() >( c );
            const auto cid = face.template getSuperentityIndex< Mesh::getMeshDimension() >( c );
            const auto& cell = mesh->template getEntity< Mesh::getMeshDimension() >( cid );
            const auto& cell = mesh->template getEntity< Mesh::getMeshDimension() >( cid );
            // general version, but very slow
////            constexpr auto facesCount = Mesh::Cell::template getSubentitiesCount< Mesh::getMeshDimension() - 1 >();
//            constexpr auto facesCount = Mesh::Cell::template SubentityTraits< Mesh::getMeshDimension() - 1 >::count;
//            for( LocalIndex f = 0; f < facesCount; f++ ) {
//               const auto fid = cell.template getSubentityIndex< Mesh::getMeshDimension() - 1 >( f );
//               const auto& face = mesh->template getEntity< Mesh::getMeshDimension() - 1 >( fid );
//               if( ! hasSubvertex( face, i ) )
//                  s += getEntityMeasure( *mesh, face );
//            }
            // specialized version for simplices (assuming that opposite vertex and face have the same local index)
            // specialized version for simplices (assuming that opposite vertex and face have the same local index)
            const auto f = getLocalVertexIndex( cell, i );
            const auto v = getLocalFaceIndex( cell, fid );
            const auto fid = cell.template getSubentityIndex< Mesh::getMeshDimension() - 1 >( f );
            const auto vid = cell.template getSubentityIndex< 0 >( v );
            const auto& face = mesh->template getEntity< Mesh::getMeshDimension() - 1 >( fid );
            Algorithms::AtomicOperations< Device >::add( array[ vid ], face_measure );
            s += getEntityMeasure( *mesh, face );
         }
         }
         array[ i ] = s;
      };
      };


      auto reset = [&]() {
      auto reset = [&]() {
@@ -504,7 +481,7 @@ struct MeshBenchmarks


      auto benchmark_func = [&] () {
      auto benchmark_func = [&] () {
         Algorithms::ParallelFor< Device >::exec(
         Algorithms::ParallelFor< Device >::exec(
               (Index) 0, entitiesCount,
               (Index) 0, facesCount,
               kernel_spheres,
               kernel_spheres,
               &meshPointer.template getData< Device >(),
               &meshPointer.template getData< Device >(),
               spheres.getData() );
               spheres.getData() );