From 7218a64dcb5d8a2d6fb616afc2bc66b6ca4bf1a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1=C5=A1=20Oberhuber?= <oberhuber.tomas@gmail.com> Date: Fri, 3 Jan 2020 11:47:01 +0100 Subject: [PATCH] Replacing CSRSegmentView and EllpackSegment view with one general but specialized SegmentView. --- src/TNL/Containers/Segments/CSR.h | 4 +- src/TNL/Containers/Segments/CSR.hpp | 2 +- src/TNL/Containers/Segments/CSRSegmentView.h | 47 ------------------- src/TNL/Containers/Segments/CSRView.h | 4 +- src/TNL/Containers/Segments/CSRView.hpp | 2 +- src/TNL/Containers/Segments/Ellpack.h | 4 +- src/TNL/Containers/Segments/Ellpack.hpp | 4 +- src/TNL/Containers/Segments/EllpackView.h | 4 +- .../{EllpackSegmentView.h => SegmentView.h} | 46 ++++++++++++++++-- src/TNL/Containers/Segments/SlicedEllpack.h | 4 +- src/TNL/Containers/Segments/SlicedEllpack.hpp | 4 +- .../Containers/Segments/SlicedEllpackView.h | 4 +- 12 files changed, 59 insertions(+), 70 deletions(-) delete mode 100644 src/TNL/Containers/Segments/CSRSegmentView.h rename src/TNL/Containers/Segments/{EllpackSegmentView.h => SegmentView.h} (51%) diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index df7cb5686e..3645e9f6a1 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -14,7 +14,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/CSRView.h> -#include <TNL/Containers/Segments/CSRSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -35,7 +35,7 @@ class CSR using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentViewType = CSRSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, true >; CSR(); diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index 83da548fc9..8b8ddfff51 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -176,7 +176,7 @@ auto CSR< Device, Index, IndexAllocator >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { - return SegmentView( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); + return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/CSRSegmentView.h b/src/TNL/Containers/Segments/CSRSegmentView.h deleted file mode 100644 index 3ab5ef9d2e..0000000000 --- a/src/TNL/Containers/Segments/CSRSegmentView.h +++ /dev/null @@ -1,47 +0,0 @@ -/*************************************************************************** - CSRSegmentView.h - description - ------------------- - begin : Dec 28, 2019 - copyright : (C) 2019 by Tomas Oberhuber - email : tomas.oberhuber@fjfi.cvut.cz - ***************************************************************************/ - -/* See Copyright Notice in tnl/Copyright */ - -#pragma once - -namespace TNL { - namespace Containers { - namespace Segments { - -template< typename Index > -class CSRSegmentView -{ - public: - - using IndexType = Index; - - __cuda_callable__ - CSRSegmentView( const IndexType offset, const IndexType size ) - : segmentOffset( offset ), segmentSize( size ){}; - - __cuda_callable__ - IndexType getSize() const - { - return this->segmentSize; - }; - - __cuda_callable__ - IndexType getGlobalIndex( const IndexType localIndex ) const - { - TNL_ASSERT_LT( localIndex, segmentSize, "Local index exceeds segment bounds." ); - return segmentOffset + localIndex; - }; - - protected: - - IndexType segmentOffset, segmentSize; -}; - } //namespace Segments - } //namespace Containers -} //namespace TNL \ No newline at end of file diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index f8bcacd0fd..a0f5cd200d 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/CSRSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -33,7 +33,7 @@ class CSRView template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentViewType = CSRSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType >; __cuda_callable__ CSRView(); diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp index b0bb353133..bbed8e3cb0 100644 --- a/src/TNL/Containers/Segments/CSRView.hpp +++ b/src/TNL/Containers/Segments/CSRView.hpp @@ -156,7 +156,7 @@ auto CSRView< Device, Index >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { - return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); + return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ], 1 ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index f73155335d..4296156473 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -12,7 +12,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/EllpackView.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -37,7 +37,7 @@ class Ellpack using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >; //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; Ellpack(); diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp index ebc2b360eb..97a256c9e2 100644 --- a/src/TNL/Containers/Segments/Ellpack.hpp +++ b/src/TNL/Containers/Segments/Ellpack.hpp @@ -239,9 +239,9 @@ Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { if( RowMajorOrder ) - return SegmentView( segmentIdx * this->segmentSize, this->segmentSize, 1 ); + return SegmentViewType( segmentIdx * this->segmentSize, this->segmentSize, 1 ); else - return SegmentView( segmentIdx, this->segmentSize, this->alignedSize ); + return SegmentViewType( segmentIdx, this->segmentSize, this->alignedSize ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h index 682eeeb4a7..737810498a 100644 --- a/src/TNL/Containers/Segments/EllpackView.h +++ b/src/TNL/Containers/Segments/EllpackView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { @@ -38,7 +38,7 @@ class EllpackView using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView; //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; __cuda_callable__ EllpackView(); diff --git a/src/TNL/Containers/Segments/EllpackSegmentView.h b/src/TNL/Containers/Segments/SegmentView.h similarity index 51% rename from src/TNL/Containers/Segments/EllpackSegmentView.h rename to src/TNL/Containers/Segments/SegmentView.h index 7a1638e3fe..29f2e77813 100644 --- a/src/TNL/Containers/Segments/EllpackSegmentView.h +++ b/src/TNL/Containers/Segments/SegmentView.h @@ -1,5 +1,5 @@ /*************************************************************************** - EllpackSegmentView.h - description + SegmentView.h - description ------------------- begin : Dec 28, 2019 copyright : (C) 2019 by Tomas Oberhuber @@ -14,17 +14,21 @@ namespace TNL { namespace Containers { namespace Segments { +template< typename Index, + bool RowMajorOrder = false > +class SegmentView; + template< typename Index > -class EllpackSegmentView +class SegmentView< Index, false > { public: using IndexType = Index; __cuda_callable__ - EllpackSegmentView( const IndexType offset, - const IndexType size, - const IndexType step ) + SegmentView( const IndexType offset, + const IndexType size, + const IndexType step ) : segmentOffset( offset ), segmentSize( size ), step( step ){}; __cuda_callable__ @@ -44,6 +48,38 @@ class EllpackSegmentView IndexType segmentOffset, segmentSize, step; }; + +template< typename Index > +class SegmentView< Index, true > +{ + public: + + using IndexType = Index; + + __cuda_callable__ + SegmentView( const IndexType offset, + const IndexType size, + const IndexType step = 1 ) // For compatibility with previous specialization + : segmentOffset( offset ), segmentSize( size ){}; + + __cuda_callable__ + IndexType getSize() const + { + return this->segmentSize; + }; + + __cuda_callable__ + IndexType getGlobalIndex( const IndexType localIndex ) const + { + TNL_ASSERT_LT( localIndex, segmentSize, "Local index exceeds segment bounds." ); + return segmentOffset + localIndex; + }; + + protected: + + IndexType segmentOffset, segmentSize; +}; + } //namespace Segments } //namespace Containers } //namespace TNL diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index 76185bcace..5953cde360 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -13,7 +13,7 @@ #include <TNL/Allocators/Default.h> #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/SlicedEllpackView.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -37,7 +37,7 @@ class SlicedEllpack template< typename Device_, typename Index_ > using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; SlicedEllpack(); diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp index b58b6a954f..76790f3933 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.hpp +++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp @@ -269,9 +269,9 @@ getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType const IndexType& segmentSize = this->sliceSegmentSizes[ sliceIdx ]; if( RowMajorOrder ) - return SegmentView( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); + return SegmentViewType( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); else - return SegmentView( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); + return SegmentViewType( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h index e87c75229a..86745e7c08 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.h +++ b/src/TNL/Containers/Segments/SlicedEllpackView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -36,7 +36,7 @@ class SlicedEllpackView using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ViewType = SlicedEllpackView; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; __cuda_callable__ SlicedEllpackView(); -- GitLab