Commit 23cc2019 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Refactoring sparse matrices: setting row lengths via VectorView

parent 29d5c521
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -51,6 +51,8 @@ void export_Matrix( py::module & m, const char* name )

    using VectorType = TNL::Containers::Vector< typename Matrix::RealType, typename Matrix::DeviceType, typename Matrix::IndexType >;

    void (Matrix::* _getCompressedRowLengths)(typename Matrix::CompressedRowLengthsVector&) const = &Matrix::getCompressedRowLengths;

    auto matrix = py::class_< Matrix, TNL::Object >( m, name )
        .def(py::init<>())
        // overloads (defined in Object)
@@ -69,7 +71,7 @@ void export_Matrix( py::module & m, const char* name )
        .def("setDimensions",           &Matrix::setDimensions)
        .def("setCompressedRowLengths", &Matrix::setCompressedRowLengths)
        .def("getRowLength",            &Matrix::getRowLength)
        .def("getCompressedRowLengths", &Matrix::getCompressedRowLengths)
        .def("getCompressedRowLengths", _getCompressedRowLengths)
        // TODO: export for more types
        .def("setLike",                 &Matrix::template setLike< typename Matrix::RealType, typename Matrix::DeviceType, typename Matrix::IndexType >)
        .def("getNumberOfMatrixElements", &Matrix::getNumberOfMatrixElements)
+4 −3
Original line number Diff line number Diff line
@@ -83,6 +83,7 @@ public:
    typedef Device DeviceType;
    typedef Index IndexType;
    typedef typename Sparse< RealType, DeviceType, IndexType >::CompressedRowLengthsVector CompressedRowLengthsVector;
    typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
    typedef AdEllpack< Real, Device, Index > ThisType;
    typedef AdEllpack< Real, Devices::Host, Index > HostType;
    typedef AdEllpack< Real, Devices::Cuda, Index > CudaType;
@@ -93,7 +94,7 @@ public:

    String getTypeVirtual() const;

    void setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths );
    void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );

    IndexType getWarp( const IndexType row ) const;

@@ -155,7 +156,7 @@ public:
    void print( std::ostream& str ) const;

    bool balanceLoad( const RealType average,
                      const CompressedRowLengthsVector& rowLengths,
                      ConstCompressedRowLengthsVectorView rowLengths,
                      warpList* list );

    void computeWarps( const IndexType SMs,
@@ -166,7 +167,7 @@ public:

    void performRowTest();

    void performRowLengthsTest( const CompressedRowLengthsVector& rowLengths );
    void performRowLengthsTest( ConstCompressedRowLengthsVectorView rowLengths );

    IndexType getTotalLoad() const;

+4 −4
Original line number Diff line number Diff line
@@ -182,7 +182,7 @@ template< typename Real,
          typename Index >
void
AdEllpack< Real, Device, Index >::
setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths )
setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths )
{
    TNL_ASSERT( this->getRows() > 0, );
    TNL_ASSERT( this->getColumns() > 0, );
@@ -250,7 +250,7 @@ Index AdEllpack< Real, Device, Index >::getTotalLoad() const
template< typename Real,
          typename Device,
          typename Index >
void AdEllpack< Real, Device, Index >::performRowLengthsTest( const CompressedRowLengthsVector& rowLengths )
void AdEllpack< Real, Device, Index >::performRowLengthsTest( ConstCompressedRowLengthsVectorView rowLengths )
{
    bool found = false;
    for( IndexType row = 0; row < this->getRows(); row++ )
@@ -694,7 +694,7 @@ template< typename Real,
          typename Device,
          typename Index >
bool AdEllpack< Real, Device, Index >::balanceLoad( const RealType average,
                                                             const CompressedRowLengthsVector& rowLengths,
                                                    ConstCompressedRowLengthsVectorView rowLengths,
                                                    warpList* list )
{
    IndexType offset, rowOffset, localLoad, reduceMap[ 32 ];
+2 −1
Original line number Diff line number Diff line
@@ -36,6 +36,7 @@ public:
	typedef Device DeviceType;
	typedef Index IndexType;
	typedef typename Sparse< RealType, DeviceType, IndexType >::CompressedRowLengthsVector CompressedRowLengthsVector;
   typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
	typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
	typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
	typedef BiEllpack< Real, Device, Index > ThisType;
@@ -51,7 +52,7 @@ public:
	void setDimensions( const IndexType rows,
	                    const IndexType columns );

	void setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths );
   void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );

	IndexType getRowLength( const IndexType row ) const;

+2 −1
Original line number Diff line number Diff line
@@ -27,6 +27,7 @@ public:
	typedef Device DeviceType;
	typedef Index IndexType;
	typedef typename Sparse< RealType, DeviceType, IndexType >::CompressedRowLengthsVector CompressedRowLengthsVector;
   typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
	typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
	typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
	typedef BiEllpackSymmetric< Real, Device, Index > ThisType;
@@ -41,7 +42,7 @@ public:

	void setDimensions( const IndexType rows, const IndexType columns );

	void setCompressedRowLengths( const CompressedRowLengthsVector& rowLengths );
   void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );

	IndexType getRowLength( const IndexType row ) const;

Loading