Commit 0052f917 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

pytnl: updated bindings for Mesh, added missing methods

parent 0f95798e
Loading
Loading
Loading
Loading
+0 −36
Original line number Diff line number Diff line
#pragma once

#include <pybind11/pybind11.h>
namespace py = pybind11;

enum class EntityTypes { Cell, Face, Vertex };

inline void
export_EntityTypes( py::module & m )
{
    // avoid duplicate conversion -> export only once
    static bool exported = false;
    if( ! exported ) {
        // TODO: should be nested types instead
        py::enum_< EntityTypes >( m, "EntityTypes" )
            .value("Cell", EntityTypes::Cell)
            .value("Face", EntityTypes::Face)
            .value("Vertex", EntityTypes::Vertex)
        ;
        exported = true;
    }
}

template< typename Mesh >
typename Mesh::GlobalIndexType
mesh_getEntitiesCount( const Mesh & self, const EntityTypes & entity )
{
    if( entity == EntityTypes::Cell )
        return self.template getEntitiesCount< typename Mesh::Cell >();
    else if( entity == EntityTypes::Face )
        return self.template getEntitiesCount< typename Mesh::Face >();
    else if( entity == EntityTypes::Vertex )
        return self.template getEntitiesCount< typename Mesh::Vertex >();
    else
        throw py::value_error("The entity parameter must be either Cell, Face or Vertex.");
}
+8 −8
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ namespace py = pybind11;

#include "StaticVector.h"
#include "Grid_getSpaceStepsProducts.h"
#include "EntityTypes.h"
#include "mesh_getters.h"

#include <type_traits>

@@ -54,8 +54,6 @@ void export_Grid( py::module & m, const char* name )
//    void (Grid::* _setDimensions1)(const IndexType) = &Grid::setDimensions;
    void (Grid::* _setDimensions2)(const typename Grid::CoordinatesType &) = &Grid::setDimensions;

    export_EntityTypes(m);

    auto grid = py::class_<Grid, TNL::Object>( m, name )
        .def(py::init<>())
        .def_static("getMeshDimension", &Grid::getMeshDimension)
@@ -68,11 +66,13 @@ void export_Grid( py::module & m, const char* name )
        .def("setDomain", &Grid::setDomain)
        .def("getOrigin", &Grid::getOrigin, py::return_value_policy::reference_internal)
        .def("getProportions", &Grid::getProportions, py::return_value_policy::reference_internal)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Grid >)
        // TODO: if combined, the return type would depend on the runtime parameter (entity)
        .def("getEntity_cell", &Grid::template getEntity<typename Grid::Cell>)
        .def("getEntity_face", &Grid::template getEntity<typename Grid::Face>)
        .def("getEntity_vertex", &Grid::template getEntity<typename Grid::Vertex>)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Cell >)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Face >)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Vertex >)
        // NOTE: if combined into getEntity, the return type would depend on the runtime parameter (entity)
        .def("getCell", &Grid::template getEntity<typename Grid::Cell>)
        .def("getFace", &Grid::template getEntity<typename Grid::Face>)
        .def("getVertex", &Grid::template getEntity<typename Grid::Vertex>)
        .def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Cell>)
        .def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Face>)
        .def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Vertex>)
+23 −13
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ namespace py = pybind11;

#include "../typedefs.h"
#include "StaticVector.h"
#include "EntityTypes.h"
#include "mesh_getters.h"

#include <TNL/String.h>
#include <TNL/Meshes/Geometry/getEntityCenter.h>
@@ -82,8 +82,11 @@ template< typename MeshEntity, typename Scope >
void export_MeshEntity( Scope & scope, const char* name )
{
    auto entity = py::class_< MeshEntity >( scope, name )
//        .def(py::init<>())
//        .def(py::init<typename MeshEntity::MeshType, typename MeshEntity::GlobalIndexType>())
        .def_static("getEntityDimension", &MeshEntity::getEntityDimension)
        .def("getIndex", &MeshEntity::getIndex)
        .def("getTag", &MeshEntity::getTag)
        // TODO
    ;

@@ -95,23 +98,24 @@ void export_MeshEntity( Scope & scope, const char* name )
template< typename Mesh >
void export_Mesh( py::module & m, const char* name )
{
    // there are two templates - const and non-const - take only the const
    auto (Mesh::* getEntity_cell)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Cell>;
    auto (Mesh::* getEntity_face)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Face>;
    auto (Mesh::* getEntity_vertex)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Vertex>;

    export_EntityTypes(m);

    auto mesh = py::class_< Mesh, TNL::Object >( m, name )
        .def(py::init<>())
        .def_static("getMeshDimension", &Mesh::getMeshDimension)
        .def_static("getSerializationType", &Mesh::getSerializationType)
        .def("getSerializationTypeVirtual", &Mesh::getSerializationTypeVirtual)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Mesh >)
        // TODO: if combined, the return type would depend on the runtime parameter (entity)
        .def("getEntity_cell", getEntity_cell)
        .def("getEntity_face", getEntity_face)
        .def("getEntity_vertex", getEntity_vertex)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Cell >)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Face >)
        .def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Vertex >)
        .def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Cell >)
        .def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Face >)
        .def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Vertex >)
        .def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Cell >)
        .def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Face >)
        .def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Vertex >)
        // NOTE: if combined into getEntity, the return type would depend on the runtime parameter (entity)
        .def("getCell", &Mesh::template getEntity<typename Mesh::Cell>)
        .def("getFace", &Mesh::template getEntity<typename Mesh::Face>)
        .def("getVertex", &Mesh::template getEntity<typename Mesh::Vertex>)
        .def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Cell& cell ){ return getEntityCenter( mesh, cell ); } )
        .def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Face& face ){ return getEntityCenter( mesh, face ); } )
        .def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){ return getEntityCenter( mesh, vertex ); } )
@@ -124,6 +128,12 @@ void export_Mesh( py::module & m, const char* name )
                                       return mesh.template isBoundaryEntity< Mesh::Face::getEntityDimension() >( face.getIndex() ); } )
        .def("isBoundaryEntity", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){
                                        return mesh.template isBoundaryEntity< Mesh::Vertex::getEntityDimension() >( vertex.getIndex() ); } )
        .def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Cell& cell ){
                                       return mesh.template isGhostEntity< Mesh::Cell::getEntityDimension() >( cell.getIndex() ); } )
        .def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Face& face ){
                                       return mesh.template isGhostEntity< Mesh::Face::getEntityDimension() >( face.getIndex() ); } )
        .def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){
                                        return mesh.template isGhostEntity< Mesh::Vertex::getEntityDimension() >( vertex.getIndex() ); } )
        // TODO: more?
    ;

+36 −0
Original line number Diff line number Diff line
#pragma once

#include <type_traits>

template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getEntitiesCount( const Mesh & self, const EntityType & entity )
{
    static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
                   std::is_same< EntityType, typename Mesh::Face >::value ||
                   std::is_same< EntityType, typename Mesh::Vertex >::value,
                   "incompatible entity type" );
    return self.template getEntitiesCount< EntityType::getEntityDimension() >();
}

template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getGhostEntitiesCount( const Mesh & self, const EntityType & entity )
{
    static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
                   std::is_same< EntityType, typename Mesh::Face >::value ||
                   std::is_same< EntityType, typename Mesh::Vertex >::value,
                   "incompatible entity type" );
    return self.template getGhostEntitiesCount< EntityType::getEntityDimension() >();
}

template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getGhostEntitiesOffset( const Mesh & self, const EntityType & entity )
{
    static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
                   std::is_same< EntityType, typename Mesh::Face >::value ||
                   std::is_same< EntityType, typename Mesh::Vertex >::value,
                   "incompatible entity type" );
    return self.template getGhostEntitiesOffset< EntityType::getEntityDimension() >();
}