From 803fae220b459469fc5e5da22244706382a86c5a Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Jakub=20Klinkovsk=C3=BD?= <klinkjak@fjfi.cvut.cz>
Date: Fri, 23 Dec 2016 10:57:12 +0100
Subject: [PATCH] Fixed UmfpackWrapper, added getters for raw data to CSR
 matrix

---
 src/TNL/Matrices/CSR.h                       | 37 +++++++++++++++++---
 src/TNL/Solvers/Linear/UmfpackWrapper.h      | 14 ++++----
 src/TNL/Solvers/Linear/UmfpackWrapper_impl.h | 30 ++++++++--------
 3 files changed, 54 insertions(+), 27 deletions(-)

diff --git a/src/TNL/Matrices/CSR.h b/src/TNL/Matrices/CSR.h
index 5ad4327a63..7c249944bf 100644
--- a/src/TNL/Matrices/CSR.h
+++ b/src/TNL/Matrices/CSR.h
@@ -199,6 +199,38 @@ class CSR : public Sparse< Real, Device, Index >
                            int gridIdx ) const;
 #endif
 
+   // The following getters allow us to interface TNL with external C-like
+   // libraries such as UMFPACK or SuperLU, which need the raw data.
+   Index* getRowPointers()
+   {
+       return this->rowPointers.getData();
+   }
+
+   const Index* getRowPointers() const
+   {
+       return this->rowPointers.getData();
+   }
+
+   Index* getColumnIndexes()
+   {
+       return this->columnIndexes.getData();
+   }
+
+   const Index* getColumnIndexes() const
+   {
+       return this->columnIndexes.getData();
+   }
+
+   Real* getValues()
+   {
+       return this->values.getData();
+   }
+
+   const Real* getValues() const
+   {
+       return this->values.getData();
+   }
+
    protected:
 
    Containers::Vector< Index, Device, Index > rowPointers;
@@ -210,11 +242,6 @@ class CSR : public Sparse< Real, Device, Index >
    typedef CSRDeviceDependentCode< DeviceType > DeviceDependentCode;
    friend class CSRDeviceDependentCode< DeviceType >;
    friend class tnlCusparseCSR< RealType >;
-#ifdef HAVE_UMFPACK
-    template< typename Matrix, typename Preconditioner >
-    friend class UmfpackWrapper;
-#endif
-
 };
 
 } // namespace Matrices
diff --git a/src/TNL/Solvers/Linear/UmfpackWrapper.h b/src/TNL/Solvers/Linear/UmfpackWrapper.h
index 043e46c8d7..e2c59d9f8a 100644
--- a/src/TNL/Solvers/Linear/UmfpackWrapper.h
+++ b/src/TNL/Solvers/Linear/UmfpackWrapper.h
@@ -9,7 +9,7 @@
 #include <TNL/Object.h>
 #include <TNL/Config/ConfigDescription.h>
 #include <TNL/Matrices/CSR.h>
-#include <TNL/Solvers/preconditioners/Dummy.h>
+#include <TNL/Solvers/Linear/Preconditioners/Dummy.h>
 #include <TNL/Solvers/IterativeSolver.h>
 #include <TNL/Solvers/Linear/LinearResidueGetter.h>
 
@@ -25,16 +25,16 @@ struct is_csr_matrix
 };
 
 template< typename Real, typename Device, typename Index >
-struct is_csr_matrix< CSR< Real, Device, Index > >
+struct is_csr_matrix< Matrices::CSR< Real, Device, Index > >
 {
     static const bool value = true;
 };
 
 
 template< typename Matrix,
-          typename Preconditioner = Dummy< typename Matrix :: RealType,
-                                           typename Matrix :: DeviceType,
-                                           typename Matrix :: IndexType> >
+          typename Preconditioner = Preconditioners::Dummy< typename Matrix :: RealType,
+                                                            typename Matrix :: DeviceType,
+                                                            typename Matrix :: IndexType> >
 class UmfpackWrapper
     : public Object,
       // just to ensure the same interface as other linear solvers
@@ -88,7 +88,7 @@ public:
 
 
 template< typename Preconditioner >
-class UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >
+class UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >
     : public Object,
       // just to ensure the same interface as other linear solvers
       public IterativeSolver< double, int >
@@ -97,7 +97,7 @@ public:
     typedef double RealType;
     typedef int IndexType;
     typedef Devices::Host DeviceType;
-    typedef CSR< double, Devices::Host, int > MatrixType;
+    typedef Matrices::CSR< double, Devices::Host, int > MatrixType;
     typedef Preconditioner PreconditionerType;
     typedef SharedPointer< const MatrixType, DeviceType, true > MatrixPointer;
     typedef SharedPointer< const PreconditionerType, DeviceType, true > PreconditionerPointer;
diff --git a/src/TNL/Solvers/Linear/UmfpackWrapper_impl.h b/src/TNL/Solvers/Linear/UmfpackWrapper_impl.h
index ab05e75151..a0a44b6e22 100644
--- a/src/TNL/Solvers/Linear/UmfpackWrapper_impl.h
+++ b/src/TNL/Solvers/Linear/UmfpackWrapper_impl.h
@@ -11,13 +11,13 @@ namespace Solvers {
 namespace Linear {   
 
 template< typename Preconditioner >
-UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 UmfpackWrapper()
 {}
 
 template< typename Preconditioner >
 void
-UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 configSetup( Config::ConfigDescription& config,
              const String& prefix )
 {
@@ -25,7 +25,7 @@ configSetup( Config::ConfigDescription& config,
 
 template< typename Preconditioner >
 bool
-UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 setup( const Config::ParameterContainer& parameters,
        const String& prefix )
 {
@@ -33,14 +33,14 @@ setup( const Config::ParameterContainer& parameters,
 }
 
 template< typename Preconditioner >
-void UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+void UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 setMatrix( const MatrixPointer& matrix )
 {
     this -> matrix = matrix;
 }
 
 template< typename Preconditioner >
-void UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+void UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 setPreconditioner( const PreconditionerPointer& preconditioner )
 {
     this -> preconditioner = preconditioner;
@@ -49,7 +49,7 @@ setPreconditioner( const PreconditionerPointer& preconditioner )
 
 template< typename Preconditioner >
     template< typename Vector, typename ResidueGetter >
-bool UmfpackWrapper< CSR< double, Devices::Host, int >, Preconditioner >::
+bool UmfpackWrapper< Matrices::CSR< double, Devices::Host, int >, Preconditioner >::
 solve( const Vector& b,
        Vector& x )
 {
@@ -77,9 +77,9 @@ solve( const Vector& b,
 
     // symbolic reordering of the sparse matrix
     status = umfpack_di_symbolic( size, size,
-                                  matrix->rowPointers.getData(),
-                                  matrix->columnIndexes.getData(),
-                                  matrix->values.getData(),
+                                  matrix->getRowPointers(),
+                                  matrix->getColumnIndexes(),
+                                  matrix->getValues(),
                                   &Symbolic, Control, Info );
     if( status != UMFPACK_OK ) {
         std::cerr << "error: symbolic reordering failed" << std::endl;
@@ -87,9 +87,9 @@ solve( const Vector& b,
     }
 
     // numeric factorization
-    status = umfpack_di_numeric( matrix->rowPointers.getData(),
-                                 matrix->columnIndexes.getData(),
-                                 matrix->values.getData(),
+    status = umfpack_di_numeric( matrix->getRowPointers(),
+                                 matrix->getColumnIndexes(),
+                                 matrix->getValues(),
                                  Symbolic, &Numeric, Control, Info );
     if( status != UMFPACK_OK ) {
         std::cerr << "error: numeric factorization failed" << std::endl;
@@ -98,9 +98,9 @@ solve( const Vector& b,
 
     // solve with specified right-hand-side
     status = umfpack_di_solve( system_type,
-                               matrix->rowPointers.getData(),
-                               matrix->columnIndexes.getData(),
-                               matrix->values.getData(),
+                               matrix->getRowPointers(),
+                               matrix->getColumnIndexes(),
+                               matrix->getValues(),
                                x.getData(),
                                b.getData(),
                                Numeric, Control, Info );
-- 
GitLab