From 1dde287ebf133b24c147e8f8f790ebbc71004a57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= Date: Tue, 31 Mar 2020 22:17:16 +0200 Subject: [PATCH] Fixed SegmentView types in CSR and ChunkedEllpack. --- src/TNL/Containers/Segments/CSRView.h | 2 +- src/TNL/Containers/Segments/ChunkedEllpack.h | 2 +- src/TNL/Containers/Segments/ChunkedEllpackView.h | 2 +- src/TNL/Containers/Segments/details/ChunkedEllpack.h | 2 +- src/UnitTests/Matrices/SparseMatrixTest.hpp | 12 ++++++++++++ 5 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index f7cf815d0..4e53bd204 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -33,7 +33,7 @@ class CSRView template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentViewType = SegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, true >; __cuda_callable__ CSRView(); diff --git a/src/TNL/Containers/Segments/ChunkedEllpack.h b/src/TNL/Containers/Segments/ChunkedEllpack.h index 93580a9cd..c6c7812db 100644 --- a/src/TNL/Containers/Segments/ChunkedEllpack.h +++ b/src/TNL/Containers/Segments/ChunkedEllpack.h @@ -35,7 +35,7 @@ class ChunkedEllpack template< typename Device_, typename Index_ > using ViewTemplate = ChunkedEllpackView< Device_, Index_, RowMajorOrder >; using ConstViewType = ChunkedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder >; - using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; + using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >; using ChunkedEllpackSliceInfoType = details::ChunkedEllpackSliceInfo< IndexType >; //TODO: using ChunkedEllpackSliceInfoAllocator = typename IndexAllocatorType::retype< ChunkedEllpackSliceInfoType >; using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >; diff --git a/src/TNL/Containers/Segments/ChunkedEllpackView.h b/src/TNL/Containers/Segments/ChunkedEllpackView.h index 4b444d084..eaf2450b5 100644 --- a/src/TNL/Containers/Segments/ChunkedEllpackView.h +++ b/src/TNL/Containers/Segments/ChunkedEllpackView.h @@ -36,7 +36,7 @@ class ChunkedEllpackView template< typename Device_, typename Index_ > using ViewTemplate = ChunkedEllpackView< Device_, Index_ >; using ConstViewType = ChunkedEllpackView< Device, std::add_const_t< Index > >; - using SegmentViewType = ChunkedEllpackSegmentView< IndexType >; + using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >; using ChunkedEllpackSliceInfoType = details::ChunkedEllpackSliceInfo< IndexType >; using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >; using ChunkedEllpackSliceInfoContainer = Containers::Array< ChunkedEllpackSliceInfoType, DeviceType, IndexType, ChunkedEllpackSliceInfoAllocator >; diff --git a/src/TNL/Containers/Segments/details/ChunkedEllpack.h b/src/TNL/Containers/Segments/details/ChunkedEllpack.h index 8807de226..95ae00c88 100644 --- a/src/TNL/Containers/Segments/details/ChunkedEllpack.h +++ b/src/TNL/Containers/Segments/details/ChunkedEllpack.h @@ -69,7 +69,7 @@ class ChunkedEllpack using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >; using ChunkedEllpackSliceInfoContainer = Containers::Array< ChunkedEllpackSliceInfoType, DeviceType, IndexType, ChunkedEllpackSliceInfoAllocator >; using ChunkedEllpackSliceInfoContainerView = typename ChunkedEllpackSliceInfoContainer::ViewType; - using SegmentViewType = ChunkedEllpackSegmentView< IndexType >; + using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >; __cuda_callable__ static IndexType getSegmentSizeDirect( const OffsetsHolderView& segmentsToSlicesMapping, diff --git a/src/UnitTests/Matrices/SparseMatrixTest.hpp b/src/UnitTests/Matrices/SparseMatrixTest.hpp index d88565472..8080d45e5 100644 --- a/src/UnitTests/Matrices/SparseMatrixTest.hpp +++ b/src/UnitTests/Matrices/SparseMatrixTest.hpp @@ -69,6 +69,18 @@ void test_Constructors() EXPECT_EQ( m2.getElement( 3, 3 ), 1 ); EXPECT_EQ( m2.getElement( 4, 4 ), 1 ); // 4th row + if( std::is_same< DeviceType, TNL::Devices::Host >::value ) + { + EXPECT_EQ( m2.getRow( 0 ).getValue( 0 ), 1 ); // 0th row + EXPECT_EQ( m2.getRow( 1 ).getValue( 0 ), 1 ); // 1st row + EXPECT_EQ( m2.getRow( 1 ).getValue( 1 ), 1 ); + EXPECT_EQ( m2.getRow( 2 ).getValue( 0 ), 1 ); // 2nd row + EXPECT_EQ( m2.getRow( 2 ).getValue( 1 ), 1 ); + EXPECT_EQ( m2.getRow( 3 ).getValue( 0 ), 1 ); // 3rd row + EXPECT_EQ( m2.getRow( 3 ).getValue( 1 ), 1 ); + EXPECT_EQ( m2.getRow( 4 ).getValue( 0 ), 1 ); // 4th row + } + m2.getCompressedRowLengths( v1 ); EXPECT_EQ( v1, v2 ); -- GitLab