Commit d701e791 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added constructor and setElements to set-up SparseMatrix from std::map.

parent 8b23a4d8
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@

#pragma once

#include <map>
#include <TNL/Matrices/Matrix.h>
#include <TNL/Matrices/MatrixType.h>
#include <TNL/Allocators/Default.h>
@@ -92,6 +93,12 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
                    const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );

      template< typename MapIndex,
                typename MapValue >
      explicit SparseMatrix( const IndexType rows,
                             const IndexType columns,
                             const std::map< std::pair< MapIndex, MapIndex > , MapValue >& map );

      ViewType getView() const; // TODO: remove const

      ConstViewType getConstView() const;
@@ -110,6 +117,10 @@ class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >

      void setElements( const std::initializer_list< std::tuple< IndexType, IndexType, RealType > >& data );

      template< typename MapIndex,
                typename MapValue >
      void setElements( const std::map< std::pair< MapIndex, MapIndex > , MapValue >& map );

      template< typename Vector >
      void getCompressedRowLengths( Vector& rowLengths ) const;

+50 −0
Original line number Diff line number Diff line
@@ -109,6 +109,24 @@ SparseMatrix( const IndexType rows,
   this->setElements( data );
}

template< typename Real,
          typename Device,
          typename Index,
          typename MatrixType,
          template< typename, typename, typename > class Segments,
          typename RealAllocator,
          typename IndexAllocator >
   template< typename MapIndex,
             typename MapValue >
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
SparseMatrix( const IndexType rows,
              const IndexType columns,
              const std::map< std::pair< MapIndex, MapIndex > , MapValue >& map )
{
   this->setDimensions( rows, columns );
   this->setElements( map );
}

template< typename Real,
          typename Device,
          typename Index,
@@ -247,6 +265,38 @@ setElements( const std::initializer_list< std::tuple< IndexType, IndexType, Real
   ( *this ) = hostMatrix;
}

template< typename Real,
          typename Device,
          typename Index,
          typename MatrixType,
          template< typename, typename, typename > class Segments,
          typename RealAllocator,
          typename IndexAllocator >
   template< typename MapIndex,
             typename MapValue >
void
SparseMatrix< Real, Device, Index, MatrixType, Segments, RealAllocator, IndexAllocator >::
setElements( const std::map< std::pair< MapIndex, MapIndex > , MapValue >& map )
{
   Containers::Vector< IndexType, Devices::Host, IndexType > rowsCapacities( this->getRows(), 0 );
   for( auto element : map )
      rowsCapacities[ element.first.first ]++;
   if( !std::is_same< DeviceType, Devices::Host >::value )
   {
      SparseMatrix< Real, Devices::Host, Index, MatrixType, Segments > hostMatrix( this->getRows(), this->getColumns() );
      hostMatrix.setCompressedRowLengths( rowsCapacities );
      for( auto element : map )
         hostMatrix.setElement( element.first.first, element.first.second, element.second );
      *this = hostMatrix;
   }
   else
   {
      this->setCompressedRowLengths( rowsCapacities );
      for( auto element : map )
         this->setElement( element.first.first, element.first.second, element.second );
   }
}

template< typename Real,
          typename Device,
          typename Index,
+52 −0
Original line number Diff line number Diff line
@@ -116,6 +116,58 @@ void test_Constructors()
   EXPECT_EQ( m3.getElement( 5, 2 ),  0 );
   EXPECT_EQ( m3.getElement( 5, 3 ), 12 );
   EXPECT_EQ( m3.getElement( 5, 4 ),  0 );

   std::map< std::pair< int, int >, float > map;
   map[ { 0, 0 } ] = 1.0;
   map[ { 0, 1 } ] = 2.0;
   map[ { 0, 2 } ] = 3.0;
   map[ { 1, 1 } ] = 4.0;
   map[ { 1, 2 } ] = 5.0;
   map[ { 1, 3 } ] = 6.0;
   map[ { 2, 2 } ] = 7.0;
   map[ { 2, 3 } ] = 8.0;
   map[ { 2, 4 } ] = 9.0;
   map[ { 3, 0 } ] = 10.0;
   map[ { 4, 1 } ] = 11.0;
   map[ { 5, 3 } ] = 12.0;
   Matrix m4( 6, 5, map );

   // Check the matrix elements
   EXPECT_EQ( m4.getElement( 0, 0 ),  1 );
   EXPECT_EQ( m4.getElement( 0, 1 ),  2 );
   EXPECT_EQ( m4.getElement( 0, 2 ),  3 );
   EXPECT_EQ( m4.getElement( 0, 3 ),  0 );
   EXPECT_EQ( m4.getElement( 0, 4 ),  0 );

   EXPECT_EQ( m4.getElement( 1, 0 ),  0 );
   EXPECT_EQ( m4.getElement( 1, 1 ),  4 );
   EXPECT_EQ( m4.getElement( 1, 2 ),  5 );
   EXPECT_EQ( m4.getElement( 1, 3 ),  6 );
   EXPECT_EQ( m4.getElement( 1, 4 ),  0 );

   EXPECT_EQ( m4.getElement( 2, 0 ),  0 );
   EXPECT_EQ( m4.getElement( 2, 1 ),  0 );
   EXPECT_EQ( m4.getElement( 2, 2 ),  7 );
   EXPECT_EQ( m4.getElement( 2, 3 ),  8 );
   EXPECT_EQ( m4.getElement( 2, 4 ),  9 );

   EXPECT_EQ( m4.getElement( 3, 0 ), 10 );
   EXPECT_EQ( m4.getElement( 3, 1 ),  0 );
   EXPECT_EQ( m4.getElement( 3, 2 ),  0 );
   EXPECT_EQ( m4.getElement( 3, 3 ),  0 );
   EXPECT_EQ( m4.getElement( 3, 4 ),  0 );

   EXPECT_EQ( m4.getElement( 4, 0 ),  0 );
   EXPECT_EQ( m4.getElement( 4, 1 ), 11 );
   EXPECT_EQ( m4.getElement( 4, 2 ),  0 );
   EXPECT_EQ( m4.getElement( 4, 3 ),  0 );
   EXPECT_EQ( m4.getElement( 4, 4 ),  0 );

   EXPECT_EQ( m4.getElement( 5, 0 ),  0 );
   EXPECT_EQ( m4.getElement( 5, 1 ),  0 );
   EXPECT_EQ( m4.getElement( 5, 2 ),  0 );
   EXPECT_EQ( m4.getElement( 5, 3 ), 12 );
   EXPECT_EQ( m4.getElement( 5, 4 ),  0 );
}

template< typename Matrix >