Commit b3d86f89 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Updating API of tridiagonal matrix.

parent 97888bf1
Loading
Loading
Loading
Loading
+114 −132
Original line number Diff line number Diff line
@@ -13,6 +13,7 @@
#include <TNL/Matrices/Matrix.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Matrices/TridiagonalRow.h>
#include <TNL/Containers/Segments/Ellpack.h>

namespace TNL {
namespace Matrices {
@@ -22,8 +23,10 @@ class TridiagonalDeviceDependentCode;

template< typename Real = double,
          typename Device = Devices::Host,
          typename Index = int >
class Tridiagonal : public Matrix< Real, Device, Index >
          typename Index = int,
          bool RowMajorOrder = std::is_same< Device, Devices::Host >::value,
          typename RealAllocator = typename Allocators::Default< Device >::template Allocator< Real > >
class Tridiagonal : public Matrix< Real, Device, Index, RealAllocator >
{
   private:
      // convenient template alias for controlling the selection of copy-assignment operator
@@ -35,13 +38,17 @@ private:
      friend class Tridiagonal;

   public:
   typedef Real RealType;
   typedef Device DeviceType;
   typedef Index IndexType;
   typedef typename Matrix< Real, Device, Index >::CompressedRowLengthsVector CompressedRowLengthsVector;
   typedef typename Matrix< Real, Device, Index >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
   typedef Matrix< Real, Device, Index > BaseType;
   typedef TridiagonalRow< Real, Index > MatrixRow;
      using RealType = Real;
      using DeviceType = Device;
      using IndexType = Index;
      using RealAllocatorType = RealAllocator;
      using BaseType = Matrix< Real, Device, Index, RealAllocator >;
      using ValuesType = typename BaseType::ValuesVector;
      using ValuesViewType = typename ValuesType::ViewType;
      //using ViewType = TridiagonalMatrixView< Real, Device, Index, RowMajorOrder >;
      //using ConstViewType = TridiagonalMatrixView< typename std::add_const< Real >::type, Device, Index, RowMajorOrder >;
      using RowView = TridiagonalMatrixRowView< SegmentViewType, ValuesViewType >;


      template< typename _Real = Real,
                typename _Device = Device,
@@ -50,6 +57,12 @@ public:

      Tridiagonal();

      Tridiagonal( const IndexType rows, const IndexType columns );

      ViewType getView();

      ConstViewType getConstView() const;

      static String getSerializationType();

      virtual String getSerializationTypeVirtual() const;
@@ -59,10 +72,11 @@ public:

      void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );

   IndexType getRowLength( const IndexType row ) const;
      template< typename Vector >
      void getCompressedRowLengths( Vector& rowLengths ) const;

   __cuda_callable__
   IndexType getRowLengthFast( const IndexType row ) const;
      [[deprecated]]
      IndexType getRowLength( const IndexType row ) const;

      IndexType getMaxRowLength() const;

@@ -77,75 +91,43 @@ public:

      void reset();

   template< typename Real2, typename Device2, typename Index2 >
   bool operator == ( const Tridiagonal< Real2, Device2, Index2 >& matrix ) const;
      template< typename Real_, typename Device_, typename Index_, bool RowMajorOrder_ >
      bool operator == ( const Tridiagonal< Real_, Device_, Index_, RowMajorOrder_ >& matrix ) const;

   template< typename Real2, typename Device2, typename Index2 >
   bool operator != ( const Tridiagonal< Real2, Device2, Index2 >& matrix ) const;
      template< typename Real_, typename Device_, typename Index_, bool RowMajorOrder_ >
      bool operator != ( const Tridiagonal< Real_, Device_, Index_ >& matrix ) const;

      void setValue( const RealType& v );

   __cuda_callable__
   bool setElementFast( const IndexType row,
                        const IndexType column,
                        const RealType& value );

      bool setElement( const IndexType row,
                       const IndexType column,
                       const RealType& value );

   __cuda_callable__
   bool addElementFast( const IndexType row,
                        const IndexType column,
                        const RealType& value,
                        const RealType& thisElementMultiplicator = 1.0 );

      bool addElement( const IndexType row,
                       const IndexType column,
                       const RealType& value,
                       const RealType& thisElementMultiplicator = 1.0 );

   __cuda_callable__
   bool setRowFast( const IndexType row,
                    const IndexType* columns,
                    const RealType* values,
                    const IndexType elements );
      RealType getElement( const IndexType row,
                           const IndexType column ) const;

   bool setRow( const IndexType row,
                const IndexType* columns,
                const RealType* values,
                const IndexType elements );
      template< typename Fetch, typename Reduce, typename Keep, typename FetchReal >
      void rowsReduction( IndexType first, IndexType last, Fetch& fetch, Reduce& reduce, Keep& keep, const FetchReal& zero ) const;

   __cuda_callable__
   bool addRowFast( const IndexType row,
                    const IndexType* columns,
                    const RealType* values,
                    const IndexType elements,
                    const RealType& thisRowMultiplicator = 1.0 );

   bool addRow( const IndexType row,
                const IndexType* columns,
                const RealType* values,
                const IndexType elements,
                const RealType& thisRowMultiplicator = 1.0 );
      template< typename Fetch, typename Reduce, typename Keep, typename FetchReal >
      void allRowsReduction( Fetch& fetch, Reduce& reduce, Keep& keep, const FetchReal& zero ) const;

   __cuda_callable__
   RealType getElementFast( const IndexType row,
                            const IndexType column ) const;
      template< typename Function >
      void forRows( IndexType first, IndexType last, Function& function ) const;

   RealType getElement( const IndexType row,
                        const IndexType column ) const;
      template< typename Function >
      void forRows( IndexType first, IndexType last, Function& function );

   __cuda_callable__
   void getRowFast( const IndexType row,
                    IndexType* columns,
                    RealType* values ) const;
      template< typename Function >
      void forAllRows( Function& function ) const;

   __cuda_callable__
   MatrixRow getRow( const IndexType rowIndex );

   __cuda_callable__
   const MatrixRow getRow( const IndexType rowIndex ) const;
      template< typename Function >
      void forAllRows( Function& function );

      template< typename Vector >
      __cuda_callable__
@@ -206,4 +188,4 @@ protected:
} // namespace Matrices
} // namespace TNL

#include <TNL/Matrices/Tridiagonal_impl.h>
#include <TNL/Matrices/Tridiagonal.hpp>
+252 −232

File changed.

Preview size limit exceeded, changes collapsed.