Commit 3b986213 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Implementing SparseMatrix using Segments.

parent 6f645c5a
Loading
Loading
Loading
Loading
+10 −15
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ namespace TNL {

template< typename Device,
          typename Index >
class Segments
class CSR
{
   public:

@@ -29,29 +29,24 @@ class Segments

      CSR();

      CSR( const SizesHolder& sizes );
      CSR( const Vector< IndexType, DeviceType, IndexType >& sizes );

      CSR( const CSR& csr );
      CSR( const CSR& segments );

      CSR( const CSR&& csr );

      /**
       * \brief Set number of segments
       */
      //void setSegmentsCount();
      CSR( const CSR&& segments );

      /**
       * \brief Set sizes of particular segmenets.
       */
      template< typename SizesHolder = OffsetsHolder >
      void setSizes( const SizesHolder& sizes )
      void setSizes( const SizesHolder& sizes );

      /**
       * \brief Number segments.
       */
      Index getSize() const;
      IndexType getSize() const;

      Index getStorageSize() const;
      IndexType getStorageSize() const;

      IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const;

@@ -62,13 +57,13 @@ class Segments
       * function 'f' with arguments 'args'
       */
      template< typename Function, typename... Args >
      void forAll( Function& f, Args args ) const;
      void forAll( Function& f, Args... args ) const;

      /***
       * \brief Go over all segments and perform a reduction in each of them.
       */
      template< typename Fetch, typename Reduction, typename ResultKeeper, typename... Args >
      void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args args );
      void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args... args );

   protected:

+56 −18
Original line number Diff line number Diff line
@@ -10,19 +10,46 @@

#pragma once

#include <TNL/Matrices/Matrix.h>
#include <TNL/Allocators/Default.h>

namespace TNL {
namespace Matrices {

template< typename Real,
          typename Organization >
class SparseMatrix : public Matrix< Real, typename Organization::Device, typename Organization::Index >
          template< typename, typename > class Segments,
          typename Device = Devices::Host,
          typename Index = int,
          typename RealAllocator = typename Allocators::Default< Device >::template Allocator< Real >,
          typename IndexAllocator = typename Allocators::Default< Device >::template Allocator< Index > >
class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
{
   public:

      using RealType = Real;
      using OrganizationType = Organization;
      using DeviceType = typename Organization::DeviceType;
      using IndexType = typename Organization::IndexType;
      template< typename Device_, typename Index_ >
      using SegmentsTemplate = Segments< Device_, Index_ >;
      using SegmentsType = Segments< Device, Index >;
      using DeviceType = Device;
      using IndexType = Index;
      using RealAllocatorType = RealAllocator;
      using IndexAllocatorType = IndexAllocator;
      using CompressedRowLengthsVectorView = Containers::VectorView< IndexType, DeviceType, IndexType >;
      using ConstCompressedRowLengthsVectorView = typename CompressedRowLengthsVectorView::ConstViewType;
      using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector;
      using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >;

      SparseMatrix( const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );

      SparseMatrix( const SparseMatrix& m );

      SparseMatrix( const SparseMatrix&& m );

      SparseMatrix( const IndexType rows,
                    const IndexType columns,
                    const RealAllocatorType& realAllocator = RealAllocatorType(),
                    const IndexAllocatorType& indexAllocator = IndexAllocatorType() );

      static String getSerializationType();

@@ -43,8 +70,10 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
      __cuda_callable__
      IndexType getNonZeroRowLengthFast( const IndexType row ) const;

      template< typename Real2, typename Device2, typename Index2 >
      void setLike( const CSR< Real2, Device2, Index2 >& matrix );
      template< typename Real2, template< typename, typename > class Segments2, typename Device2, typename Index2, typename RealAllocator2, typename IndexAllocator2 >
      void setLike( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix );

      IndexType getNumberOfNonzeroMatrixElements() const;

      void reset();

@@ -106,11 +135,11 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
                       IndexType* columns,
                       RealType* values ) const;

      __cuda_callable__
      /*__cuda_callable__
      MatrixRow getRow( const IndexType rowIndex );

      __cuda_callable__
      ConstMatrixRow getRow( const IndexType rowIndex ) const;
      ConstMatrixRow getRow( const IndexType rowIndex ) const;*/

      template< typename Vector >
      __cuda_callable__
@@ -123,14 +152,15 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
                          OutVector& outVector ) const;
      // TODO: add const RealType& multiplicator = 1.0 )

      template< typename Real2, typename Index2 >
      void addMatrix( const CSR< Real2, Device, Index2 >& matrix,
      /*template< typename Real2, typename Index2 >
      void addMatrix( const SparseMatrix< Real2, Segments, Device, Index2 >& matrix,
                      const RealType& matrixMultiplicator = 1.0,
                      const RealType& thisMatrixMultiplicator = 1.0 );

      template< typename Real2, typename Index2 >
      void getTransposition( const CSR< Real2, Device, Index2 >& matrix,
      void getTransposition( const SparseMatrix< Real2, Segments, Device, Index2 >& matrix,
                             const RealType& matrixMultiplicator = 1.0 );
       */

      template< typename Vector1, typename Vector2 >
      bool performSORIteration( const Vector1& b,
@@ -139,12 +169,16 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
                                const RealType& omega = 1.0 ) const;

      // copy assignment
      CSR& operator=( const CSR& matrix );
      SparseMatrix& operator=( const SparseMatrix& matrix );

      // cross-device copy assignment
      template< typename Real2, typename Device2, typename Index2,
                typename = typename Enabler< Device2 >::type >
      CSR& operator=( const CSR< Real2, Device2, Index2 >& matrix );
      template< typename Real2, 
                template< typename, typename > class Segments2,
                typename Device2,
                typename Index2,
                typename RealAllocator2,
                typename IndexAllocator2 >
      SparseMatrix& operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix );

      void save( File& file ) const;

@@ -156,8 +190,12 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam

      void print( std::ostream& str ) const;
      
   protected:

      ColumnsVectorType columnsVector;
};

}  // namespace Conatiners
} // namespace TNL

#include <TNL/Matrices/SparseMatrix.hpp>
+566 −178

File changed.

Preview size limit exceeded, changes collapsed.