From ff7c9054e2205808d109d3349efefd748b2f809b Mon Sep 17 00:00:00 2001
From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz>
Date: Tue, 3 Dec 2019 12:58:55 +0100
Subject: [PATCH] All tests passed for SparseMatrix using Segments.

---
 src/TNL/Containers/Segments/CSR.h             |  5 +++-
 src/TNL/Containers/Segments/CSR.hpp           | 21 ++++++++++++++-
 src/TNL/Matrices/SparseMatrix.hpp             | 27 +++++++++++++++----
 .../Matrices/SparseMatrixTest_CSR_segments.h  |  6 ++---
 4 files changed, 49 insertions(+), 10 deletions(-)

diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h
index 92b4f39491..e3eff23427 100644
--- a/src/TNL/Containers/Segments/CSR.h
+++ b/src/TNL/Containers/Segments/CSR.h
@@ -79,8 +79,11 @@ class CSR
       void segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
 
       template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
-      void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args );
+      void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
 
+      void save( File& file ) const;
+
+      void load( File& file );
 
    protected:
 
diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp
index e2fd099aec..c99611958e 100644
--- a/src/TNL/Containers/Segments/CSR.hpp
+++ b/src/TNL/Containers/Segments/CSR.hpp
@@ -181,10 +181,29 @@ template< typename Device,
    template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
 void
 CSR< Device, Index >::
-allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args )
+allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
    this->segmentsReduction( 0, this->getSize(), fetch, reduction, keeper, zero, args... );
 }
+
+template< typename Device,
+          typename Index >
+void
+CSR< Device, Index >::
+save( File& file ) const
+{
+   file << this->offsets;
+}
+
+template< typename Device,
+          typename Index >
+void
+CSR< Device, Index >::
+load( File& file )
+{
+   file >> this->offsets;
+}
+
       } // namespace Segements
    }  // namespace Conatiners
 } // namespace TNL
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index 067f36001c..a43ddba829 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -597,7 +597,9 @@ void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 save( File& file ) const
 {
-
+   Matrix< RealType, DeviceType, IndexType >::save( file );
+   file << this->columnIndexes;
+   this->segments.save( file );
 }
 
 template< typename Real,
@@ -610,7 +612,9 @@ void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 load( File& file )
 {
-
+   Matrix< RealType, DeviceType, IndexType >::load( file );
+   file >> this->columnIndexes;
+   this->segments.load( file );
 }
 
 template< typename Real,
@@ -623,7 +627,7 @@ void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 save( const String& fileName ) const
 {
-
+   Object::save( fileName );
 }
 
 template< typename Real,
@@ -636,7 +640,7 @@ void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 load( const String& fileName )
 {
-
+   Object::load( fileName );
 }
 
 template< typename Real,
@@ -649,7 +653,20 @@ void
 SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >::
 print( std::ostream& str ) const
 {
-
+   for( IndexType row = 0; row < this->getRows(); row++ )
+   {
+      str <<"Row: " << row << " -> ";
+      const IndexType rowLength = this->segments.getSegmentSize( row );
+      for( IndexType i = 0; i < rowLength; i++ )
+      {
+         const IndexType globalIdx = this->segments.getGlobalIndex( row, i );
+         const IndexType column = this->columnIndexes.getElement( globalIdx );
+         if( column == this->getPaddingIndex() )
+            break;
+         str << " Col:" << column << "->" << this->values.getElement( globalIdx ) << "\t";
+      }
+      str << std::endl;
+   }
 }
 
 template< typename Real,
diff --git a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h
index 4443d7f6c0..a738af0e2a 100644
--- a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h
+++ b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h
@@ -122,11 +122,11 @@ TYPED_TEST( CSRMatrixTest, vectorProductTest )
     test_VectorProduct< CSRMatrixType >();
 }
 
-/*TYPED_TEST( CSRMatrixTest, saveAndLoadTest )
+TYPED_TEST( CSRMatrixTest, saveAndLoadTest )
 {
     using CSRMatrixType = typename TestFixture::CSRMatrixType;
 
-    test_SaveAndLoad< CSRMatrixType >( "test_SparseMatrixTest_CSR" );
+    test_SaveAndLoad< CSRMatrixType >( "test_SparseMatrixTest_CSR_segments" );
 }
 
 TYPED_TEST( CSRMatrixTest, printTest )
@@ -134,7 +134,7 @@ TYPED_TEST( CSRMatrixTest, printTest )
     using CSRMatrixType = typename TestFixture::CSRMatrixType;
 
     test_Print< CSRMatrixType >();
-}*/
+}
 
 #endif
 
-- 
GitLab