From 01f367ca5e64c90d8c1db876e16bad37b3ef6e2d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com>
Date: Mon, 9 Dec 2019 20:07:29 +0100
Subject: [PATCH] Implementing SlicedEllpack segments.

---
 src/TNL/Containers/Segments/CSR.hpp           |   2 +-
 src/TNL/Containers/Segments/Ellpack.hpp       |   5 +-
 src/TNL/Containers/Segments/SlicedEllpack.h   |  14 +-
 src/TNL/Containers/Segments/SlicedEllpack.hpp | 191 ++++++++++++------
 4 files changed, 141 insertions(+), 71 deletions(-)

diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp
index 486149e042..ef7431038c 100644
--- a/src/TNL/Containers/Segments/CSR.hpp
+++ b/src/TNL/Containers/Segments/CSR.hpp
@@ -153,7 +153,7 @@ void
 CSR< Device, Index >::
 forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
 {
-   const auto offsetsView = this->offsets.getView();
+   const auto offsetsView = this->offsets.getConstView();
    auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
       const IndexType begin = offsetsView[ i ];
       const IndexType end = offsetsView[ i + 1 ];
diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp
index 034b0820eb..d3d90be5e0 100644
--- a/src/TNL/Containers/Segments/Ellpack.hpp
+++ b/src/TNL/Containers/Segments/Ellpack.hpp
@@ -189,11 +189,10 @@ void
 Ellpack< Device, Index, RowMajorOrder, Alignment >::
 forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
 {
-   const auto offsetsView = this->offsets.getView();
    if( RowMajorOrder )
    {
       const IndexType segmentSize = this->segmentSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) {
+      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i * segmentSize;
          const IndexType end = begin + segmentSize;
          for( IndexType j = begin; j < end; j++  )
@@ -206,7 +205,7 @@ forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
    {
       const IndexType storageSize = this->getStorageSize();
       const IndexType alignedSize = this->alignedSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) {
+      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
          const IndexType begin = i;
          const IndexType end = storageSize;
          for( IndexType j = begin; j < end; j += alignedSize )
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h
index a5ef9d1211..ecc2c8c7ef 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.h
+++ b/src/TNL/Containers/Segments/SlicedEllpack.h
@@ -42,7 +42,13 @@ class SlicedEllpack
        * \brief Set sizes of particular segments.
        */
       template< typename SizesHolder = OffsetsHolder >
-      void setSizes( const SizesHolder& sizes );
+      void setSegmentsSizes( const SizesHolder& sizes );
+
+      __cuda_callable__
+      IndexType getSegmentsCount() const;
+
+      __cuda_callable__
+      IndexType getSegmentSize( const IndexType segmentIdx ) const;
 
       /**
        * \brief Number segments.
@@ -50,8 +56,6 @@ class SlicedEllpack
       __cuda_callable__
       IndexType getSize() const;
 
-      __cuda_callable__
-      IndexType getSegmentSize( const IndexType segmentIdx ) const;
 
       __cuda_callable__
       IndexType getStorageSize() const;
@@ -90,9 +94,9 @@ class SlicedEllpack
 
    protected:
 
-      IndexType size;
+      IndexType size, alignedSize, segmentsCount;
 
-      OffsetHolder sliceOffsets;
+      OffsetsHolder sliceOffsets, sliceSegmentSizes;
 };
 
       } // namespace Segements
diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp
index c91a13473f..e23ee5f15b 100644
--- a/src/TNL/Containers/Segments/SlicedEllpack.hpp
+++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp
@@ -26,17 +26,30 @@ template< typename Device,
           int SliceSize >
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 SlicedEllpack()
-   : size( 0 )
+   : size( 0 ), alignedSize( 0 ), segmentsCount( 0 )
 {
 }
 
+template< typename Device,
+          typename Index,
+          bool RowMajorOrder,
+          int SliceSize >
+SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
+SlicedEllpack( const Vector< IndexType, DeviceType, IndexType >& sizes )
+   : size( 0 ), alignedSize( 0 ), segmentsCount( 0 )
+{
+   this->setSegmentsSizes( sizes );
+}
+
 template< typename Device,
           typename Index,
           bool RowMajorOrder,
           int SliceSize >
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 SlicedEllpack( const SlicedEllpack& slicedEllpack )
-   : size( slicedEllpack.size ), sliceOffsets( slicedEllpack.sliceOffsets )
+   : size( slicedEllpack.size ), alignedSize( slicedEllpack.alignedSize ),
+     segmentsCount( slicedEllpack.segmentsCount ), sliceOffsets( slicedEllpack.sliceOffsets ),
+     sliceSegmentSizes( slicedEllpack.sliceSegmentSizes )
 {
 }
 
@@ -46,7 +59,9 @@ template< typename Device,
           int SliceSize >
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 SlicedEllpack( const SlicedEllpack&& slicedEllpack )
-   : size( slicedEllpack.size ), sliceOffsets( slicedEllpack.sliceOffsets )
+   : size( slicedEllpack.size ), alignedSize( slicedEllpack.alignedSize ),
+     segmentsCount( slicedEllpack.segmentsCount ), sliceOffsets( slicedEllpack.sliceOffsets ),
+     sliceSegmentSizes( slicedEllpack.sliceSegmentSizes )
 {
 }
 
@@ -57,36 +72,36 @@ template< typename Device,
    template< typename SizesHolder >
 void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
-setSizes( const SizesHolder& sizes )
+setSegmentsSizes( const SizesHolder& sizes )
 {
-   this->size = sizes.getSize();
-   const IndexType segmentsCount = roundUpDivision( this->size, getSliceSize() );
-   this->segmentOffsets.setSize( segmentsCount + 1 );
+   this->segmentsCount = sizes.getSize();
+   const IndexType slicesCount = roundUpDivision( this->segmentsCount, getSliceSize() );
+   this->sliceOffsets.setSize( slicesCount + 1 );
+   this->sliceOffsets = 0;
+   this->sliceSegmentSizes.setSize( slicesCount );
    Ellpack< DeviceType, IndexType, true > ellpack;
-   ellpack.setSizes( segmentsCount, SliceSize );
+   ellpack.setSegmentsSizes( slicesCount, SliceSize );
 
-   const IndexType _size = this->getSize();
+   const IndexType _size = sizes.getSize();
    const auto sizes_view = sizes.getConstView();
-   auto offsets_view = this->segmentOffsets().getView();
+   auto slices_view = this->sliceOffsets.getView();
+   auto slice_segment_size_view = this->sliceSegmentSizes.getView();
    auto fetch = [=] __cuda_callable__ ( IndexType segmentIdx, IndexType globalIdx ) -> IndexType {
-      if( globalIdx < size )
+      if( globalIdx < _size )
          return sizes_view[ globalIdx ];
+      return 0;
    };
    auto reduce = [] __cuda_callable__ ( IndexType& aux, const IndexType i ) {
       aux = TNL::max( aux, i );
    };
-   auto keep = [=] __cuda_callable__ ( IndexType i, IndexType res ) {
-      offsets_view[ i ] = res;
-   }
-
-   std::cerr << offsets_view << std::endl;
-
-
-
-   if( RowMajorOrder )
-      this->alignedSize = this->size;
-   else
-      this->alignedSize = roundUpDivision( size, this->getSliceSize() ) * this->getSliceSize();
+   auto keep = [=] __cuda_callable__ ( IndexType i, IndexType res ) mutable {
+      slices_view[ i ] = res * SliceSize;
+      slice_segment_size_view[ i ] = res;
+   };
+   ellpack.allReduction( fetch, reduce, keep, std::numeric_limits< IndexType >::min() );
+   this->sliceOffsets.template scan< Algorithms::ScanType::Exclusive >();
+   this->size = sum( sizes );
+   this->alignedSize = this->sliceOffsets.getElement( slicesCount );
 }
 
 template< typename Device,
@@ -96,9 +111,9 @@ template< typename Device,
 __cuda_callable__
 Index
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
-getSize() const
+getSegmentsCount() const
 {
-   return this->size;
+   return this->segmentsCount;
 }
 
 template< typename Device,
@@ -110,7 +125,29 @@ Index
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 getSegmentSize( const IndexType segmentIdx ) const
 {
-   return this->segmentSize;
+   const Index sliceIdx = segmentIdx / SliceSize;
+   if( std::is_same< DeviceType, Devices::Host >::value )
+      return this->sliceSegmentSizes[ sliceIdx ];
+   else
+   {
+#ifdef __CUDA_ARCH__
+   return this->sliceSegmentSizes[ sliceIdx ];
+#else
+   return this->sliceSegmentSizes.getElement( sliceIdx );
+#endif
+   }
+}
+
+template< typename Device,
+          typename Index,
+          bool RowMajorOrder,
+          int SliceSize >
+__cuda_callable__
+Index
+SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
+getSize() const
+{
+   return this->size;
 }
 
 template< typename Device,
@@ -122,7 +159,7 @@ Index
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 getStorageSize() const
 {
-   return this->alignedSize * this->segmentSize;
+   return this->alignedSize;
 }
 
 template< typename Device,
@@ -134,10 +171,28 @@ Index
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 getGlobalIndex( const Index segmentIdx, const Index localIdx ) const
 {
+   const IndexType sliceIdx = segmentIdx / SliceSize;
+   const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
+   IndexType sliceOffset, segmentSize;
+   if( std::is_same< DeviceType, Devices::Host >::value )
+   {
+      sliceOffset = this->sliceOffsets[ sliceIdx ];
+      segmentSize = this->sliceSegmentSizes[ sliceIdx ];
+   }
+   else
+   {
+#ifdef __CUDA__ARCH__
+      sliceOffset = this->sliceOffsets[ sliceIdx ];
+      segmentSize = this->sliceSegmentSizes[ sliceIdx ];
+#else
+      sliceOffset = this->sliceOffsets.getElement( sliceIdx );
+      segmentSize = this->sliceSegmentSizes.getElement( sliceIdx );
+#endif
+   }
    if( RowMajorOrder )
-      return segmentIdx * this->segmentSize + localIdx;
+      return sliceOffset + segmentInSliceIdx * segmentSize + localIdx;
    else
-      return segmentIdx + this->alignedSize * localIdx;
+      return sliceOffset + segmentInSliceIdx + SliceSize * localIdx;
 }
 
 template< typename Device,
@@ -160,28 +215,32 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
 {
-   const auto offsetsView = this->offsets.getView();
+   const auto sliceSegmentSizes_view = this->sliceSegmentSizes.getConstView();
+   const auto sliceOffsets_view = this->sliceOffsets.getConstView();
    if( RowMajorOrder )
    {
-      const IndexType segmentSize = this->segmentSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) {
-         const IndexType begin = i * segmentSize;
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) {
+         const IndexType sliceIdx = segmentIdx / SliceSize;
+         const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
+         const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
+         const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx * segmentSize;
          const IndexType end = begin + segmentSize;
-         for( IndexType j = begin; j < end; j++  )
-            if( ! f( i, j, args... ) )
+         for( IndexType globalIdx = begin; globalIdx < end; globalIdx++  )
+            if( ! f( segmentIdx, globalIdx, args... ) )
                break;
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
    }
    else
    {
-      const IndexType storageSize = this->getStorageSize();
-      const IndexType alignedSize = this->alignedSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) {
-         const IndexType begin = i;
-         const IndexType end = storageSize;
-         for( IndexType j = begin; j < end; j += alignedSize )
-            if( ! f( i, j, args... ) )
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) {
+         const IndexType sliceIdx = segmentIdx / SliceSize;
+         const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
+         const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
+         const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx;
+         const IndexType end = sliceOffsets_view[ sliceIdx + 1 ];
+         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize )
+            if( ! f( segmentIdx, globalIdx, args... ) )
                break;
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
@@ -197,7 +256,7 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 forAll( Function& f, Args... args ) const
 {
-   this->forSegments( 0, this->getSize(), f, args... );
+   this->forSegments( 0, this->getSegmentsCount(), f, args... );
 }
 
 template< typename Device,
@@ -209,32 +268,36 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
+   using RealType = decltype( fetch( IndexType(), IndexType() ) );
+   const auto sliceSegmentSizes_view = this->sliceSegmentSizes.getConstView();
+   const auto sliceOffsets_view = this->sliceOffsets.getConstView();
    if( RowMajorOrder )
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
-      const IndexType segmentSize = this->segmentSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
-         const IndexType begin = i * segmentSize;
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
+         const IndexType sliceIdx = segmentIdx / SliceSize;
+         const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
+         const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
+         const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx * segmentSize;
          const IndexType end = begin + segmentSize;
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j++  )
-            reduction( aux, fetch( i, j, args... ) );
-         keeper( i, aux );
+         for( IndexType globalIdx = begin; globalIdx< end; globalIdx++  )
+            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
    }
    else
    {
-      using RealType = decltype( fetch( IndexType(), IndexType() ) );
-      const IndexType storageSize = this->getStorageSize();
-      const IndexType alignedSize = this->alignedSize;
-      auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
-         const IndexType begin = i;
-         const IndexType end = storageSize;
+      auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
+         const IndexType sliceIdx = segmentIdx / SliceSize;
+         const IndexType segmentInSliceIdx = segmentIdx % SliceSize;
+         const IndexType segmentSize = sliceSegmentSizes_view[ sliceIdx ];
+         const IndexType begin = sliceOffsets_view[ sliceIdx ] + segmentInSliceIdx;
+         const IndexType end = sliceOffsets_view[ sliceIdx + 1 ];
          RealType aux( zero );
-         for( IndexType j = begin; j < end; j += alignedSize  )
-            reduction( aux, fetch( i, j, args... ) );
-         keeper( i, aux );
+         for( IndexType globalIdx = begin; globalIdx < end; globalIdx += SliceSize  )
+            reduction( aux, fetch( segmentIdx, globalIdx, args... ) );
+         keeper( segmentIdx, aux );
       };
       Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
    }
@@ -249,7 +312,7 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
 {
-   this->segmentsReduction( 0, this->getSize(), fetch, reduction, keeper, zero, args... );
+   this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
 }
 
 template< typename Device,
@@ -260,9 +323,11 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 save( File& file ) const
 {
-   file.save( &segmentSize );
    file.save( &size );
    file.save( &alignedSize );
+   file.save( &segmentsCount );
+   this->sliceOffsets.save( file );
+   this->sliceSegmentSizes.save( file );
 }
 
 template< typename Device,
@@ -273,9 +338,11 @@ void
 SlicedEllpack< Device, Index, RowMajorOrder, SliceSize >::
 load( File& file )
 {
-   file.load( &segmentSize );
    file.load( &size );
    file.load( &alignedSize );
+   file.load( &segmentsCount );
+   this->sliceOffsets.load( file );
+   this->sliceSegmentSizes.load( file );
 }
 
       } // namespace Segments
-- 
GitLab