Commit 62823a9f authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Implementing the tridiagonal matrix format.

parent b2b90a70
Loading
Loading
Loading
Loading
+14 −17
Original line number Diff line number Diff line
@@ -64,7 +64,7 @@ template< typename Real,
   template< typename Real2, typename Device2, typename Index2 >
bool tnlTridiagonalMatrix< Real, Device, Index >::setLike( const tnlTridiagonalMatrix< Real2, Device2, Index2 >& m )
{

   return this->setDimensions( m.getRows() );
}

template< typename Real,
@@ -106,7 +106,7 @@ template< typename Real,
          typename Index >
void tnlTridiagonalMatrix< Real, Device, Index >::setValue( const RealType& v )
{

   this->values.setValue( v );
}

template< typename Real,
@@ -125,6 +125,8 @@ template< typename Real,
Real tnlTridiagonalMatrix< Real, Device, Index >::getElement( const IndexType row,
                                                              const IndexType column ) const
{
   if( abs( column - row ) > 1 )
      return 0.0;
   return this->values.getElement( this->getElementIndex( row, column ) );
}

@@ -229,23 +231,18 @@ template< typename Real,
void tnlTridiagonalMatrix< Real, Device, Index >::getTransposition( const tnlTridiagonalMatrix< Real2, Device, Index2 >& matrix,
                                                                    const RealType& matrixMultiplicator )
{
   tnlAssert( this->getColumns() == matrix.getRows() &&
              this->getRows() == matrix.getColumns(),
               cerr << "This matrix columns: " << this->getColumns() << endl
                    << "This matrix rows: " << this->getRows() << endl
                    << "This matrix name: " << this->getName() << endl
                    << "That matrix columns: " << matrix.getColumns() << endl
                    << "That matrix rows: " << matrix.getRows() << endl
                    << "That matrix name: " << matrix.getName() << endl );
   tnlAssert( this->getRows() == matrix.getRows(),
               cerr << "This matrix rows: " << this->getRows() << endl
                    << "That matrix rows: " << matrix.getRows() << endl );

   const IndexType& rows = matrix.getRows();
   const IndexType& columns = matrix.getColumns();
   for( IndexType i = 0; i < rows; i += tileDim )
      for( IndexType j = 0; j < columns; j += tileDim )
         for( IndexType k = i; k < i + tileDim && k < rows; k++ )
            for( IndexType l = j; l < j + tileDim && l < columns; l++ )
               this->operator()( l, k ) = matrix( k, l );

   for( IndexType i = 1; i < rows; i++ )
   {
      RealType aux = matrix. getElement( i, i - 1 );
      this->setElement( i, i - 1, matrix.getElement( i - 1, i ) );
      this->setElement( i, i, matrix.getElement( i, i ) );
      this->setElement( i - 1, i, aux );
   }
}

template< typename Real,
+42 −7
Original line number Diff line number Diff line
@@ -48,6 +48,7 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
      CppUnit :: TestResult result;

      suiteOfTests -> addTest( new TestCallerType( "setDimensionsTest", &TesterType::setDimensionsTest ) );
      suiteOfTests -> addTest( new TestCallerType( "setLikeTest", &TesterType::setLikeTest ) );
      suiteOfTests -> addTest( new TestCallerType( "setElementTest", &TesterType::setElementTest ) );
      suiteOfTests -> addTest( new TestCallerType( "addToElementTest", &TesterType::addToElementTest ) );
      suiteOfTests -> addTest( new TestCallerType( "vectorProductTest", &TesterType::vectorProductTest ) );
@@ -65,6 +66,28 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
      CPPUNIT_ASSERT( m.getColumns() == 10 );
   }

   void setLikeTest()
   {
      MatrixType m1, m2;
      m1.setDimensions( 10 );
      m2.setLike( m1 );
      CPPUNIT_ASSERT( m1.getRows() == m2.getRows() );
   }

   void setValueTest()
   {
      const int size( 10 );
      MatrixType m;
      m.setDimensions( size );
      m.setValue( 1.0 );
      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            if( abs( i - j ) <= 1 )
               CPPUNIT_ASSERT( m.getElement( i, j ) == 1.0 );
            else
               CPPUNIT_ASSERT( m.getElement( i, j ) == 0.0 );
   }

   void setElementTest()
   {
      MatrixType m;
@@ -88,6 +111,7 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
         m.setElement( i, i, i );
      for( int i = 0; i < 10; i++ )
         for( int j = 0; j < 10; j++ )
            if( abs( i - j ) <= 1 )
               m.addToElement( i, j, 1 );

      for( int i = 0; i < 10; i++ )
@@ -95,7 +119,10 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
            if( i == j )
               CPPUNIT_ASSERT( m.getElement( i, i ) == i + 1 );
            else
               if( abs( i - j ) == 1 )
                  CPPUNIT_ASSERT( m.getElement( i, j ) == 1 );
               else
                  CPPUNIT_ASSERT( m.getElement( i, j ) == 0 );
   }

   void vectorProductTest()
@@ -124,6 +151,7 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
      m.setDimensions( 10);
      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            if( abs( i - j ) <= 1 )
               m( i, j ) = i*size + j;

      MatrixType m2;
@@ -133,13 +161,19 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase

      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            CPPUNIT_ASSERT( m2( i, j ) == m( i, j ) + 3.0 );
            if( abs( i - j ) <= 1 )
               CPPUNIT_ASSERT( m2.getElement( i, j ) == m.getElement( i, j ) + 3.0 );
            else
               CPPUNIT_ASSERT( m2.getElement( i, j ) == 0.0 );

      m2.addMatrix( m, 0.5, 0.0 );

      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            CPPUNIT_ASSERT( m2( i, j ) == 0.5*m( i, j ) );
            if( abs( i - j ) <= 1 )
               CPPUNIT_ASSERT( m2.getElement( i, j ) == 0.5*m.getElement( i, j ) );
            else
               CPPUNIT_ASSERT( m2.getElement( i, j ) == 0.0 );
   }

   void matrixTranspositionTest()
@@ -149,7 +183,8 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase
      m.setDimensions( 10 );
      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            m( i, j ) = i*size + j;
            if( abs( i - j ) <= 1 )
               m.setElement( i, j, i*size + j );

      MatrixType mTransposed;
      mTransposed.setLike( m );
@@ -157,7 +192,7 @@ class tnlTridiagonalMatrixTester : public CppUnit :: TestCase

      for( int i = 0; i < size; i++ )
         for( int j = 0; j < size; j++ )
            CPPUNIT_ASSERT( m( i, j ) == mTransposed( j, i ) );
            CPPUNIT_ASSERT( m.getElement( i, j ) == mTransposed.getElement( j, i ) );
   }
};