Commit 099f12f2 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

PyTNL: export bindings for sparse matrices based on segments

Fixes #86
parent b75c0e74
Loading
Loading
Loading
Loading
+33 −19
Original line number Diff line number Diff line
@@ -2,17 +2,27 @@
#include "../tnl_conversions.h"

#include "SparseMatrix.h"
#include "../typedefs.h"

#include <Benchmarks/SpMV/ReferenceFormats/Legacy/CSR.h>
#include <Benchmarks/SpMV/ReferenceFormats/Legacy/Ellpack.h>
#include <Benchmarks/SpMV/ReferenceFormats/Legacy/SlicedEllpack.h>
#include <TNL/Matrices/SparseMatrix.h>
#include <TNL/Matrices/SparseOperations.h>
#include <TNL/Algorithms/Segments/CSR.h>
#include <TNL/Algorithms/Segments/Ellpack.h>
#include <TNL/Algorithms/Segments/SlicedEllpack.h>

using CSR_host = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::CSR< double, TNL::Devices::Host, int >;
using CSR_cuda = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::CSR< double, TNL::Devices::Cuda, int >;
using E_host = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::Ellpack< double, TNL::Devices::Host, int >;
using E_cuda = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::Ellpack< double, TNL::Devices::Cuda, int >;
using SE_host = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::SlicedEllpack< double, TNL::Devices::Host, int >;
using SE_cuda = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::SlicedEllpack< double, TNL::Devices::Cuda, int >;
template< typename Device, typename Index, typename IndexAllocator >
using CSR = TNL::Algorithms::Segments::CSRDefault< Device, Index, IndexAllocator >;
template< typename Device, typename Index, typename IndexAllocator >
using Ellpack = TNL::Algorithms::Segments::Ellpack< Device, Index, IndexAllocator >;
template< typename Device, typename Index, typename IndexAllocator >
using SlicedEllpack = TNL::Algorithms::Segments::SlicedEllpack< Device, Index, IndexAllocator >;

using CSR_host = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Host, IndexType, TNL::Matrices::GeneralMatrix, CSR >;
using CSR_cuda = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Cuda, IndexType, TNL::Matrices::GeneralMatrix, CSR >;
using E_host   = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Host, IndexType, TNL::Matrices::GeneralMatrix, Ellpack >;
using E_cuda   = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Cuda, IndexType, TNL::Matrices::GeneralMatrix, Ellpack >;
using SE_host  = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Host, IndexType, TNL::Matrices::GeneralMatrix, SlicedEllpack >;
using SE_cuda  = TNL::Matrices::SparseMatrix< RealType, TNL::Devices::Cuda, IndexType, TNL::Matrices::GeneralMatrix, SlicedEllpack >;

void export_SparseMatrices( py::module & m )
{
@@ -20,11 +30,15 @@ void export_SparseMatrices( py::module & m )
   export_Matrix< E_host   >( m, "Ellpack" );
   export_Matrix< SE_host  >( m, "SlicedEllpack" );

    // TODO: copySparseMatrix does not work with Legacy matrices anymore
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< CSR_host, E_host >);
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< E_host, CSR_host >);
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< CSR_host, SE_host >);
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< SE_host, CSR_host >);
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< E_host, SE_host >);
    //m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< SE_host, E_host >);
   // NOTE: all exported formats (CSR, Ellpack, SlicedEllpack) use the same SegmentView,
   // so the RowView and ConstRowView are also the same types in all three formats
   export_RowView< typename CSR_host::RowView >( m, "SparseMatrixRowView" );
   export_RowView< typename CSR_host::ConstRowView >( m, "SparseMatrixConstRowView" );

   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< CSR_host, E_host >);
   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< E_host, CSR_host >);
   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< CSR_host, SE_host >);
   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< SE_host, CSR_host >);
   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< E_host, SE_host >);
   m.def("copySparseMatrix", &TNL::Matrices::copySparseMatrix< SE_host, E_host >);
}
+145 −95
Original line number Diff line number Diff line
#pragma once

#include <pybind11/pybind11.h>
#include <pybind11/operators.h>
namespace py = pybind11;

#include <TNL/String.h>
#include <TNL/Containers/Vector.h>
#include <Benchmarks/SpMV/ReferenceFormats/Legacy/CSR.h>
#include <TNL/TypeTraits.h>

template< typename Matrix >
struct SpecificExports
template< typename RowView, typename Scope >
std::enable_if_t< ! std::is_const< typename RowView::RealType >::value >
export_RowView_nonconst( Scope & s )
{
   using RealType = typename RowView::RealType;
   using IndexType = typename RowView::IndexType;

   s
      .def("getColumnIndex", []( RowView& row, IndexType localIdx ) -> IndexType& {
               return row.getColumnIndex( localIdx );
         }, py::return_value_policy::reference_internal)
      .def("getValue", []( RowView& row, IndexType localIdx ) -> RealType& {
               return row.getValue( localIdx );
         }, py::return_value_policy::reference_internal)
      .def("setValue",         &RowView::setValue)
      .def("setColumnIndex",   &RowView::setColumnIndex)
      .def("setElement",       &RowView::setElement)
   ;
}

template< typename RowView, typename Scope >
std::enable_if_t< std::is_const< typename RowView::RealType >::value >
export_RowView_nonconst( Scope & s )
{}

template< typename RowView, typename Scope >
void export_RowView( Scope & s, const char* name )
{
   using RealType = typename RowView::RealType;
   using IndexType = typename RowView::IndexType;

   auto rowView = py::class_< RowView >( s, name )
      .def("getSize",          &RowView::getSize)
      .def("getRowIndex",      &RowView::getRowIndex)
      .def("getColumnIndex", []( const RowView& row, IndexType localIdx ) -> const IndexType& {
               return row.getColumnIndex( localIdx );
         }, py::return_value_policy::reference_internal)
      .def("getValue", []( const RowView& row, IndexType localIdx ) -> const RealType& {
               return row.getValue( localIdx );
         }, py::return_value_policy::reference_internal)
      .def(py::self == py::self)
//      .def(py::self_ns::str(py::self_ns::self))
   ;
   export_RowView_nonconst< RowView >( rowView );
}

template< typename Segments, typename Enable = void >
struct export_CSR
{
   template< typename Scope >
    static void exec( Scope & s ) {}
   static void e( Scope & s ) {}
};

template< typename Real, typename Device, typename Index >
struct SpecificExports< TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::CSR< Real, Device, Index > >
template< typename Segments >
struct export_CSR< Segments, typename TNL::enable_if_type< decltype(Segments{}.getOffsets()) >::type >
{
   template< typename Scope >
    static void exec( Scope & s )
   static void e( Scope & s )
   {
        using Matrix = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::CSR< Real, Device, Index >;

        s.def("getRowPointers",   py::overload_cast<>(&Matrix::getRowPointers),   py::return_value_policy::reference_internal);
        s.def("getColumnIndexes", py::overload_cast<>(&Matrix::getColumnIndexes), py::return_value_policy::reference_internal);
        s.def("getValues",        py::overload_cast<>(&Matrix::getValues),        py::return_value_policy::reference_internal);
      s
         .def("getOffsets", []( const Segments& segments ) -> const typename Segments::OffsetsHolder& {
                  return segments.getOffsets();
            }, py::return_value_policy::reference_internal)
      ;
   }
};

template< typename MatrixRow >
void export_MatrixRow( py::module & m, const char* name )
template< typename Segments, typename Scope >
void export_Segments( Scope & s, const char* name )
{
    // guard against duplicate to-Python converters for the same type
    static bool defined = false;
    if( ! defined ) {
        py::class_< MatrixRow >( m, name )
            .def("setElement", &MatrixRow::setElement)
            .def("getElementColumn", &MatrixRow::getElementColumn, py::return_value_policy::reference_internal)
            .def("getElementValue", &MatrixRow::getElementValue, py::return_value_policy::reference_internal)
//            .def(py::self_ns::str(py::self_ns::self))
   auto segments = py::class_< Segments >( s, name )
      .def("getSegmentsCount", &Segments::getSegmentsCount)
      .def("getSegmentSize", &Segments::getSegmentSize)
      .def("getSize", &Segments::getSize)
      .def("getStorageSize", &Segments::getStorageSize)
      .def("getGlobalIndex", &Segments::getGlobalIndex)
      // FIXME: this does not compile
//      .def(py::self == py::self)
      // TODO: forElements, forAllElements, forSegments, forAllSegments, segmentsReduction, allReduction
   ;
        defined = true;
    }

   export_CSR< Segments >::e( segments );
}

template< typename Matrix >
void export_Matrix( py::module & m, const char* name )
{
    typename Matrix::MatrixRow (Matrix::* _getRow)(typename Matrix::IndexType) = &Matrix::getRow;

    using VectorType = TNL::Containers::Vector< typename Matrix::RealType, typename Matrix::DeviceType, typename Matrix::IndexType >;
   using RealType = typename Matrix::RealType;
   using DeviceType = typename Matrix::DeviceType;
   using IndexType = typename Matrix::IndexType;

    void (Matrix::* _getCompressedRowLengths)(typename Matrix::RowsCapacitiesTypeView) const = &Matrix::getCompressedRowLengths;
   using VectorType = TNL::Containers::Vector< RealType, DeviceType, IndexType >;
   using IndexVectorType = TNL::Containers::Vector< IndexType, DeviceType, IndexType >;

   auto matrix = py::class_< Matrix, TNL::Object >( m, name )
      .def(py::init<>())
@@ -67,49 +115,51 @@ void export_Matrix( py::module & m, const char* name )

      // Matrix
      .def("setDimensions",           &Matrix::setDimensions)
        .def("setCompressedRowLengths", &Matrix::setCompressedRowLengths)
        .def("getRowLength",            &Matrix::getRowLength)
        .def("getCompressedRowLengths", _getCompressedRowLengths)
      // TODO: export for more types
      .def("setLike", []( Matrix& matrix, const Matrix& other ) -> void {
               matrix.setLike( other );
         })
      .def("getAllocatedElementsCount",   &Matrix::getAllocatedElementsCount)
        .def("getNumberOfNonzeroMatrixElements", &Matrix::getNumberOfNonzeroMatrixElements)
      .def("getNonzeroElementsCount",     &Matrix::getNonzeroElementsCount)
      .def("reset",                       &Matrix::reset)
      .def("getRows",                     &Matrix::getRows)
      .def("getColumns",                  &Matrix::getColumns)
      // TODO: export for more types
      .def(py::self == py::self)
      .def(py::self != py::self)

      // SparseMatrix
      .def("setRowCapacities",         &Matrix::template setRowCapacities< IndexVectorType >)
      .def("getRowCapacities",         &Matrix::template getRowCapacities< IndexVectorType >)
      .def("getCompressedRowLengths",  &Matrix::template getCompressedRowLengths< IndexVectorType >)
      .def("getRowCapacity",           &Matrix::getRowCapacity)
      .def("getPaddingIndex",          &Matrix::getPaddingIndex)
      .def("getRow", []( Matrix& matrix, IndexType rowIdx ) -> typename Matrix::RowView {
               return matrix.getRow( rowIdx );
         })
      .def("getRow", []( const Matrix& matrix, IndexType rowIdx ) -> typename Matrix::ConstRowView {
               return matrix.getRow( rowIdx );
         })
      .def("setElement",               &Matrix::setElement)
      .def("addElement",               &Matrix::addElement)
        // setRow and addRow operate on pointers
        //.def("setRow",                  &Matrix::setRow)
        //.def("addRow",                  &Matrix::addRow)
      .def("getElement",               &Matrix::getElement)
        // TODO: operator== and operator!= are general and very slow

        // Sparse
        .def("getMaxRowLength",     &Matrix::getMaxRowLength)
        .def("getPaddingIndex",     &Matrix::getPaddingIndex)
        // TODO: this one is empty in the C++ code
//        .def("printStructure",      &Matrix::printStructure)

        // specific to each format, but with common interface
        .def("getRow",              _getRow)
      // TODO: reduceRows, reduceAllRows, forElements, forAllElements, forRows, forAllRows
      // TODO: export for more types
        .def("rowVectorProduct",    &Matrix::template rowVectorProduct< VectorType >)
      .def("vectorProduct",       &Matrix::template vectorProduct< VectorType, VectorType >)
      // TODO: these two don't work
      //.def("addMatrix",           &Matrix::addMatrix)
      //.def("getTransposition",    &Matrix::getTransposition)
      .def("performSORIteration", &Matrix::template performSORIteration< VectorType, VectorType >)
//        .def("assign",              &Matrix::operator=)
      // TODO: export for more types
      .def("assign", []( Matrix& matrix, const Matrix& other ) -> Matrix& {
               return matrix = other;
         })
    ;

    // export format-specific methods
    SpecificExports< Matrix >::exec( matrix );
      // accessors for internal vectors
      .def("getValues",        py::overload_cast<>(&Matrix::getValues),        py::return_value_policy::reference_internal)
      .def("getColumnIndexes", py::overload_cast<>(&Matrix::getColumnIndexes), py::return_value_policy::reference_internal)
      .def("getSegments",      py::overload_cast<>(&Matrix::getSegments),      py::return_value_policy::reference_internal)
   ;

    export_MatrixRow< typename Matrix::MatrixRow >( m, "MatrixRow" );
   export_Segments< typename Matrix::SegmentsType >( matrix, "Segments" );
}