Commit bc185a64 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Added template parameters to dense matrix: RowMajorOrder and RealAllocator.

parent 9895f081
Loading
Loading
Loading
Loading
+17 −20
Original line number Diff line number Diff line
@@ -10,6 +10,7 @@

#pragma once

#include <TNL/Allocators/Default.h>
#include <TNL/Devices/Host.h>
#include <TNL/Matrices/Matrix.h>
#include <TNL/Matrices/DenseRow.h>
@@ -23,7 +24,9 @@ class DenseDeviceDependentCode;

template< typename Real = double,
          typename Device = Devices::Host,
          typename Index = int >
          typename Index = int,
          bool RowMajorOrder = std::is_same< Device, Devices::Host >::value,
          typename RealAllocator = typename Allocators::Default< Device >::template Allocator< Real > >
class Dense : public Matrix< Real, Device, Index >
{
private:
@@ -32,17 +35,17 @@ private:
   using Enabler = std::enable_if< ! std::is_same< Device2, Device >::value >;

   // friend class will be needed for templated assignment operators
   template< typename Real2, typename Device2, typename Index2 >
   friend class Dense;
   //template< typename Real2, typename Device2, typename Index2 >
   //friend class Dense;

public:
   typedef Real RealType;
   typedef Device DeviceType;
   typedef Index IndexType;
   typedef typename Matrix< Real, Device, Index >::CompressedRowLengthsVector CompressedRowLengthsVector;
   typedef typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
   typedef Matrix< Real, Device, Index > BaseType;
   typedef DenseRow< Real, Index > MatrixRow;
   using RealType = Real;
   using DeviceType = Device;
   using IndexType = Index;
   using CompressedRowLengthsVector = typename Matrix< Real, Device, Index >::CompressedRowLengthsVector;
   using ConstCompressedRowLengthsVectorView = typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView;
   using BaseType = Matrix< Real, Device, Index >;
   using MatrixRow = DenseRow< Real, Index >;

   template< typename _Real = Real,
             typename _Device = Device,
@@ -58,23 +61,17 @@ public:
   void setDimensions( const IndexType rows,
                       const IndexType columns );

   template< typename Real2, typename Device2, typename Index2 >
   void setLike( const Dense< Real2, Device2, Index2 >& matrix );
   template< typename Matrix >
   void setLike( const Matrix& matrix );

   /****
    * This method is only for the compatibility with the sparse matrices.
    */
   void setCompressedRowLengths( ConstCompressedRowLengthsVectorView rowLengths );

   /****
    * Returns maximal number of the nonzero matrix elements that can be stored
    * in a given row.
    */
   [[deprecated]]
   IndexType getRowLength( const IndexType row ) const;

   __cuda_callable__
   IndexType getRowLengthFast( const IndexType row ) const;

   IndexType getMaxRowLength() const;

   IndexType getNumberOfMatrixElements() const;
@@ -220,4 +217,4 @@ protected:
} // namespace Matrices
} // namespace TNL

#include <TNL/Matrices/Dense_impl.h>
#include <TNL/Matrices/Dense.hpp>
+188 −96

File changed.

Preview size limit exceeded, changes collapsed.

+2 −1
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@
#include <vector>
#include <utility>  // std::pair
#include <limits>   // std::numeric_limits
#include <TNL/Allocators/Host.h>
#include <TNL/Matrices/Dense.h>
#include <TNL/Containers/Vector.h>
#include <TNL/Containers/VectorView.h>
@@ -235,7 +236,7 @@ public:

protected:
   // communication pattern
   Matrices::Dense< IndexType, Devices::Host, int > commPatternStarts, commPatternEnds;
   Matrices::Dense< IndexType, Devices::Host, int, true, Allocators::Host< IndexType > > commPatternStarts, commPatternEnds;

   // span of rows with only block-diagonal entries
   std::pair< IndexType, IndexType > localOnlySpan;
+2 −2
Original line number Diff line number Diff line
@@ -61,8 +61,8 @@ public:

   virtual void getCompressedRowLengths( CompressedRowLengthsVectorView rowLengths ) const;

   template< typename Real2, typename Device2, typename Index2, typename RealAllocator2 >
   void setLike( const Matrix< Real2, Device2, Index2, RealAllocator2 >& matrix );
   template< typename Matrix_ >
   void setLike( const Matrix_& matrix );

   IndexType getNumberOfMatrixElements() const;

+2 −5
Original line number Diff line number Diff line
@@ -81,11 +81,8 @@ template< typename Real,
          typename Device,
          typename Index,
          typename RealAllocator >
   template< typename Real2,
             typename Device2,
             typename Index2,
             typename RealAllocator2 >
void Matrix< Real, Device, Index, RealAllocator >::setLike( const Matrix< Real2, Device2, Index2, RealAllocator2 >& matrix )
   template< typename Matrix_ >
void Matrix< Real, Device, Index, RealAllocator >::setLike( const Matrix_& matrix )
{
   setDimensions( matrix.getRows(), matrix.getColumns() );
}
Loading