Commit ea7bb4c3 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed pytnl export of the CSR matrix format

parent c71cdb6d
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -16,8 +16,7 @@ using SE_cuda = TNL::Benchmarks::SpMV::ReferenceFormats::Legacy::SlicedEllpack<

void export_SparseMatrices( py::module & m )
{
    // TODO: This stop working after adding template parameter KernelType to Legacy::CSR
    //export_Matrix< CSR_host >( m, "CSR" );
    export_Matrix< CSR_host >( m, "CSR" );
    export_Matrix< E_host   >( m, "Ellpack" );
    export_Matrix< SE_host  >( m, "SlicedEllpack" );

+4 −2
Original line number Diff line number Diff line
@@ -71,7 +71,9 @@ void export_Matrix( py::module & m, const char* name )
        .def("getRowLength",            &Matrix::getRowLength)
        .def("getCompressedRowLengths", _getCompressedRowLengths)
        // TODO: export for more types
        .def("setLike",                 &Matrix::template setLike< typename Matrix::RealType, typename Matrix::DeviceType, typename Matrix::IndexType >)
        .def("setLike", []( Matrix& matrix, const Matrix& other ) -> void {
                matrix.setLike( other );
            })
        .def("getAllocatedElementsCount", &Matrix::getAllocatedElementsCount)
        .def("getNumberOfNonzeroMatrixElements", &Matrix::getNumberOfNonzeroMatrixElements)
        .def("reset",                   &Matrix::reset)