Skip to content
Snippets Groups Projects
Commit 0052f917 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

pytnl: updated bindings for Mesh, added missing methods

parent 0f95798e
No related branches found
No related tags found
1 merge request!82MPI refactoring
#pragma once
#include <pybind11/pybind11.h>
namespace py = pybind11;
enum class EntityTypes { Cell, Face, Vertex };
inline void
export_EntityTypes( py::module & m )
{
// avoid duplicate conversion -> export only once
static bool exported = false;
if( ! exported ) {
// TODO: should be nested types instead
py::enum_< EntityTypes >( m, "EntityTypes" )
.value("Cell", EntityTypes::Cell)
.value("Face", EntityTypes::Face)
.value("Vertex", EntityTypes::Vertex)
;
exported = true;
}
}
template< typename Mesh >
typename Mesh::GlobalIndexType
mesh_getEntitiesCount( const Mesh & self, const EntityTypes & entity )
{
if( entity == EntityTypes::Cell )
return self.template getEntitiesCount< typename Mesh::Cell >();
else if( entity == EntityTypes::Face )
return self.template getEntitiesCount< typename Mesh::Face >();
else if( entity == EntityTypes::Vertex )
return self.template getEntitiesCount< typename Mesh::Vertex >();
else
throw py::value_error("The entity parameter must be either Cell, Face or Vertex.");
}
......@@ -5,7 +5,7 @@ namespace py = pybind11;
#include "StaticVector.h"
#include "Grid_getSpaceStepsProducts.h"
#include "EntityTypes.h"
#include "mesh_getters.h"
#include <type_traits>
......@@ -54,8 +54,6 @@ void export_Grid( py::module & m, const char* name )
// void (Grid::* _setDimensions1)(const IndexType) = &Grid::setDimensions;
void (Grid::* _setDimensions2)(const typename Grid::CoordinatesType &) = &Grid::setDimensions;
export_EntityTypes(m);
auto grid = py::class_<Grid, TNL::Object>( m, name )
.def(py::init<>())
.def_static("getMeshDimension", &Grid::getMeshDimension)
......@@ -68,11 +66,13 @@ void export_Grid( py::module & m, const char* name )
.def("setDomain", &Grid::setDomain)
.def("getOrigin", &Grid::getOrigin, py::return_value_policy::reference_internal)
.def("getProportions", &Grid::getProportions, py::return_value_policy::reference_internal)
.def("getEntitiesCount", &mesh_getEntitiesCount< Grid >)
// TODO: if combined, the return type would depend on the runtime parameter (entity)
.def("getEntity_cell", &Grid::template getEntity<typename Grid::Cell>)
.def("getEntity_face", &Grid::template getEntity<typename Grid::Face>)
.def("getEntity_vertex", &Grid::template getEntity<typename Grid::Vertex>)
.def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Cell >)
.def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Face >)
.def("getEntitiesCount", &mesh_getEntitiesCount< Grid, typename Grid::Vertex >)
// NOTE: if combined into getEntity, the return type would depend on the runtime parameter (entity)
.def("getCell", &Grid::template getEntity<typename Grid::Cell>)
.def("getFace", &Grid::template getEntity<typename Grid::Face>)
.def("getVertex", &Grid::template getEntity<typename Grid::Vertex>)
.def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Cell>)
.def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Face>)
.def("getEntityIndex", &Grid::template getEntityIndex<typename Grid::Vertex>)
......
......@@ -5,7 +5,7 @@ namespace py = pybind11;
#include "../typedefs.h"
#include "StaticVector.h"
#include "EntityTypes.h"
#include "mesh_getters.h"
#include <TNL/String.h>
#include <TNL/Meshes/Geometry/getEntityCenter.h>
......@@ -82,8 +82,11 @@ template< typename MeshEntity, typename Scope >
void export_MeshEntity( Scope & scope, const char* name )
{
auto entity = py::class_< MeshEntity >( scope, name )
// .def(py::init<>())
// .def(py::init<typename MeshEntity::MeshType, typename MeshEntity::GlobalIndexType>())
.def_static("getEntityDimension", &MeshEntity::getEntityDimension)
.def("getIndex", &MeshEntity::getIndex)
.def("getTag", &MeshEntity::getTag)
// TODO
;
......@@ -95,23 +98,24 @@ void export_MeshEntity( Scope & scope, const char* name )
template< typename Mesh >
void export_Mesh( py::module & m, const char* name )
{
// there are two templates - const and non-const - take only the const
auto (Mesh::* getEntity_cell)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Cell>;
auto (Mesh::* getEntity_face)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Face>;
auto (Mesh::* getEntity_vertex)(const typename Mesh::GlobalIndexType) const = &Mesh::template getEntity<typename Mesh::Vertex>;
export_EntityTypes(m);
auto mesh = py::class_< Mesh, TNL::Object >( m, name )
.def(py::init<>())
.def_static("getMeshDimension", &Mesh::getMeshDimension)
.def_static("getSerializationType", &Mesh::getSerializationType)
.def("getSerializationTypeVirtual", &Mesh::getSerializationTypeVirtual)
.def("getEntitiesCount", &mesh_getEntitiesCount< Mesh >)
// TODO: if combined, the return type would depend on the runtime parameter (entity)
.def("getEntity_cell", getEntity_cell)
.def("getEntity_face", getEntity_face)
.def("getEntity_vertex", getEntity_vertex)
.def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Cell >)
.def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Face >)
.def("getEntitiesCount", &mesh_getEntitiesCount< Mesh, typename Mesh::Vertex >)
.def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Cell >)
.def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Face >)
.def("getGhostEntitiesCount", &mesh_getGhostEntitiesCount< Mesh, typename Mesh::Vertex >)
.def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Cell >)
.def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Face >)
.def("getGhostEntitiesOffset", &mesh_getGhostEntitiesOffset< Mesh, typename Mesh::Vertex >)
// NOTE: if combined into getEntity, the return type would depend on the runtime parameter (entity)
.def("getCell", &Mesh::template getEntity<typename Mesh::Cell>)
.def("getFace", &Mesh::template getEntity<typename Mesh::Face>)
.def("getVertex", &Mesh::template getEntity<typename Mesh::Vertex>)
.def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Cell& cell ){ return getEntityCenter( mesh, cell ); } )
.def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Face& face ){ return getEntityCenter( mesh, face ); } )
.def("getEntityCenter", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){ return getEntityCenter( mesh, vertex ); } )
......@@ -124,6 +128,12 @@ void export_Mesh( py::module & m, const char* name )
return mesh.template isBoundaryEntity< Mesh::Face::getEntityDimension() >( face.getIndex() ); } )
.def("isBoundaryEntity", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){
return mesh.template isBoundaryEntity< Mesh::Vertex::getEntityDimension() >( vertex.getIndex() ); } )
.def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Cell& cell ){
return mesh.template isGhostEntity< Mesh::Cell::getEntityDimension() >( cell.getIndex() ); } )
.def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Face& face ){
return mesh.template isGhostEntity< Mesh::Face::getEntityDimension() >( face.getIndex() ); } )
.def("isGhostEntity", []( const Mesh& mesh, const typename Mesh::Vertex& vertex ){
return mesh.template isGhostEntity< Mesh::Vertex::getEntityDimension() >( vertex.getIndex() ); } )
// TODO: more?
;
......
#pragma once
#include <type_traits>
template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getEntitiesCount( const Mesh & self, const EntityType & entity )
{
static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
std::is_same< EntityType, typename Mesh::Face >::value ||
std::is_same< EntityType, typename Mesh::Vertex >::value,
"incompatible entity type" );
return self.template getEntitiesCount< EntityType::getEntityDimension() >();
}
template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getGhostEntitiesCount( const Mesh & self, const EntityType & entity )
{
static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
std::is_same< EntityType, typename Mesh::Face >::value ||
std::is_same< EntityType, typename Mesh::Vertex >::value,
"incompatible entity type" );
return self.template getGhostEntitiesCount< EntityType::getEntityDimension() >();
}
template< typename Mesh, typename EntityType >
typename Mesh::GlobalIndexType
mesh_getGhostEntitiesOffset( const Mesh & self, const EntityType & entity )
{
static_assert( std::is_same< EntityType, typename Mesh::Cell >::value ||
std::is_same< EntityType, typename Mesh::Face >::value ||
std::is_same< EntityType, typename Mesh::Vertex >::value,
"incompatible entity type" );
return self.template getGhostEntitiesOffset< EntityType::getEntityDimension() >();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment