Commit 803fae22 authored by Jakub Klinkovský's avatar Jakub Klinkovský
Browse files

Fixed UmfpackWrapper, added getters for raw data to CSR matrix

parent a16d4530
Loading
Loading
Loading
Loading
+32 −5
Original line number Diff line number Diff line
@@ -199,6 +199,38 @@ class CSR : public Sparse< Real, Device, Index >
                           int gridIdx ) const;
#endif

   // The following getters allow us to interface TNL with external C-like
   // libraries such as UMFPACK or SuperLU, which need the raw data.
   Index* getRowPointers()
   {
       return this->rowPointers.getData();
   }

   const Index* getRowPointers() const
   {
       return this->rowPointers.getData();
   }

   Index* getColumnIndexes()
   {
       return this->columnIndexes.getData();
   }

   const Index* getColumnIndexes() const
   {
       return this->columnIndexes.getData();
   }

   Real* getValues()
   {
       return this->values.getData();
   }

   const Real* getValues() const
   {
       return this->values.getData();
   }

   protected:

   Containers::Vector< Index, Device, Index > rowPointers;
@@ -210,11 +242,6 @@ class CSR : public Sparse< Real, Device, Index >
   typedef CSRDeviceDependentCode< DeviceType > DeviceDependentCode;
   friend class CSRDeviceDependentCode< DeviceType >;
   friend class tnlCusparseCSR< RealType >;
#ifdef HAVE_UMFPACK
    template< typename Matrix, typename Preconditioner >
    friend class UmfpackWrapper;
#endif

};

} // namespace Matrices
+7 −7
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@
#include <TNL/Object.h>
#include <TNL/Config/ConfigDescription.h>
#include <TNL/Matrices/CSR.h>
#include <TNL/Solvers/preconditioners/Dummy.h>
#include <TNL/Solvers/Linear/Preconditioners/Dummy.h>
#include <TNL/Solvers/IterativeSolver.h>
#include <TNL/Solvers/Linear/LinearResidueGetter.h>

@@ -25,14 +25,14 @@ struct is_csr_matrix
};

template< typename Real, typename Device, typename Index >
struct is_csr_matrix< CSR< Real, Device, Index > >
struct is_csr_matrix< Matrices::CSR< Real, Device, Index > >
{
    static const bool value = true;
};


template< typename Matrix,
          typename Preconditioner = Dummy< typename Matrix :: RealType,
          typename Preconditioner = Preconditioners::Dummy< typename Matrix :: RealType,
                                                            typename Matrix :: DeviceType,
                                                            typename Matrix :: IndexType> >
class UmfpackWrapper
@@ -88,7 +88,7 @@ public:


template< typename Preconditioner >
class UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >
class UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >
    : public Object,
      // just to ensure the same interface as other linear solvers
      public IterativeSolver< double, int >
@@ -97,7 +97,7 @@ public:
    typedef double RealType;
    typedef int IndexType;
    typedef Devices::Host DeviceType;
    typedef CSR< double, Devices::Host, int > MatrixType;
    typedef Matrices::CSR< double, Devices::Host, int > MatrixType;
    typedef Preconditioner PreconditionerType;
    typedef SharedPointer< const MatrixType, DeviceType, true > MatrixPointer;
    typedef SharedPointer< const PreconditionerType, DeviceType, true > PreconditionerPointer;
+15 −15
Original line number Diff line number Diff line
@@ -11,13 +11,13 @@ namespace Solvers {
namespace Linear {   

template< typename Preconditioner >
UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
UmfpackWrapper()
{}

template< typename Preconditioner >
void
UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
configSetup( Config::ConfigDescription& config,
             const String& prefix )
{
@@ -25,7 +25,7 @@ configSetup( Config::ConfigDescription& config,

template< typename Preconditioner >
bool
UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
setup( const Config::ParameterContainer& parameters,
       const String& prefix )
{
@@ -33,14 +33,14 @@ setup( const Config::ParameterContainer& parameters,
}

template< typename Preconditioner >
void UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
void UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
setMatrix( const MatrixPointer& matrix )
{
    this -> matrix = matrix;
}

template< typename Preconditioner >
void UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
void UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
setPreconditioner( const PreconditionerPointer& preconditioner )
{
    this -> preconditioner = preconditioner;
@@ -49,7 +49,7 @@ setPreconditioner( const PreconditionerPointer& preconditioner )

template< typename Preconditioner >
    template< typename Vector, typename ResidueGetter >
bool UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
bool UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
solve( const Vector& b,
       Vector& x )
{
@@ -77,9 +77,9 @@ solve( const Vector& b,

    // symbolic reordering of the sparse matrix
    status = umfpack_di_symbolic( size, size,
                                  matrix->rowPointers.getData(),
                                  matrix->columnIndexes.getData(),
                                  matrix->values.getData(),
                                  matrix->getRowPointers(),
                                  matrix->getColumnIndexes(),
                                  matrix->getValues(),
                                  &Symbolic, Control, Info );
    if( status != UMFPACK_OK ) {
        std::cerr << "error: symbolic reordering failed" << std::endl;
@@ -87,9 +87,9 @@ solve( const Vector& b,
    }

    // numeric factorization
    status = umfpack_di_numeric( matrix->rowPointers.getData(),
                                 matrix->columnIndexes.getData(),
                                 matrix->values.getData(),
    status = umfpack_di_numeric( matrix->getRowPointers(),
                                 matrix->getColumnIndexes(),
                                 matrix->getValues(),
                                 Symbolic, &Numeric, Control, Info );
    if( status != UMFPACK_OK ) {
        std::cerr << "error: numeric factorization failed" << std::endl;
@@ -98,9 +98,9 @@ solve( const Vector& b,

    // solve with specified right-hand-side
    status = umfpack_di_solve( system_type,
                               matrix->rowPointers.getData(),
                               matrix->columnIndexes.getData(),
                               matrix->values.getData(),
                               matrix->getRowPointers(),
                               matrix->getColumnIndexes(),
                               matrix->getValues(),
                               x.getData(),
                               b.getData(),
                               Numeric, Control, Info );