Commit d070cc39 authored by Jakub Klinkovský's avatar Jakub Klinkovský

Removed HostType and CudaType aliases in containers, matrices and grids

They are not suitable for more than 2 devices/execution types; their design
breaks the Open-Closed Principle. Instead, a type template "Self" was
created, which allows to change any template parameter.
parent 3a997233
......@@ -73,8 +73,8 @@ benchmarkSpmvCuda( Benchmark& benchmark,
{
using RealType = typename Matrix::RealType;
using IndexType = typename Matrix::IndexType;
using CudaMatrix = typename Matrix::CudaType;
using CudaVector = typename Vector::CudaType;
using CudaMatrix = typename Matrix::template Self< RealType, Devices::Cuda >;
using CudaVector = typename Vector::template Self< typename Vector::RealType, Devices::Cuda >;
CudaVector cuda_x;
cuda_x = x;
......@@ -125,8 +125,8 @@ benchmarkDistributedSpmvCuda( Benchmark& benchmark,
{
using RealType = typename Matrix::RealType;
using IndexType = typename Matrix::IndexType;
using CudaMatrix = typename Matrix::CudaType;
using CudaVector = typename Vector::CudaType;
using CudaMatrix = typename Matrix::template Self< RealType, Devices::Cuda >;
using CudaVector = typename Vector::template Self< typename Vector::RealType, Devices::Cuda >;
CudaVector cuda_x;
cuda_x = x;
......
......@@ -119,8 +119,8 @@ benchmarkIterativeSolvers( Benchmark& benchmark,
const Vector& b )
{
#ifdef HAVE_CUDA
using CudaMatrix = typename Matrix::CudaType;
using CudaVector = typename Vector::CudaType;
using CudaMatrix = typename Matrix::template Self< typename Matrix::RealType, Devices::Cuda >;
using CudaVector = typename Vector::template Self< typename Vector::RealType, Devices::Cuda >;
CudaVector cuda_x0, cuda_b;
cuda_x0 = x0;
......@@ -461,9 +461,11 @@ struct LinearSolversBenchmark
SharedPointer< CSR > matrixCopy;
Matrices::copySparseMatrix( *matrixCopy, *matrixPointer );
SharedPointer< typename CSR::CudaType > cuda_matrixCopy;
using CudaCSR = Matrices::CSR< RealType, Devices::Cuda, IndexType >;
using CudaVector = typename VectorType::template Self< RealType, Devices::Cuda >;
SharedPointer< CudaCSR > cuda_matrixCopy;
*cuda_matrixCopy = *matrixCopy;
typename VectorType::CudaType cuda_x0, cuda_b;
CudaVector cuda_x0, cuda_b;
cuda_x0.setLike( x0 );
cuda_b.setLike( b );
cuda_x0 = x0;
......
......@@ -52,7 +52,8 @@ template< typename Array >
void expect_eq( Array& a, Array& b )
{
if( std::is_same< typename Array::DeviceType, TNL::Devices::Cuda >::value ) {
typename Array::HostType a_host, b_host;
using HostArray = typename Array::template Self< typename Array::ValueType, TNL::Devices::Host >;
HostArray a_host, b_host;
a_host = a;
b_host = b;
expect_eq_chunked( a_host, b_host );
......
......@@ -54,7 +54,8 @@ template< typename Array >
void expect_eq( Array& a, Array& b )
{
if( std::is_same< typename Array::DeviceType, TNL::Devices::Cuda >::value ) {
typename Array::HostType a_host, b_host;
using HostArray = typename Array::template Self< typename Array::ValueType, TNL::Devices::Host >;
HostArray a_host, b_host;
a_host = a;
b_host = b;
expect_eq_chunked( a_host, b_host );
......
......@@ -73,7 +73,6 @@ template< typename Value,
class Array
{
public:
/**
* \brief Type of elements stored in this array.
*/
......@@ -98,16 +97,6 @@ class Array
*/
using AllocatorType = Allocator;
/**
* \brief Defines the same array type but allocated on host (CPU).
*/
using HostType = Array< Value, TNL::Devices::Host, Index >;
/**
* \brief Defines the same array type but allocated on CUDA device (GPU).
*/
using CudaType = Array< Value, TNL::Devices::Cuda, Index >;
/**
* \brief Compatible ArrayView type.
*/
......@@ -118,6 +107,15 @@ class Array
*/
using ConstViewType = ArrayView< std::add_const_t< Value >, Device, Index >;
/**
* \brief A template which allows to quickly obtain an \ref Array type with changed template parameters.
*/
template< typename _Value,
typename _Device = Device,
typename _Index = Index,
typename _Allocator = typename Allocators::Default< _Device >::template Allocator< _Value > >
using Self = Array< _Value, _Device, _Index, _Allocator >;
/**
* \brief Constructs an empty array with zero size.
......
......@@ -80,16 +80,6 @@ public:
*/
using IndexType = Index;
/**
* \brief Defines the same array type but allocated on host (CPU).
*/
using HostType = ArrayView< Value, TNL::Devices::Host, Index >;
/**
* \brief Defines the same array type but allocated on CUDA device (GPU).
*/
using CudaType = ArrayView< Value, TNL::Devices::Cuda, Index >;
/**
* \brief Compatible ArrayView type.
*/
......@@ -100,6 +90,15 @@ public:
*/
using ConstViewType = ArrayView< std::add_const_t< Value >, Device, Index >;
/**
* \brief A template which allows to quickly obtain an \ref ArrayView type with changed template parameters.
*/
template< typename _Value,
typename _Device = Device,
typename _Index = Index >
using Self = ArrayView< _Value, _Device, _Index >;
/**
* \brief Constructs an empty array view.
*
......
......@@ -35,11 +35,19 @@ public:
using LocalRangeType = Subrange< Index >;
using LocalViewType = Containers::ArrayView< Value, Device, Index >;
using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >;
using HostType = DistributedArray< Value, Devices::Host, Index, Communicator >;
using CudaType = DistributedArray< Value, Devices::Cuda, Index, Communicator >;
using ViewType = DistributedArrayView< Value, Device, Index, Communicator >;
using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >;
/**
* \brief A template which allows to quickly obtain a \ref DistributedArray type with changed template parameters.
*/
template< typename _Value,
typename _Device = Device,
typename _Index = Index,
typename _Communicator = Communicator >
using Self = DistributedArray< _Value, _Device, _Index, _Communicator >;
DistributedArray() = default;
DistributedArray( DistributedArray& ) = default;
......
......@@ -34,11 +34,19 @@ public:
using LocalRangeType = Subrange< Index >;
using LocalViewType = Containers::ArrayView< Value, Device, Index >;
using ConstLocalViewType = Containers::ArrayView< std::add_const_t< Value >, Device, Index >;
using HostType = DistributedArrayView< Value, Devices::Host, Index, Communicator >;
using CudaType = DistributedArrayView< Value, Devices::Cuda, Index, Communicator >;
using ViewType = DistributedArrayView< Value, Device, Index, Communicator >;
using ConstViewType = DistributedArrayView< std::add_const_t< Value >, Device, Index, Communicator >;
/**
* \brief A template which allows to quickly obtain a \ref DistributedArrayView type with changed template parameters.
*/
template< typename _Value,
typename _Device = Device,
typename _Index = Index,
typename _Communicator = Communicator >
using Self = DistributedArrayView< _Value, _Device, _Index, _Communicator >;
// Initialization by raw data
__cuda_callable__
DistributedArrayView( const LocalRangeType& localRange, IndexType globalSize, CommunicationGroup group, LocalViewType localData )
......
......@@ -34,11 +34,19 @@ public:
using IndexType = Index;
using LocalViewType = Containers::VectorView< Real, Device, Index >;
using ConstLocalViewType = Containers::VectorView< std::add_const_t< Real >, Device, Index >;
using HostType = DistributedVector< Real, Devices::Host, Index, Communicator >;
using CudaType = DistributedVector< Real, Devices::Cuda, Index, Communicator >;
using ViewType = DistributedVectorView< Real, Device, Index, Communicator >;
using ConstViewType = DistributedVectorView< std::add_const_t< Real >, Device, Index, Communicator >;
/**
* \brief A template which allows to quickly obtain a \ref Vector type with changed template parameters.
*/
template< typename _Real,
typename _Device = Device,
typename _Index = Index,
typename _Communicator = Communicator >
using Self = DistributedVector< _Real, _Device, _Index, _Communicator >;
// inherit all constructors and assignment operators from Array
using BaseType::DistributedArray;
using BaseType::operator=;
......
......@@ -35,11 +35,19 @@ public:
using IndexType = Index;
using LocalViewType = Containers::VectorView< Real, Device, Index >;
using ConstLocalViewType = Containers::VectorView< std::add_const_t< Real >, Device, Index >;
using HostType = DistributedVectorView< Real, Devices::Host, Index, Communicator >;
using CudaType = DistributedVectorView< Real, Devices::Cuda, Index, Communicator >;
using ViewType = DistributedVectorView< Real, Device, Index, Communicator >;
using ConstViewType = DistributedVectorView< std::add_const_t< Real >, Device, Index, Communicator >;
/**
* \brief A template which allows to quickly obtain a \ref VectorView type with changed template parameters.
*/
template< typename _Real,
typename _Device = Device,
typename _Index = Index,
typename _Communicator = Communicator >
using Self = DistributedVectorView< _Real, _Device, _Index, _Communicator >;
// inherit all constructors and assignment operators from ArrayView
using BaseType::DistributedArrayView;
using BaseType::operator=;
......
......@@ -42,7 +42,6 @@ class Vector
: public Array< Real, Device, Index, Allocator >
{
public:
/**
* \brief Type of elements stored in this vector.
*/
......@@ -67,16 +66,6 @@ public:
*/
using AllocatorType = Allocator;
/**
* \brief Defines the same vector type but allocated on host (CPU).
*/
using HostType = Vector< Real, TNL::Devices::Host, Index >;
/**
* \brief Defines the same vector type but allocated on CUDA device (GPU).
*/
using CudaType = Vector< Real, TNL::Devices::Cuda, Index >;
/**
* \brief Compatible VectorView type.
*/
......@@ -87,6 +76,16 @@ public:
*/
using ConstViewType = VectorView< std::add_const_t< Real >, Device, Index >;
/**
* \brief A template which allows to quickly obtain a \ref Vector type with changed template parameters.
*/
template< typename _Real,
typename _Device = Device,
typename _Index = Index,
typename _Allocator = typename Allocators::Default< _Device >::template Allocator< _Real > >
using Self = Vector< _Real, _Device, _Index, _Allocator >;
// constructors and assignment operators inherited from the class Array
using Array< Real, Device, Index, Allocator >::Array;
using Array< Real, Device, Index, Allocator >::operator=;
......
......@@ -39,7 +39,6 @@ class VectorView
using BaseType = ArrayView< Real, Device, Index >;
using NonConstReal = typename std::remove_const< Real >::type;
public:
/**
* \brief Type of elements stored in this vector.
*/
......@@ -57,16 +56,6 @@ public:
*/
using IndexType = Index;
/**
* \brief Defines the same vector type but allocated on host (CPU).
*/
using HostType = VectorView< Real, TNL::Devices::Host, Index >;
/**
* \brief Defines the same vector type but allocated on CUDA device (GPU).
*/
using CudaType = VectorView< Real, TNL::Devices::Cuda, Index >;
/**
* \brief Compatible VectorView type.
*/
......@@ -77,6 +66,15 @@ public:
*/
using ConstViewType = VectorView< std::add_const_t< Real >, Device, Index >;
/**
* \brief A template which allows to quickly obtain a \ref VectorView type with changed template parameters.
*/
template< typename _Real,
typename _Device = Device,
typename _Index = Index >
using Self = VectorView< _Real, _Device, _Index >;
// constructors and assignment operators inherited from the class ArrayView
using ArrayView< Real, Device, Index >::ArrayView;
using ArrayView< Real, Device, Index >::operator=;
......
......@@ -84,8 +84,11 @@ public:
typedef Index IndexType;
typedef typename Sparse< RealType, DeviceType, IndexType >::CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef AdEllpack< Real, Devices::Host, Index > HostType;
typedef AdEllpack< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = AdEllpack< _Real, _Device, _Index >;
AdEllpack();
......
......@@ -39,8 +39,11 @@ public:
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef BiEllpack< Real, Devices::Host, Index > HostType;
typedef BiEllpack< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = BiEllpack< _Real, _Device, _Index >;
BiEllpack();
......
......@@ -30,8 +30,11 @@ public:
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef BiEllpackSymmetric< Real, Devices::Host, Index > HostType;
typedef BiEllpackSymmetric< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = BiEllpackSymmetric< _Real, _Device, _Index >;
BiEllpackSymmetric();
......
......@@ -35,8 +35,11 @@ public:
typedef Index IndexType;
typedef typename Sparse< RealType, DeviceType, IndexType >:: CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef COOMatrix< Real, Devices::Host, Index > HostType;
typedef COOMatrix< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = COOMatrix< _Real, _Device, _Index >;
COOMatrix();
......
......@@ -49,12 +49,15 @@ public:
using IndexType = Index;
typedef typename Sparse< RealType, DeviceType, IndexType >::CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef CSR< Real, Devices::Host, Index > HostType;
typedef CSR< Real, Devices::Cuda, Index > CudaType;
typedef Sparse< Real, Device, Index > BaseType;
using MatrixRow = typename BaseType::MatrixRow;
using ConstMatrixRow = typename BaseType::ConstMatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = CSR< _Real, _Device, _Index >;
enum SPMVCudaKernel { scalar, vector, hybrid };
CSR();
......
......@@ -75,12 +75,15 @@ public:
typedef tnlChunkedEllpackSliceInfo< IndexType > ChunkedEllpackSliceInfo;
typedef typename Sparse< RealType, DeviceType, IndexType >:: CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef ChunkedEllpack< Real, Devices::Host, Index > HostType;
typedef ChunkedEllpack< Real, Devices::Cuda, Index > CudaType;
typedef Sparse< Real, Device, Index > BaseType;
typedef typename BaseType::MatrixRow MatrixRow;
typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = ChunkedEllpack< _Real, _Device, _Index >;
ChunkedEllpack();
static String getSerializationType();
......
......@@ -41,11 +41,13 @@ public:
typedef Index IndexType;
typedef typename Matrix< Real, Device, Index >::CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Matrix< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef Dense< Real, Devices::Host, Index > HostType;
typedef Dense< Real, Devices::Cuda, Index > CudaType;
typedef Matrix< Real, Device, Index > BaseType;
typedef DenseRow< Real, Index > MatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = Dense< _Real, _Device, _Index >;
Dense();
......
......@@ -54,14 +54,17 @@ public:
using CommunicatorType = Communicator;
using LocalRangeType = Containers::Subrange< typename Matrix::IndexType >;
using HostType = DistributedMatrix< typename Matrix::HostType, Communicator >;
using CudaType = DistributedMatrix< typename Matrix::CudaType, Communicator >;
using CompressedRowLengthsVector = Containers::DistributedVector< IndexType, DeviceType, IndexType, CommunicatorType >;
using MatrixRow = Matrices::SparseRow< RealType, IndexType >;
using ConstMatrixRow = Matrices::SparseRow< std::add_const_t< RealType >, std::add_const_t< IndexType > >;
template< typename _Real = RealType,
typename _Device = DeviceType,
typename _Index = IndexType,
typename _Communicator = Communicator >
using Self = DistributedMatrix< typename MatrixType::template Self< _Real, _Device, _Index >, _Communicator >;
DistributedMatrix() = default;
DistributedMatrix( DistributedMatrix& ) = default;
......
......@@ -39,12 +39,15 @@ public:
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef Ellpack< Real, Devices::Host, Index > HostType;
typedef Ellpack< Real, Devices::Cuda, Index > CudaType;
typedef Sparse< Real, Device, Index > BaseType;
typedef typename BaseType::MatrixRow MatrixRow;
typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = Ellpack< _Real, _Device, _Index >;
Ellpack();
static String getSerializationType();
......
......@@ -31,9 +31,11 @@ class EllpackSymmetric : public Sparse< Real, Device, Index >
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef EllpackSymmetric< Real, Devices::Host, Index > HostType;
typedef EllpackSymmetric< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = EllpackSymmetric< _Real, _Device, _Index >;
EllpackSymmetric();
......
......@@ -31,9 +31,11 @@ class EllpackSymmetricGraph : public Sparse< Real, Device, Index >
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef EllpackSymmetricGraph< Real, Devices::Host, Index > HostType;
typedef EllpackSymmetricGraph< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = EllpackSymmetricGraph< _Real, _Device, _Index >;
EllpackSymmetricGraph();
......
......@@ -648,8 +648,8 @@ Ellpack< Real, Device, Index >::operator=( const Ellpack< Real2, Device2, Index2
// host -> cuda
if( std::is_same< Device, Devices::Cuda >::value ) {
typename ValuesVector::HostType tmpValues;
typename ColumnIndexesVector::HostType tmpColumnIndexes;
typename ValuesVector::template Self< typename ValuesVector::ValueType, Devices::Sequential > tmpValues;
typename ColumnIndexesVector::template Self< typename ColumnIndexesVector::ValueType, Devices::Sequential > tmpColumnIndexes;
tmpValues.setLike( this->values );
tmpColumnIndexes.setLike( this->columnIndexes );
......
......@@ -425,11 +425,11 @@ class MatrixReaderDeviceDependentCode< Devices::Cuda >
bool verbose,
bool symReader )
{
typedef typename Matrix::HostType HostMatrixType;
typedef typename HostMatrixType::CompressedRowLengthsVector CompressedRowLengthsVector;
using HostMatrixType = typename Matrix::template Self< typename Matrix::RealType, Devices::Sequential >;
using CompressedRowLengthsVector = typename HostMatrixType::CompressedRowLengthsVector;
HostMatrixType hostMatrix;
typename Matrix::CompressedRowLengthsVector rowLengths;
CompressedRowLengthsVector rowLengths;
return MatrixReader< Matrix >::readMtxFileHostMatrix( file, matrix, rowLengths, verbose, symReader );
matrix = hostMatrix;
......
......@@ -38,11 +38,13 @@ public:
typedef Index IndexType;
typedef typename Matrix< Real, Device, Index >::CompressedRowLengthsVector CompressedRowLengthsVector;
typedef typename Matrix< Real, Device, Index >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef Multidiagonal< Real, Devices::Host, Index > HostType;
typedef Multidiagonal< Real, Devices::Cuda, Index > CudaType;
typedef Matrix< Real, Device, Index > BaseType;
typedef MultidiagonalRow< Real, Index > MatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index >
using Self = Multidiagonal< _Real, _Device, _Index >;
Multidiagonal();
......
......@@ -68,12 +68,15 @@ public:
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef SlicedEllpack< Real, Devices::Host, Index, SliceSize > HostType;
typedef SlicedEllpack< Real, Devices::Cuda, Index, SliceSize > CudaType;
typedef Sparse< Real, Device, Index > BaseType;
typedef typename BaseType::MatrixRow MatrixRow;
typedef SparseRow< const RealType, const IndexType > ConstMatrixRow;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index,
int _SliceSize = SliceSize >
using Self = SlicedEllpack< _Real, _Device, _Index, _SliceSize >;
SlicedEllpack();
......
......@@ -49,9 +49,12 @@ class SlicedEllpackSymmetric : public Sparse< Real, Device, Index >
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef SlicedEllpackSymmetric< Real, Devices::Host, Index > HostType;
typedef SlicedEllpackSymmetric< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index,
int _SliceSize = SliceSize >
using Self = SlicedEllpackSymmetric< _Real, _Device, _Index, _SliceSize >;
SlicedEllpackSymmetric();
......
......@@ -49,9 +49,12 @@ class SlicedEllpackSymmetricGraph : public Sparse< Real, Device, Index >
typedef typename Sparse< RealType, DeviceType, IndexType >::ConstCompressedRowLengthsVectorView ConstCompressedRowLengthsVectorView;
typedef typename Sparse< RealType, DeviceType, IndexType >::ValuesVector ValuesVector;
typedef typename Sparse< RealType, DeviceType, IndexType >::ColumnIndexesVector ColumnIndexesVector;
typedef SlicedEllpackSymmetricGraph< Real, Devices::Host, Index > HostType;
typedef SlicedEllpackSymmetricGraph< Real, Devices::Cuda, Index > CudaType;
template< typename _Real = Real,
typename _Device = Device,
typename _Index = Index,
int _SliceSize = SliceSize >
using Self = SlicedEllpackSymmetricGraph< _Real, _Device, _Index, _SliceSize >;
SlicedEllpackSymmetricGraph();
......
......@@ -620,19 +620,14 @@ template< typename Real,
SlicedEllpack< Real, Device, Index, SliceSize >&
SlicedEllpack< Real, Device, Index, SliceSize >::operator=( const SlicedEllpack< Real2, Device2, Index2, SliceSize >& matrix )
{
static_assert( std::is_same< Device, Devices::Host >::value || std::is_same< Device, Devices::Cuda >::value,
"unknown device" );
static_assert( std::is_same< Device2, Devices::Host >::value || std::is_same< Device2, Devices::Cuda >::value,
"unknown device" );
this->setLike( matrix );
this->slicePointers = matrix.slicePointers;
this->sliceCompressedRowLengths = matrix.sliceCompressedRowLengths;
// host -> cuda
if( std::is_same< Device, Devices::Cuda >::value ) {
typename ValuesVector::HostType tmpValues;
typename ColumnIndexesVector::HostType tmpColumnIndexes;
typename ValuesVector::template Self< typename ValuesVector::ValueType, Devices::Sequential > tmpValues;
typename ColumnIndexesVector::template Self< typename ColumnIndexesVector::ValueType, Devices::Sequential > tmpColumnIndexes;
tmpValues.setLike( matrix.values );
tmpColumnIndexes.setLike( matrix.columnIndexes );
......@@ -654,7 +649,7 @@ SlicedEllpack< Real, Device, Index, SliceSize >::operator=( const SlicedEllpack<
}
// cuda -> host
if( std::is_same< Device, Devices::Host >::value ) {
else {
ValuesVector tmpValues;
ColumnIndexesVector tmpColumnIndexes;
tmpValues.setLike( matrix.values );
......@@ -724,7 +719,7 @@ template< typename Real,
int SliceSize >
void SlicedEllpack< Real, Device, Index, SliceSize >::print( std::ostream& str ) const
{
if( std::is_same< Device, Devices::Host >::value ) {
if( ! std::is_same< Device, Devices::Cuda >::value ) {
for( IndexType row = 0; row < this->getRows(); row++ )
{
str <<"Row: " << row << " -> ";
......@@ -745,7 +740,7 @@ void SlicedEllpack< Real, Device, Index, SliceSize >::print( std::ostream& str )
}
}
else {
HostType hostMatrix;
Self< Real, Devices::Sequential > hostMatrix;
hostMatrix = *this;
hostMatrix.print( str );
}
......@@ -778,12 +773,13 @@ __device__ void SlicedEllpack< Real, Device, Index, SliceSize >::computeMaximalR
}
#endif
template<>
class SlicedEllpackDeviceDependentCode< Devices::Host >
// implementation for host types
template< typename Device_ >
class SlicedEllpackDeviceDependentCode
{
public:
typedef Devices::Host Device;
typedef Device_ Device;
template< typename Real,
typename Index,
......
......@@ -170,7 +170,8 @@ typename std::enable_if< ! std::is_same< typename Matrix1::DeviceType, typename
std::is_same< typename Matrix2::DeviceType, Devices::Host >::value >::type
copySparseMatrix_impl( Matrix1& A, const Matrix2& B )
{
typename Matrix2::CudaType B_tmp;
using CudaMatrix2 = typename Matrix2::template Self< typename Matrix2::RealType, Devices::Cuda >;
CudaMatrix2 B_tmp;
B_tmp = B;
copySparseMatrix_impl( A, B_tmp );
}
......@@ -182,7 +183,8 @@ typename std::enable_if< ! std::is_same< typename Matrix1::DeviceType, typename
std::is_same< typename Matrix2::DeviceType, Devices::Cuda >::value >::type
copySparseMatrix_impl( Matrix1& A, const Matrix2& B )
{
typename Matrix1::CudaType A_tmp;
using CudaMatrix1 = typename Matrix1::template Self< typename Matrix1::RealType, Devices::Cuda >;
CudaMatrix1 A_tmp;
copySparseMatrix_impl( A_tmp, B );
A = A_tmp;