From ff7c9054e2205808d109d3349efefd748b2f809b Mon Sep 17 00:00:00 2001 From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz> Date: Tue, 3 Dec 2019 12:58:55 +0100 Subject: [PATCH] All tests passed for SparseMatrix using Segments. --- src/TNL/Containers/Segments/CSR.h | 5 +++- src/TNL/Containers/Segments/CSR.hpp | 21 ++++++++++++++- src/TNL/Matrices/SparseMatrix.hpp | 27 +++++++++++++++---- .../Matrices/SparseMatrixTest_CSR_segments.h | 6 ++--- 4 files changed, 49 insertions(+), 10 deletions(-) diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index 92b4f39491..e3eff23427 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -79,8 +79,11 @@ class CSR void segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > - void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ); + void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + void save( File& file ) const; + + void load( File& file ); protected: diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index e2fd099aec..c99611958e 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -181,10 +181,29 @@ template< typename Device, template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > void CSR< Device, Index >:: -allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) +allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const { this->segmentsReduction( 0, this->getSize(), fetch, reduction, keeper, zero, args... ); } + +template< typename Device, + typename Index > +void +CSR< Device, Index >:: +save( File& file ) const +{ + file << this->offsets; +} + +template< typename Device, + typename Index > +void +CSR< Device, Index >:: +load( File& file ) +{ + file >> this->offsets; +} + } // namespace Segements } // namespace Conatiners } // namespace TNL diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp index 067f36001c..a43ddba829 100644 --- a/src/TNL/Matrices/SparseMatrix.hpp +++ b/src/TNL/Matrices/SparseMatrix.hpp @@ -597,7 +597,9 @@ void SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: save( File& file ) const { - + Matrix< RealType, DeviceType, IndexType >::save( file ); + file << this->columnIndexes; + this->segments.save( file ); } template< typename Real, @@ -610,7 +612,9 @@ void SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: load( File& file ) { - + Matrix< RealType, DeviceType, IndexType >::load( file ); + file >> this->columnIndexes; + this->segments.load( file ); } template< typename Real, @@ -623,7 +627,7 @@ void SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: save( const String& fileName ) const { - + Object::save( fileName ); } template< typename Real, @@ -636,7 +640,7 @@ void SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: load( const String& fileName ) { - + Object::load( fileName ); } template< typename Real, @@ -649,7 +653,20 @@ void SparseMatrix< Real, Segments, Device, Index, RealAllocator, IndexAllocator >:: print( std::ostream& str ) const { - + for( IndexType row = 0; row < this->getRows(); row++ ) + { + str <<"Row: " << row << " -> "; + const IndexType rowLength = this->segments.getSegmentSize( row ); + for( IndexType i = 0; i < rowLength; i++ ) + { + const IndexType globalIdx = this->segments.getGlobalIndex( row, i ); + const IndexType column = this->columnIndexes.getElement( globalIdx ); + if( column == this->getPaddingIndex() ) + break; + str << " Col:" << column << "->" << this->values.getElement( globalIdx ) << "\t"; + } + str << std::endl; + } } template< typename Real, diff --git a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h index 4443d7f6c0..a738af0e2a 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h +++ b/src/UnitTests/Matrices/SparseMatrixTest_CSR_segments.h @@ -122,11 +122,11 @@ TYPED_TEST( CSRMatrixTest, vectorProductTest ) test_VectorProduct< CSRMatrixType >(); } -/*TYPED_TEST( CSRMatrixTest, saveAndLoadTest ) +TYPED_TEST( CSRMatrixTest, saveAndLoadTest ) { using CSRMatrixType = typename TestFixture::CSRMatrixType; - test_SaveAndLoad< CSRMatrixType >( "test_SparseMatrixTest_CSR" ); + test_SaveAndLoad< CSRMatrixType >( "test_SparseMatrixTest_CSR_segments" ); } TYPED_TEST( CSRMatrixTest, printTest ) @@ -134,7 +134,7 @@ TYPED_TEST( CSRMatrixTest, printTest ) using CSRMatrixType = typename TestFixture::CSRMatrixType; test_Print< CSRMatrixType >(); -}*/ +} #endif -- GitLab