Commit c869e976 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixing sparse matrix constructors with initializer list together with unit tests.

parent 5d79ecbc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -74,7 +74,7 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
                    const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );

      SparseMatrix( const std::initializer_list< std::tuple< IndexType > >& rowCapacities,
      SparseMatrix( const std::initializer_list< IndexType >& rowCapacities,
                    const IndexType columns,
                    const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
+3 −3
Original line number Diff line number Diff line
@@ -81,13 +81,13 @@ template< typename Real,
          typename RealAllocator,
          typename IndexAllocator >
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
SparseMatrix( const std::initializer_list< std::tuple< IndexType > >& rowCapacities,
SparseMatrix( const std::initializer_list< IndexType >& rowCapacities,
              const IndexType columns,
              const RealAllocatorType& realAllocator,
              const IndexAllocatorType& indexAllocator )
: BaseType( rowCapacities.size(), columns, realAllocator ), columnIndexes( indexAllocator )
{
   this->setCompressedRowLengths( RowCapacitiesType ( rowCapacities ) );
   this->setCompressedRowLengths( RowsCapacitiesType( rowCapacities ) );
}

template< typename Real,
@@ -109,7 +109,7 @@ SparseMatrix( const IndexType rows,
   for( const auto& i : data )
      rowCapacities[ std::get< 0 >( i ) ]++;
   SparseMatrix< Real, Devices::Host, Index, MatrixType, Segments > hostMatrix( rows, columns );
   hostMatrix.setCompressedRowLength( rowCapacities );
   hostMatrix.setCompressedRowLengths( rowCapacities );
   for( const auto& i : data )
      hostMatrix.setElement( std::get< 0 >( i ), std::get< 1 >( i ), std::get< 2 >( i ) );
   ( *this ) = hostMatrix;
+20 −5
Original line number Diff line number Diff line
@@ -48,8 +48,17 @@ void test_Constructors()
   EXPECT_EQ( m1.getColumns(), 6 );

   Matrix m2( {1, 2, 2, 2, 1 }, 5 );
   typename Matrix::RowCapacitiesType v1, v2{ 1, 2, 2, 2, 1 }; 
   m2.getCompressedRowLength( v1 );
   typename Matrix::RowsCapacitiesType v1, v2{ 1, 2, 2, 2, 1 }; 
   m2.setElement( 0, 0, 1 );   // 0th row
   m2.setElement( 1, 0, 1 );   // 1st row
   m2.setElement( 1, 1, 1 );   
   m2.setElement( 2, 1, 1 );   // 2nd row
   m2.setElement( 2, 2, 1 );
   m2.setElement( 3, 2, 1 );   // 3rd row
   m2.setElement( 3, 3, 1 );
   m2.setElement( 4, 4, 1 );   // 4th row
   m2.getCompressedRowLengths( v1 );
   
   EXPECT_EQ( v1, v2 );

   /*
@@ -662,8 +671,14 @@ void test_AddElement()
   const IndexType rows = 6;
   const IndexType cols = 5;

   Matrix m( rows, cols );
   typename Matrix::CompressedRowLengthsVector rowLengths( rows, 3 );
   Matrix m( rows, cols, {
      { 0, 0,  1 }, { 0, 1,  2 }, { 0, 2, 3 },
                    { 1, 1,  4 }, { 1, 2, 5 }, { 1, 3,  6 },
                                  { 2, 2, 7 }, { 2, 3,  8 }, { 2, 4, 9 },
      { 3, 0, 10 }, { 3, 1,  0 }, { 3, 2, 0 },
                    { 4, 1, 11 }, { 4, 2, 0 }, { 4, 3,  0 },
                                  { 5, 2, 0 }, { 5, 3, 12 }, { 5, 4, 0 } } );
   /*typename Matrix::CompressedRowLengthsVector rowLengths( rows, 3 );
   m.setCompressedRowLengths( rowLengths );

   RealType value = 1;
@@ -680,7 +695,7 @@ void test_AddElement()

   m.setElement( 4, 1, value++ );      // 4th row

   m.setElement( 5, 3, value++ );      // 5th row
   m.setElement( 5, 3, value++ );      // 5th row*/


   // Check the set elements
+7 −0
Original line number Diff line number Diff line
@@ -59,6 +59,13 @@ using CSRMatrixTypes = ::testing::Types

TYPED_TEST_SUITE( CSRMatrixTest, CSRMatrixTypes);

TYPED_TEST( CSRMatrixTest, Constructors )
{
    using CSRMatrixType = typename TestFixture::CSRMatrixType;

    test_Constructors< CSRMatrixType >();
}

TYPED_TEST( CSRMatrixTest, setDimensionsTest )
{
    using CSRMatrixType = typename TestFixture::CSRMatrixType;
+7 −0
Original line number Diff line number Diff line
@@ -70,6 +70,13 @@ using EllpackMatrixTypes = ::testing::Types

TYPED_TEST_SUITE( EllpackMatrixTest, EllpackMatrixTypes);

TYPED_TEST( EllpackMatrixTest, Constructors )
{
    using EllpackMatrixType = typename TestFixture::EllpackMatrixType;

    test_Constructors< EllpackMatrixType >();
}

TYPED_TEST( EllpackMatrixTest, setDimensionsTest )
{
    using EllpackMatrixType = typename TestFixture::EllpackMatrixType;
Loading