Commit 129f17ad authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Added Python bindings (copied from tnl-mhfem)

parent df30f370
Loading
Loading
Loading
Loading
+22 −1
Original line number Diff line number Diff line
find_package( PythonInterp 3 )
find_package( PythonLibs 3 )

set( PYTHON_SITE_PACKAGES_DIR lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}/site-packages )

if( PYTHONINTERP_FOUND )
   CONFIGURE_FILE( "__init__.py.in" "${PROJECT_BUILD_PATH}/Python/__init__.py" )
   INSTALL( FILES ${PROJECT_BUILD_PATH}/Python/__init__.py
                  LogParser.py
            DESTINATION lib/python${PYTHON_VERSION_MAJOR}.${PYTHON_VERSION_MINOR}/site-packages/TNL )
            DESTINATION ${PYTHON_SITE_PACKAGES_DIR}/TNL )
endif()

if( PYTHONLIBS_FOUND )
   include(ExternalProject)
   ExternalProject_Add(pybind11_project
     GIT_REPOSITORY    https://github.com/pybind/pybind11.git
     GIT_TAG           master
     SOURCE_DIR        "${CMAKE_BINARY_DIR}/pybind11-src"
     BINARY_DIR        "${CMAKE_BINARY_DIR}/pybind11-build"
     CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_INSTALL_PREFIX} -DPYBIND11_TEST=FALSE
   )
   add_subdirectory(${CMAKE_BINARY_DIR}/pybind11-src ${CMAKE_BINARY_DIR}/pybind11-build)
#   find_package(pybind11 REQUIRED
#                PATHS ${CMAKE_INSTALL_PREFIX})

   add_subdirectory(pytnl)
else()
   message( "The Python.h header file was not found, Python bindings will not be builg." )
endif()
+13 −0
Original line number Diff line number Diff line
add_subdirectory( tnl )

set( headers
         exceptions.h
         RawIterator.h
         tnl_conversions.h
         tnl_indexing.h
         tnl_str_conversion.h
         tnl_tuple_conversion.h
         typedefs.h
)

install( FILES ${headers} DESTINATION "pytnl" )
+51 −0
Original line number Diff line number Diff line
#pragma once

#include <iterator>

template< typename DataType >
class RawIterator : public std::iterator<std::random_access_iterator_tag,
                                           DataType,
                                           ptrdiff_t,
                                           DataType*,
                                           DataType&>
{
protected:
    DataType*               m_ptr;

public:
    RawIterator( DataType* ptr = nullptr ) { m_ptr = ptr; }
    RawIterator( const RawIterator<DataType> & rawIterator ) = default;
    ~RawIterator(){}

    RawIterator<DataType>&  operator=( const RawIterator<DataType> & rawIterator ) = default;
    RawIterator<DataType>&  operator=( DataType* ptr ) { m_ptr = ptr; return (*this); }

    operator                bool() const
    {
        if(m_ptr)
            return true;
        else
            return false;
    }

    bool                    operator==( const RawIterator<DataType> & rawIterator ) const { return ( m_ptr == rawIterator.getConstPtr() ); }
    bool                    operator!=( const RawIterator<DataType> & rawIterator ) const { return ( m_ptr != rawIterator.getConstPtr() ); }

    RawIterator<DataType>&  operator+=( const ptrdiff_t & movement ){ m_ptr += movement; return (*this); }
    RawIterator<DataType>&  operator-=( const ptrdiff_t & movement ){ m_ptr -= movement; return (*this); }
    RawIterator<DataType>&  operator++() { ++m_ptr; return (*this); }
    RawIterator<DataType>&  operator--() { --m_ptr; return (*this); }
    RawIterator<DataType>   operator++( int ) { auto temp(*this); ++m_ptr; return temp; }
    RawIterator<DataType>   operator--( int ) { auto temp(*this); --m_ptr; return temp; }
    RawIterator<DataType>   operator+( const ptrdiff_t & movement ) { auto oldPtr = m_ptr; m_ptr+=movement; auto temp(*this); m_ptr = oldPtr; return temp; }
    RawIterator<DataType>   operator-( const ptrdiff_t & movement ) { auto oldPtr = m_ptr; m_ptr-=movement; auto temp(*this); m_ptr = oldPtr; return temp; }

    ptrdiff_t               operator-( const RawIterator<DataType>& rawIterator ) { return std::distance(rawIterator.getPtr(), this->getPtr()); }

    DataType&               operator*() { return *m_ptr; }
    const DataType&         operator*() const { return *m_ptr; }
    DataType*               operator->() { return m_ptr; }

    DataType*               getPtr() const { return m_ptr; }
    const DataType*         getConstPtr() const { return m_ptr; }
};
+35 −0
Original line number Diff line number Diff line
#pragma once

#include <stdexcept>

#include <pybind11/pybind11.h>
namespace py = pybind11;

#include <TNL/Assert.h>

struct NotImplementedError
   : public std::runtime_error
{
   NotImplementedError( const char* msg )
   : std::runtime_error( msg )
   {}
};

static void register_exceptions( py::module & m )
{
    py::register_exception_translator(
        [](std::exception_ptr p) {
            try {
                if (p) std::rethrow_exception(p);
            }
            // translate exceptions used in the bindings
            catch (const NotImplementedError & e) {
                PyErr_SetString(PyExc_NotImplementedError, e.what());
            }
            // translate TNL::Assert::AssertionError
            catch (const TNL::Assert::AssertionError & e) {
                PyErr_SetString(PyExc_AssertionError, e.what());
            }
        }
    );
}
+63 −0
Original line number Diff line number Diff line
#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
namespace py = pybind11;

#include "../tnl_indexing.h"

#include <TNL/Containers/Array.h>

template< typename ArrayType >
void export_Array(py::module & m, const char* name)
{
    auto array = py::class_<ArrayType, TNL::Object>(m, name, py::buffer_protocol())
        .def(py::init<>())
        .def(py::init<int>())
        .def_static("getType",              &ArrayType::getType)
        .def("getTypeVirtual",              &ArrayType::getTypeVirtual)
        .def_static("getSerializationType", &ArrayType::getSerializationType)
        .def("getSerializationTypeVirtual", &ArrayType::getSerializationTypeVirtual)
        .def("setSize", &ArrayType::setSize)
        .def("setLike", &ArrayType::template setLike<ArrayType>)
        .def("swap", &ArrayType::swap)
        .def("reset", &ArrayType::reset)
        .def("getSize", &ArrayType::getSize)
        .def("setElement", &ArrayType::setElement)
        .def("getElement", &ArrayType::getElement)
        // operator=
        .def("assign", []( ArrayType& array, const ArrayType& other ) -> ArrayType& {
                return array = other;
            })
        .def(py::self == py::self)
        .def(py::self != py::self)
        .def("setValue", &ArrayType::setValue)

        .def("__str__", []( ArrayType & a ) {
                std::stringstream ss;
                ss << a;
                return ss.str();
            } )

        // Python buffer protocol: http://pybind11.readthedocs.io/en/master/advanced/pycpp/numpy.html
        .def_buffer( [](ArrayType & a) -> py::buffer_info {
            return py::buffer_info(
                // Pointer to buffer
                a.getData(),
                // Size of one scalar
                sizeof( typename ArrayType::ElementType ),
                // Python struct-style format descriptor
                py::format_descriptor< typename ArrayType::ElementType >::format(),
                // Number of dimensions
                1,
                // Buffer dimensions
                { a.getSize() },
                // Strides (in bytes) for each index
                { sizeof( typename ArrayType::ElementType ) }
            );
        })
    ;

    tnl_indexing< ArrayType >( array );
    tnl_slice_indexing< ArrayType >( array );
}
Loading