Skip to content
Snippets Groups Projects
Commit b927423b authored by Lukas Cejka's avatar Lukas Cejka Committed by Tomáš Oberhuber
Browse files

Added basic functionality for cross-device copy assignment. Removed StripSize...

Added basic functionality for cross-device copy assignment. Removed StripSize template typename as it was never used anywhere.
parent 44257e9c
No related branches found
No related tags found
1 merge request!45Matrices revision
...@@ -28,9 +28,19 @@ namespace TNL { ...@@ -28,9 +28,19 @@ namespace TNL {
template< typename Device > template< typename Device >
class BiEllpackDeviceDependentCode; class BiEllpackDeviceDependentCode;
template< typename Real, typename Device = Devices::Cuda, typename Index = int, int StripSize = 32 > template< typename Real, typename Device /*= Devices::Cuda*/, typename Index /*= int*/ >
class BiEllpack : public Sparse< Real, Device, Index > class BiEllpack : public Sparse< Real, Device, Index >
{ {
private:
// convenient template alias for controlling the selection of copy-assignment operator
template< typename Device2 >
using Enabler = std::enable_if< ! std::is_same< Device2, Device >::value >;
// friend class will be needed for templated assignment operators
template< typename Real2, typename Device2, typename Index2 >
friend class BiEllpack;
public: public:
typedef Real RealType; typedef Real RealType;
typedef Device DeviceType; typedef Device DeviceType;
...@@ -57,7 +67,15 @@ public: ...@@ -57,7 +67,15 @@ public:
template< typename Real2, template< typename Real2,
typename Device2, typename Device2,
typename Index2 > typename Index2 >
void setLike( const BiEllpack< Real2, Device2, Index2, StripSize >& matrix ); void setLike( const BiEllpack< Real2, Device2, Index2 >& matrix );
void reset();
template< typename Real2, typename Device2, typename Index2 >
bool operator == ( const BiEllpack< Real2, Device2, Index2 >& matrix ) const;
template< typename Real2, typename Device2, typename Index2 >
bool operator != ( const BiEllpack< Real2, Device2, Index2 >& matrix ) const;
void getRowLengths( CompressedRowLengthsVector& rowLengths ) const; void getRowLengths( CompressedRowLengthsVector& rowLengths ) const;
...@@ -124,8 +142,14 @@ public: ...@@ -124,8 +142,14 @@ public:
IndexType getNumberOfGroups( const IndexType row ) const; IndexType getNumberOfGroups( const IndexType row ) const;
bool vectorProductTest() const; bool vectorProductTest() const;
// copy assignment
BiEllpack& operator=( const BiEllpack& matrix );
void reset(); // cross-device copy assignment
template< typename Real2, typename Device2, typename Index2,
typename = typename Enabler< Device2 >::type >
BiEllpack& operator=( const BiEllpack< Real2, Device2, Index2 >& matrix );
void save( File& file ) const; void save( File& file ) const;
...@@ -136,11 +160,13 @@ public: ...@@ -136,11 +160,13 @@ public:
void load( const String& fileName ); void load( const String& fileName );
void print( std::ostream& str ) const; void print( std::ostream& str ) const;
void printValues() const;
void performRowBubbleSort( Containers::Vector< Index, Device, Index >& tempRowLengths ); void performRowBubbleSort( Containers::Vector< Index, Device, Index >& tempRowLengths );
void computeColumnSizes( Containers::Vector< Index, Device, Index >& tempRowLengths ); void computeColumnSizes( Containers::Vector< Index, Device, Index >& tempRowLengths );
// void verifyRowLengths( const typename BiEllpack< Real, Device, Index, StripSize >::CompressedRowLengthsVector& rowLengths ); // void verifyRowLengths( const typename BiEllpack< Real, Device, Index >::CompressedRowLengthsVector& rowLengths );
template< typename InVector, template< typename InVector,
typename OutVector > typename OutVector >
...@@ -157,11 +183,11 @@ public: ...@@ -157,11 +183,11 @@ public:
IndexType getStripLength( const IndexType strip ) const; IndexType getStripLength( const IndexType strip ) const;
__cuda_callable__ __cuda_callable__
void performRowBubbleSortCudaKernel( const typename BiEllpack< Real, Device, Index, StripSize >::CompressedRowLengthsVector& rowLengths, void performRowBubbleSortCudaKernel( const typename BiEllpack< Real, Device, Index >::CompressedRowLengthsVector& rowLengths,
const IndexType strip ); const IndexType strip );
__cuda_callable__ __cuda_callable__
void computeColumnSizesCudaKernel( const typename BiEllpack< Real, Device, Index, StripSize >::CompressedRowLengthsVector& rowLengths, void computeColumnSizesCudaKernel( const typename BiEllpack< Real, Device, Index >::CompressedRowLengthsVector& rowLengths,
const IndexType numberOfStrips, const IndexType numberOfStrips,
const IndexType strip ); const IndexType strip );
...@@ -171,6 +197,8 @@ public: ...@@ -171,6 +197,8 @@ public:
typedef BiEllpackDeviceDependentCode< DeviceType > DeviceDependentCode; typedef BiEllpackDeviceDependentCode< DeviceType > DeviceDependentCode;
friend class BiEllpackDeviceDependentCode< DeviceType >; friend class BiEllpackDeviceDependentCode< DeviceType >;
friend class BiEllpack< RealType, Devices::Host, IndexType >;
friend class BiEllpack< RealType, Devices::Cuda, IndexType >;
private: private:
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment