Commit 8b23a4d8 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added DenseMatrix constructor from std initializer list and method setElements.

parent 0e48e0a4
Loading
Loading
Loading
Loading
+12 −1
Original line number Diff line number Diff line
@@ -57,6 +57,8 @@ class Dense : public Matrix< Real, Device, Index >

      Dense( const IndexType rows, const IndexType columns );

      Dense( std::initializer_list< std::initializer_list< RealType > > data );

      ViewType getView();

      ConstViewType getConstView() const;
@@ -71,7 +73,16 @@ class Dense : public Matrix< Real, Device, Index >
      template< typename Matrix >
      void setLike( const Matrix& matrix );

      /****
      /**
       * \brief This method creates dense matrix from 2D initializer list.
       * 
       * The matrix dimensions will be adjusted by the input data.
       * 
       * @param data
       */
      void setElements( std::initializer_list< std::initializer_list< RealType > > data );
      
      /**
       * This method is only for the compatibility with the sparse matrices.
       */
      void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );
+51 −0
Original line number Diff line number Diff line
@@ -37,6 +37,57 @@ Dense( const IndexType rows, const IndexType columns )
   this->setDimensions( rows, columns );
}

template< typename Real,
          typename Device,
          typename Index,
          bool RowMajorOrder,
          typename RealAllocator >
Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::
Dense( std::initializer_list< std::initializer_list< RealType > > data )
{
   this->setElements( data );
}

template< typename Real,
          typename Device,
          typename Index,
          bool RowMajorOrder,
          typename RealAllocator >
void
Dense< Real, Device, Index, RowMajorOrder, RealAllocator >::
setElements( std::initializer_list< std::initializer_list< RealType > > data )
{
   IndexType rows = data.size();
   IndexType columns = 0;
   for( auto row : data )
      columns = max( columns, row.size() );
   this->setDimensions( rows, columns );
   if( ! std::is_same< DeviceType, Devices::Host >::value )
   {
      Dense< RealType, Devices::Host, IndexType > hostDense( rows, columns );
      IndexType rowIdx( 0 );
      for( auto row : data )
      {
         IndexType columnIdx( 0 );
         for( auto element : row )
            hostDense.setElement( rowIdx, columnIdx++, element );
         rowIdx++;
      }
      *this = hostDense;
   }
   else
   {
      IndexType rowIdx( 0 );
      for( auto row : data )
      {
         IndexType columnIdx( 0 );
         for( auto element : row )
            this->setElement( rowIdx, columnIdx++, element );
         rowIdx++;
      }
   }
}

template< typename Real,
          typename Device,
          typename Index,
+56 −23
Original line number Diff line number Diff line
@@ -84,6 +84,32 @@ void test_SetLike()
   EXPECT_EQ( m1.getColumns(), m2.getColumns() );
}

template< typename Matrix >
void test_SetElements()
{
   using RealType = typename Matrix::RealType;
   using DeviceType = typename Matrix::DeviceType;
   using IndexType = typename Matrix::IndexType;

   Matrix m( {
      { 1, 2, 3 },
      { 4, 5, 6 },
      { 7, 8, 9 },
   } );

   EXPECT_EQ( m.getRows(), 3 );
   EXPECT_EQ( m.getColumns(), 3 );
   EXPECT_EQ( m.getElement( 0, 0 ), 1 );
   EXPECT_EQ( m.getElement( 0, 1 ), 2 );
   EXPECT_EQ( m.getElement( 0, 2 ), 3 );
   EXPECT_EQ( m.getElement( 1, 0 ), 4 );
   EXPECT_EQ( m.getElement( 1, 1 ), 5 );
   EXPECT_EQ( m.getElement( 1, 2 ), 6 );
   EXPECT_EQ( m.getElement( 2, 0 ), 7 );
   EXPECT_EQ( m.getElement( 2, 1 ), 8 );
   EXPECT_EQ( m.getElement( 2, 2 ), 9 );
}

template< typename Matrix >
void test_GetCompressedRowLengths()
{
@@ -1386,6 +1412,13 @@ TYPED_TEST( MatrixTest, setLikeTest )
    test_SetLike< MatrixType, MatrixType >();
}

TYPED_TEST( MatrixTest, setElementsTest )
{
    using MatrixType = typename TestFixture::MatrixType;

    test_SetElements< MatrixType >();
}

TYPED_TEST( MatrixTest, getRowLengthTest )
{
    using MatrixType = typename TestFixture::MatrixType;