From 5d79ecbc2e7470ae246cd568b3b389a37514111d Mon Sep 17 00:00:00 2001
From: Tomas Oberhuber <tomas.oberhuber@fjfi.cvut.cz>
Date: Fri, 7 Feb 2020 23:02:07 +0100
Subject: [PATCH] Added SparseMatrix constructors with initializer lists.

---
 src/TNL/Matrices/SparseMatrix.h           |  11 +++
 src/TNL/Matrices/SparseMatrix.hpp         |  42 ++++++++
 src/UnitTests/Matrices/SparseMatrixTest.h | 115 ++++++++++++++++------
 3 files changed, 136 insertions(+), 32 deletions(-)

diff --git a/src/TNL/Matrices/SparseMatrix.h b/src/TNL/Matrices/SparseMatrix.h
index 7072ce3c4f..15f5857168 100644
--- a/src/TNL/Matrices/SparseMatrix.h
+++ b/src/TNL/Matrices/SparseMatrix.h
@@ -74,6 +74,17 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
                     const RealAllocatorType& realAllocator = RealAllocatorType(),
                     const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
 
+      SparseMatrix( const std::initializer_list< std::tuple< IndexType > >& rowCapacities,
+                    const IndexType columns,
+                    const RealAllocatorType& realAllocator = RealAllocatorType(),
+                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
+
+      SparseMatrix( const IndexType rows,
+                    const IndexType columns,
+                    const std::initializer_list< std::tuple< IndexType, IndexType, RealType > >& data,
+                    const RealAllocatorType& realAllocator = RealAllocatorType(),
+                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
+
       ViewType getView() const; // TODO: remove const
 
       ConstViewType getConstView() const;
diff --git a/src/TNL/Matrices/SparseMatrix.hpp b/src/TNL/Matrices/SparseMatrix.hpp
index 9924434346..938d883af9 100644
--- a/src/TNL/Matrices/SparseMatrix.hpp
+++ b/src/TNL/Matrices/SparseMatrix.hpp
@@ -73,6 +73,48 @@ SparseMatrix( const IndexType rows,
 {
 }
 
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename, typename > class Segments,
+          typename RealAllocator,
+          typename IndexAllocator >
+SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
+SparseMatrix( const std::initializer_list< std::tuple< IndexType > >& rowCapacities,
+              const IndexType columns,
+              const RealAllocatorType& realAllocator,
+              const IndexAllocatorType& indexAllocator )
+: BaseType( rowCapacities.size(), columns, realAllocator ), columnIndexes( indexAllocator )
+{
+   this->setCompressedRowLengths( RowCapacitiesType ( rowCapacities ) );
+}
+
+template< typename Real,
+          typename Device,
+          typename Index,
+          typename MatrixType,
+          template< typename, typename, typename > class Segments,
+          typename RealAllocator,
+          typename IndexAllocator >
+SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
+SparseMatrix( const IndexType rows,
+              const IndexType columns,
+              const std::initializer_list< std::tuple< IndexType, IndexType, RealType > >& data,
+              const RealAllocatorType& realAllocator,
+              const IndexAllocatorType& indexAllocator )
+: BaseType( rows, columns, realAllocator ), columnIndexes( indexAllocator )
+{
+   Containers::Vector< IndexType, Devices::Host, IndexType > rowCapacities( rows, 0 );
+   for( const auto& i : data )
+      rowCapacities[ std::get< 0 >( i ) ]++;
+   SparseMatrix< Real, Devices::Host, Index, MatrixType, Segments > hostMatrix( rows, columns );
+   hostMatrix.setCompressedRowLength( rowCapacities );
+   for( const auto& i : data )
+      hostMatrix.setElement( std::get< 0 >( i ), std::get< 1 >( i ), std::get< 2 >( i ) );
+   ( *this ) = hostMatrix;
+}
+
 template< typename Real,
           typename Device,
           typename Index,
diff --git a/src/UnitTests/Matrices/SparseMatrixTest.h b/src/UnitTests/Matrices/SparseMatrixTest.h
index 04a9b065fb..26b15fafd2 100644
--- a/src/UnitTests/Matrices/SparseMatrixTest.h
+++ b/src/UnitTests/Matrices/SparseMatrixTest.h
@@ -36,6 +36,79 @@ void cuda_test_GetType()
    std::cerr << "This test has not been implemented properly yet.\n" << std::endl;
 }
 
+template< typename Matrix >
+void test_Constructors()
+{
+   using RealType = typename Matrix::RealType;
+   using DeviceType = typename Matrix::DeviceType;
+   using IndexType = typename Matrix::IndexType;
+
+   Matrix m1( 5, 6 );
+   EXPECT_EQ( m1.getRows(), 5 );
+   EXPECT_EQ( m1.getColumns(), 6 );
+
+   Matrix m2( {1, 2, 2, 2, 1 }, 5 );
+   typename Matrix::RowCapacitiesType v1, v2{ 1, 2, 2, 2, 1 }; 
+   m2.getCompressedRowLength( v1 );
+   EXPECT_EQ( v1, v2 );
+
+   /*
+    * Sets up the following 6x5 sparse matrix:
+    *
+    *    /  1  2  3  0  0 \
+    *    |  0  4  5  6  0 |
+    *    |  0  0  7  8  9 |
+    *    | 10  0  0  0  0 |
+    *    |  0 11  0  0  0 |
+    *    \  0  0  0 12  0 /
+    */
+
+   Matrix m3( 6, 5, {
+      { 0, 0,  1.0 }, { 0, 1, 2.0 }, { 0, 2, 3.0 },
+      { 1, 1,  4.0 }, { 1, 2, 5.0 }, { 1, 3, 6.0 },
+      { 2, 2,  7.0 }, { 2, 3, 8.0 }, { 2, 4, 9.0 },
+      { 3, 0, 10.0 },
+      { 4, 1, 11.0 },
+      { 5, 3, 12.0 } } );
+
+   // Check the set elements
+   EXPECT_EQ( m3.getElement( 0, 0 ),  1 );
+   EXPECT_EQ( m3.getElement( 0, 1 ),  2 );
+   EXPECT_EQ( m3.getElement( 0, 2 ),  3 );
+   EXPECT_EQ( m3.getElement( 0, 3 ),  0 );
+   EXPECT_EQ( m3.getElement( 0, 4 ),  0 );
+
+   EXPECT_EQ( m3.getElement( 1, 0 ),  0 );
+   EXPECT_EQ( m3.getElement( 1, 1 ),  4 );
+   EXPECT_EQ( m3.getElement( 1, 2 ),  5 );
+   EXPECT_EQ( m3.getElement( 1, 3 ),  6 );
+   EXPECT_EQ( m3.getElement( 1, 4 ),  0 );
+
+   EXPECT_EQ( m3.getElement( 2, 0 ),  0 );
+   EXPECT_EQ( m3.getElement( 2, 1 ),  0 );
+   EXPECT_EQ( m3.getElement( 2, 2 ),  7 );
+   EXPECT_EQ( m3.getElement( 2, 3 ),  8 );
+   EXPECT_EQ( m3.getElement( 2, 4 ),  9 );
+
+   EXPECT_EQ( m3.getElement( 3, 0 ), 10 );
+   EXPECT_EQ( m3.getElement( 3, 1 ),  0 );
+   EXPECT_EQ( m3.getElement( 3, 2 ),  0 );
+   EXPECT_EQ( m3.getElement( 3, 3 ),  0 );
+   EXPECT_EQ( m3.getElement( 3, 4 ),  0 );
+
+   EXPECT_EQ( m3.getElement( 4, 0 ),  0 );
+   EXPECT_EQ( m3.getElement( 4, 1 ), 11 );
+   EXPECT_EQ( m3.getElement( 4, 2 ),  0 );
+   EXPECT_EQ( m3.getElement( 4, 3 ),  0 );
+   EXPECT_EQ( m3.getElement( 4, 4 ),  0 );
+
+   EXPECT_EQ( m3.getElement( 5, 0 ),  0 );
+   EXPECT_EQ( m3.getElement( 5, 1 ),  0 );
+   EXPECT_EQ( m3.getElement( 5, 2 ),  0 );
+   EXPECT_EQ( m3.getElement( 5, 3 ), 12 );
+   EXPECT_EQ( m3.getElement( 5, 4 ),  0 );
+}
+
 template< typename Matrix >
 void test_SetDimensions()
 {
@@ -64,9 +137,7 @@ void test_SetCompressedRowLengths()
    const IndexType cols = 11;
 
    Matrix m( rows, cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths;
-   rowLengths.setSize( rows );
-   rowLengths = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths( rows, 3 );
 
    IndexType rowLength = 1;
    for( IndexType i = 2; i < rows; i++ )
@@ -592,8 +663,7 @@ void test_AddElement()
    const IndexType cols = 5;
 
    Matrix m( rows, cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths( rows );
-   rowLengths = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths( rows, 3 );
    m.setCompressedRowLengths( rowLengths );
 
    RealType value = 1;
@@ -742,12 +812,7 @@ void test_VectorProduct()
    Matrix m_1;
    m_1.reset();
    m_1.setDimensions( m_rows_1, m_cols_1 );
-   typename Matrix::CompressedRowLengthsVector rowLengths_1;
-   rowLengths_1.setSize( m_rows_1 );
-   rowLengths_1.setElement( 0, 1 );
-   rowLengths_1.setElement( 1, 2 );
-   rowLengths_1.setElement( 2, 1 );
-   rowLengths_1.setElement( 3, 1 );
+   typename Matrix::CompressedRowLengthsVector rowLengths_1{ 1, 2, 1, 1 };
    m_1.setCompressedRowLengths( rowLengths_1 );
 
    RealType value_1 = 1;
@@ -770,10 +835,8 @@ void test_VectorProduct()
    for( IndexType j = 0; j < outVector_1.getSize(); j++ )
        outVector_1.setElement( j, 0 );
 
-
    m_1.vectorProduct( inVector_1, outVector_1 );
 
-
    EXPECT_EQ( outVector_1.getElement( 0 ),  2 );
    EXPECT_EQ( outVector_1.getElement( 1 ), 10 );
    EXPECT_EQ( outVector_1.getElement( 2 ),  8 );
@@ -793,21 +856,18 @@ void test_VectorProduct()
 
    Matrix m_2( m_rows_2, m_cols_2 );
    typename Matrix::CompressedRowLengthsVector rowLengths_2{ 3, 1, 3, 1 };
-   /*rowLengths_2 = 3;
-   rowLengths_2.setElement( 1, 1 );
-   rowLengths_2.setElement( 3, 1 );*/
    m_2.setCompressedRowLengths( rowLengths_2 );
 
    RealType value_2 = 1;
-   for( IndexType i = 0; i < 3; i++ )   // 0th row
+   for( IndexType i = 0; i < 3; i++ )      // 0th row
       m_2.setElement( 0, i, value_2++ );
 
    m_2.setElement( 1, 3, value_2++ );      // 1st row
 
-   for( IndexType i = 0; i < 3; i++ )   // 2nd row
+   for( IndexType i = 0; i < 3; i++ )      // 2nd row
       m_2.setElement( 2, i, value_2++ );
 
-   for( IndexType i = 1; i < 2; i++ )       // 3rd row
+   for( IndexType i = 1; i < 2; i++ )      // 3rd row
       m_2.setElement( 3, i, value_2++ );
 
    VectorType inVector_2;
@@ -891,11 +951,6 @@ void test_VectorProduct()
 
    Matrix m_4( m_rows_4, m_cols_4 );
    typename Matrix::CompressedRowLengthsVector rowLengths_4{ 4, 4, 5, 4, 4, 4, 5, 5 };
-   /*rowLengths_4.setSize( m_rows_4 );
-   rowLengths_4.setValue( 4 );
-   rowLengths_4.setElement( 2, 5 );
-   rowLengths_4.setElement( 6, 5 );
-   rowLengths_4.setElement( 7, 5 );*/
    m_4.setCompressedRowLengths( rowLengths_4 );
 
    RealType value_4 = 1;
@@ -1137,8 +1192,7 @@ void test_PerformSORIteration()
    const IndexType m_cols = 4;
 
    Matrix m( m_rows, m_cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows );
-   rowLengths = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows, 3 );
    m.setCompressedRowLengths( rowLengths );
 
    m.setElement( 0, 0, 4.0 );        // 0th row
@@ -1210,8 +1264,7 @@ void test_SaveAndLoad( const char* filename )
    const IndexType m_cols = 4;
 
    Matrix savedMatrix( m_rows, m_cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows );
-   rowLengths = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows, 3 );
    savedMatrix.setCompressedRowLengths( rowLengths );
 
    RealType value = 1;
@@ -1230,8 +1283,7 @@ void test_SaveAndLoad( const char* filename )
    ASSERT_NO_THROW( savedMatrix.save( filename ) );
 
    Matrix loadedMatrix( m_rows, m_cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths2( m_rows );
-   rowLengths2 = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths2( m_rows, 3 );
    loadedMatrix.setCompressedRowLengths( rowLengths2 );
 
    ASSERT_NO_THROW( loadedMatrix.load( filename ) );
@@ -1300,8 +1352,7 @@ void test_Print()
    const IndexType m_cols = 4;
 
    Matrix m( m_rows, m_cols );
-   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows );
-   rowLengths = 3;
+   typename Matrix::CompressedRowLengthsVector rowLengths( m_rows, 3 );
    m.setCompressedRowLengths( rowLengths );
 
    RealType value = 1;
-- 
GitLab