Skip to content
Snippets Groups Projects
Commit 3b986213 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber Committed by Tomáš Oberhuber
Browse files

Implementing SparseMatrix using Segments.

parent 6f645c5a
No related branches found
No related tags found
1 merge request!48Segments
......@@ -19,7 +19,7 @@ namespace TNL {
template< typename Device,
typename Index >
class Segments
class CSR
{
public:
......@@ -29,29 +29,24 @@ class Segments
CSR();
CSR( const SizesHolder& sizes );
CSR( const Vector< IndexType, DeviceType, IndexType >& sizes );
CSR( const CSR& csr );
CSR( const CSR& segments );
CSR( const CSR&& csr );
/**
* \brief Set number of segments
*/
//void setSegmentsCount();
CSR( const CSR&& segments );
/**
* \brief Set sizes of particular segmenets.
*/
template< typename SizesHolder = OffsetsHolder >
void setSizes( const SizesHolder& sizes )
void setSizes( const SizesHolder& sizes );
/**
* \brief Number segments.
*/
Index getSize() const;
IndexType getSize() const;
Index getStorageSize() const;
IndexType getStorageSize() const;
IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const;
......@@ -62,13 +57,13 @@ class Segments
* function 'f' with arguments 'args'
*/
template< typename Function, typename... Args >
void forAll( Function& f, Args args ) const;
void forAll( Function& f, Args... args ) const;
/***
* \brief Go over all segments and perform a reduction in each of them.
*/
template< typename Fetch, typename Reduction, typename ResultKeeper, typename... Args >
void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args args );
void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args... args );
protected:
......@@ -80,4 +75,4 @@ class Segments
} // namespace Conatiners
} // namespace TNL
#include <TNL/Containers/Segments/CSR.h>
\ No newline at end of file
#include <TNL/Containers/Segments/CSR.h>
......@@ -10,19 +10,46 @@
#pragma once
#include <TNL/Matrices/Matrix.h>
#include <TNL/Allocators/Default.h>
namespace TNL {
namespace Matrices {
template< typename Real,
typename Organization >
class SparseMatrix : public Matrix< Real, typename Organization::Device, typename Organization::Index >
template< typename, typename > class Segments,
typename Device = Devices::Host,
typename Index = int,
typename RealAllocator = typename Allocators::Default< Device >::template Allocator< Real >,
typename IndexAllocator = typename Allocators::Default< Device >::template Allocator< Index > >
class SparseMatrix : public Matrix< Real, Device, Index, RealAllocator >
{
public:
using RealType = Real;
using OrganizationType = Organization;
using DeviceType = typename Organization::DeviceType;
using IndexType = typename Organization::IndexType;
template< typename Device_, typename Index_ >
using SegmentsTemplate = Segments< Device_, Index_ >;
using SegmentsType = Segments< Device, Index >;
using DeviceType = Device;
using IndexType = Index;
using RealAllocatorType = RealAllocator;
using IndexAllocatorType = IndexAllocator;
using CompressedRowLengthsVectorView = Containers::VectorView< IndexType, DeviceType, IndexType >;
using ConstCompressedRowLengthsVectorView = typename CompressedRowLengthsVectorView::ConstViewType;
using ValuesVectorType = typename Matrix< Real, Device, Index, RealAllocator >::ValuesVector;
using ColumnsVectorType = Containers::Vector< IndexType, DeviceType, IndexType, IndexAllocatorType >;
SparseMatrix( const RealAllocatorType& realAllocator = RealAllocatorType(),
const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
SparseMatrix( const SparseMatrix& m );
SparseMatrix( const SparseMatrix&& m );
SparseMatrix( const IndexType rows,
const IndexType columns,
const RealAllocatorType& realAllocator = RealAllocatorType(),
const IndexAllocatorType& indexAllocator = IndexAllocatorType() );
static String getSerializationType();
......@@ -43,8 +70,10 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
__cuda_callable__
IndexType getNonZeroRowLengthFast( const IndexType row ) const;
template< typename Real2, typename Device2, typename Index2 >
void setLike( const CSR< Real2, Device2, Index2 >& matrix );
template< typename Real2, template< typename, typename > class Segments2, typename Device2, typename Index2, typename RealAllocator2, typename IndexAllocator2 >
void setLike( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix );
IndexType getNumberOfNonzeroMatrixElements() const;
void reset();
......@@ -106,11 +135,11 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
IndexType* columns,
RealType* values ) const;
__cuda_callable__
/*__cuda_callable__
MatrixRow getRow( const IndexType rowIndex );
__cuda_callable__
ConstMatrixRow getRow( const IndexType rowIndex ) const;
ConstMatrixRow getRow( const IndexType rowIndex ) const;*/
template< typename Vector >
__cuda_callable__
......@@ -123,14 +152,15 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
OutVector& outVector ) const;
// TODO: add const RealType& multiplicator = 1.0 )
template< typename Real2, typename Index2 >
void addMatrix( const CSR< Real2, Device, Index2 >& matrix,
/*template< typename Real2, typename Index2 >
void addMatrix( const SparseMatrix< Real2, Segments, Device, Index2 >& matrix,
const RealType& matrixMultiplicator = 1.0,
const RealType& thisMatrixMultiplicator = 1.0 );
template< typename Real2, typename Index2 >
void getTransposition( const CSR< Real2, Device, Index2 >& matrix,
void getTransposition( const SparseMatrix< Real2, Segments, Device, Index2 >& matrix,
const RealType& matrixMultiplicator = 1.0 );
*/
template< typename Vector1, typename Vector2 >
bool performSORIteration( const Vector1& b,
......@@ -139,12 +169,16 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
const RealType& omega = 1.0 ) const;
// copy assignment
CSR& operator=( const CSR& matrix );
SparseMatrix& operator=( const SparseMatrix& matrix );
// cross-device copy assignment
template< typename Real2, typename Device2, typename Index2,
typename = typename Enabler< Device2 >::type >
CSR& operator=( const CSR< Real2, Device2, Index2 >& matrix );
template< typename Real2,
template< typename, typename > class Segments2,
typename Device2,
typename Index2,
typename RealAllocator2,
typename IndexAllocator2 >
SparseMatrix& operator=( const SparseMatrix< Real2, Segments2, Device2, Index2, RealAllocator2, IndexAllocator2 >& matrix );
void save( File& file ) const;
......@@ -155,9 +189,13 @@ class SparseMatrix : public Matrix< Real, typename Organization::Device, typenam
void load( const String& fileName );
void print( std::ostream& str ) const;
protected:
ColumnsVectorType columnsVector;
};
} // namespace Conatiners
} // namespace TNL
\ No newline at end of file
} // namespace TNL
#include <TNL/Matrices/SparseMatrix.hpp>
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment