Commit 1dde287e authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Fixed SegmentView types in CSR and ChunkedEllpack.

parent 9953d3ab
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -33,7 +33,7 @@ class CSRView
      template< typename Device_, typename Index_ >
      using ViewTemplate = CSRView< Device_, Index_ >;
      using ConstViewType = CSRView< Device, std::add_const_t< Index > >;
      using SegmentViewType = SegmentView< IndexType >;
      using SegmentViewType = SegmentView< IndexType, true >;

      __cuda_callable__
      CSRView();
+1 −1
Original line number Diff line number Diff line
@@ -35,7 +35,7 @@ class ChunkedEllpack
      template< typename Device_, typename Index_ >
      using ViewTemplate = ChunkedEllpackView< Device_, Index_, RowMajorOrder >;
      using ConstViewType = ChunkedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder >;
      using SegmentViewType = SegmentView< IndexType, RowMajorOrder >;
      using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >;
      using ChunkedEllpackSliceInfoType = details::ChunkedEllpackSliceInfo< IndexType >;
      //TODO: using ChunkedEllpackSliceInfoAllocator = typename IndexAllocatorType::retype< ChunkedEllpackSliceInfoType >;
      using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >;
+1 −1
Original line number Diff line number Diff line
@@ -36,7 +36,7 @@ class ChunkedEllpackView
      template< typename Device_, typename Index_ >
      using ViewTemplate = ChunkedEllpackView< Device_, Index_ >;
      using ConstViewType = ChunkedEllpackView< Device, std::add_const_t< Index > >;
      using SegmentViewType = ChunkedEllpackSegmentView< IndexType >;
      using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >;
      using ChunkedEllpackSliceInfoType = details::ChunkedEllpackSliceInfo< IndexType >;
      using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >;
      using ChunkedEllpackSliceInfoContainer = Containers::Array< ChunkedEllpackSliceInfoType, DeviceType, IndexType, ChunkedEllpackSliceInfoAllocator >;
+1 −1
Original line number Diff line number Diff line
@@ -69,7 +69,7 @@ class ChunkedEllpack
      using ChunkedEllpackSliceInfoAllocator = typename Allocators::Default< Device >::template Allocator< ChunkedEllpackSliceInfoType >;
      using ChunkedEllpackSliceInfoContainer = Containers::Array< ChunkedEllpackSliceInfoType, DeviceType, IndexType, ChunkedEllpackSliceInfoAllocator >;
      using ChunkedEllpackSliceInfoContainerView = typename ChunkedEllpackSliceInfoContainer::ViewType;
      using SegmentViewType = ChunkedEllpackSegmentView< IndexType >;
      using SegmentViewType = ChunkedEllpackSegmentView< IndexType, RowMajorOrder >;

      __cuda_callable__ static
      IndexType getSegmentSizeDirect( const OffsetsHolderView& segmentsToSlicesMapping,
+12 −0
Original line number Diff line number Diff line
@@ -69,6 +69,18 @@ void test_Constructors()
   EXPECT_EQ( m2.getElement( 3, 3 ), 1 );
   EXPECT_EQ( m2.getElement( 4, 4 ), 1 );   // 4th row

   if( std::is_same< DeviceType, TNL::Devices::Host >::value )
   {
      EXPECT_EQ( m2.getRow( 0 ).getValue( 0 ), 1 );   // 0th row
      EXPECT_EQ( m2.getRow( 1 ).getValue( 0 ), 1 );   // 1st row
      EXPECT_EQ( m2.getRow( 1 ).getValue( 1 ), 1 );
      EXPECT_EQ( m2.getRow( 2 ).getValue( 0 ), 1 );   // 2nd row
      EXPECT_EQ( m2.getRow( 2 ).getValue( 1 ), 1 );
      EXPECT_EQ( m2.getRow( 3 ).getValue( 0 ), 1 );   // 3rd row
      EXPECT_EQ( m2.getRow( 3 ).getValue( 1 ), 1 );
      EXPECT_EQ( m2.getRow( 4 ).getValue( 0 ), 1 );   // 4th row
   }

   m2.getCompressedRowLengths( v1 );
   EXPECT_EQ( v1, v2 );