diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h index df7cb5686e03c8f7209c02febe0d039c4a1152c9..3645e9f6a1964ed3a68f8d93ebf48c85d1c98582 100644 --- a/src/TNL/Containers/Segments/CSR.h +++ b/src/TNL/Containers/Segments/CSR.h @@ -14,7 +14,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/CSRView.h> -#include <TNL/Containers/Segments/CSRSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -35,7 +35,7 @@ class CSR using ViewTemplate = CSRView< Device_, Index_ >; using ViewType = CSRView< Device, Index >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentViewType = CSRSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, true >; CSR(); diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index 83da548fc9eeefbea54c4d45883bd21cbe33f632..8b8ddfff51d27e0139d3ac8ef2f4350739a15091 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -176,7 +176,7 @@ auto CSR< Device, Index, IndexAllocator >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { - return SegmentView( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); + return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/CSRSegmentView.h b/src/TNL/Containers/Segments/CSRSegmentView.h deleted file mode 100644 index 3ab5ef9d2eb1afde64321a7741dc33cd54c4dcb8..0000000000000000000000000000000000000000 --- a/src/TNL/Containers/Segments/CSRSegmentView.h +++ /dev/null @@ -1,47 +0,0 @@ -/*************************************************************************** - CSRSegmentView.h - description - ------------------- - begin : Dec 28, 2019 - copyright : (C) 2019 by Tomas Oberhuber - email : tomas.oberhuber@fjfi.cvut.cz - ***************************************************************************/ - -/* See Copyright Notice in tnl/Copyright */ - -#pragma once - -namespace TNL { - namespace Containers { - namespace Segments { - -template< typename Index > -class CSRSegmentView -{ - public: - - using IndexType = Index; - - __cuda_callable__ - CSRSegmentView( const IndexType offset, const IndexType size ) - : segmentOffset( offset ), segmentSize( size ){}; - - __cuda_callable__ - IndexType getSize() const - { - return this->segmentSize; - }; - - __cuda_callable__ - IndexType getGlobalIndex( const IndexType localIndex ) const - { - TNL_ASSERT_LT( localIndex, segmentSize, "Local index exceeds segment bounds." ); - return segmentOffset + localIndex; - }; - - protected: - - IndexType segmentOffset, segmentSize; -}; - } //namespace Segments - } //namespace Containers -} //namespace TNL \ No newline at end of file diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h index f8bcacd0fd0e9b46c3d9aa91ab01cf399ee5c6e8..a0f5cd200d1d95ec708b294a03619f42a91a6fa6 100644 --- a/src/TNL/Containers/Segments/CSRView.h +++ b/src/TNL/Containers/Segments/CSRView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/CSRSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -33,7 +33,7 @@ class CSRView template< typename Device_, typename Index_ > using ViewTemplate = CSRView< Device_, Index_ >; using ConstViewType = CSRView< Device, std::add_const_t< Index > >; - using SegmentViewType = CSRSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType >; __cuda_callable__ CSRView(); diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp index b0bb3531336618f1463ef083de530eac905991dc..bbed8e3cb032c61a2e8add669012b8d21f8b720a 100644 --- a/src/TNL/Containers/Segments/CSRView.hpp +++ b/src/TNL/Containers/Segments/CSRView.hpp @@ -156,7 +156,7 @@ auto CSRView< Device, Index >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { - return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ] ); + return SegmentViewType( offsets[ segmentIdx ], offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ], 1 ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/Ellpack.h b/src/TNL/Containers/Segments/Ellpack.h index f73155335d11de10e0b86188cb832d072b0b8694..429615647391a381e0e58bb3565975f60350eb59 100644 --- a/src/TNL/Containers/Segments/Ellpack.h +++ b/src/TNL/Containers/Segments/Ellpack.h @@ -12,7 +12,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/EllpackView.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -37,7 +37,7 @@ class Ellpack using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView< Device, Index, RowMajorOrder, Alignment >; //using ConstViewType = EllpackView< Device, std::add_const_t< Index >, RowMajorOrder, Alignment >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; Ellpack(); diff --git a/src/TNL/Containers/Segments/Ellpack.hpp b/src/TNL/Containers/Segments/Ellpack.hpp index ebc2b360eb91a4100db7a1d6599b33536c74f4b6..97a256c9e26bc86023c48bc8f6428c2e78c7e47c 100644 --- a/src/TNL/Containers/Segments/Ellpack.hpp +++ b/src/TNL/Containers/Segments/Ellpack.hpp @@ -239,9 +239,9 @@ Ellpack< Device, Index, IndexAllocator, RowMajorOrder, Alignment >:: getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType { if( RowMajorOrder ) - return SegmentView( segmentIdx * this->segmentSize, this->segmentSize, 1 ); + return SegmentViewType( segmentIdx * this->segmentSize, this->segmentSize, 1 ); else - return SegmentView( segmentIdx, this->segmentSize, this->alignedSize ); + return SegmentViewType( segmentIdx, this->segmentSize, this->alignedSize ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/EllpackView.h b/src/TNL/Containers/Segments/EllpackView.h index 682eeeb4a76ee1b69d4a6a78b1c95023021ee770..737810498a94485e8e65efdeabd4c18f90083a95 100644 --- a/src/TNL/Containers/Segments/EllpackView.h +++ b/src/TNL/Containers/Segments/EllpackView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { @@ -38,7 +38,7 @@ class EllpackView using ViewTemplate = EllpackView< Device_, Index_ >; using ViewType = EllpackView; //using ConstViewType = EllpackView< Device, std::add_const_t< Index > >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; __cuda_callable__ EllpackView(); diff --git a/src/TNL/Containers/Segments/EllpackSegmentView.h b/src/TNL/Containers/Segments/SegmentView.h similarity index 51% rename from src/TNL/Containers/Segments/EllpackSegmentView.h rename to src/TNL/Containers/Segments/SegmentView.h index 7a1638e3fe82fed958f00f30c82f3d7309b0657d..29f2e778132ebe304c85298bee7aad3a283edce8 100644 --- a/src/TNL/Containers/Segments/EllpackSegmentView.h +++ b/src/TNL/Containers/Segments/SegmentView.h @@ -1,5 +1,5 @@ /*************************************************************************** - EllpackSegmentView.h - description + SegmentView.h - description ------------------- begin : Dec 28, 2019 copyright : (C) 2019 by Tomas Oberhuber @@ -14,17 +14,21 @@ namespace TNL { namespace Containers { namespace Segments { +template< typename Index, + bool RowMajorOrder = false > +class SegmentView; + template< typename Index > -class EllpackSegmentView +class SegmentView< Index, false > { public: using IndexType = Index; __cuda_callable__ - EllpackSegmentView( const IndexType offset, - const IndexType size, - const IndexType step ) + SegmentView( const IndexType offset, + const IndexType size, + const IndexType step ) : segmentOffset( offset ), segmentSize( size ), step( step ){}; __cuda_callable__ @@ -44,6 +48,38 @@ class EllpackSegmentView IndexType segmentOffset, segmentSize, step; }; + +template< typename Index > +class SegmentView< Index, true > +{ + public: + + using IndexType = Index; + + __cuda_callable__ + SegmentView( const IndexType offset, + const IndexType size, + const IndexType step = 1 ) // For compatibility with previous specialization + : segmentOffset( offset ), segmentSize( size ){}; + + __cuda_callable__ + IndexType getSize() const + { + return this->segmentSize; + }; + + __cuda_callable__ + IndexType getGlobalIndex( const IndexType localIndex ) const + { + TNL_ASSERT_LT( localIndex, segmentSize, "Local index exceeds segment bounds." ); + return segmentOffset + localIndex; + }; + + protected: + + IndexType segmentOffset, segmentSize; +}; + } //namespace Segments } //namespace Containers } //namespace TNL diff --git a/src/TNL/Containers/Segments/SlicedEllpack.h b/src/TNL/Containers/Segments/SlicedEllpack.h index 76185bcace2661e85881a0b9dc80fab599f34b2d..5953cde3605148ebb268b098778973ae5d41eb4d 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/SlicedEllpack.h @@ -13,7 +13,7 @@ #include <TNL/Allocators/Default.h> #include <TNL/Containers/Vector.h> #include <TNL/Containers/Segments/SlicedEllpackView.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -37,7 +37,7 @@ class SlicedEllpack template< typename Device_, typename Index_ > using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index >, RowMajorOrder, SliceSize >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; SlicedEllpack(); diff --git a/src/TNL/Containers/Segments/SlicedEllpack.hpp b/src/TNL/Containers/Segments/SlicedEllpack.hpp index b58b6a954f2b6b60474665ef3f55497bb7ff19e9..76790f393371a17da72759f9694129359d4a7495 100644 --- a/src/TNL/Containers/Segments/SlicedEllpack.hpp +++ b/src/TNL/Containers/Segments/SlicedEllpack.hpp @@ -269,9 +269,9 @@ getSegmentView( const IndexType segmentIdx ) const -> SegmentViewType const IndexType& segmentSize = this->sliceSegmentSizes[ sliceIdx ]; if( RowMajorOrder ) - return SegmentView( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); + return SegmentViewType( sliceOffset + segmentInSliceIdx * segmentSize, segmentSize, 1 ); else - return SegmentView( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); + return SegmentViewType( sliceOffset + segmentInSliceIdx, segmentSize, SliceSize ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/SlicedEllpackView.h b/src/TNL/Containers/Segments/SlicedEllpackView.h index e87c75229a9c16ad17023204fdd838ecdc09c425..86745e7c086cc6ed961f738810c52e4220d4712b 100644 --- a/src/TNL/Containers/Segments/SlicedEllpackView.h +++ b/src/TNL/Containers/Segments/SlicedEllpackView.h @@ -13,7 +13,7 @@ #include <type_traits> #include <TNL/Containers/Vector.h> -#include <TNL/Containers/Segments/EllpackSegmentView.h> +#include <TNL/Containers/Segments/SegmentView.h> namespace TNL { namespace Containers { @@ -36,7 +36,7 @@ class SlicedEllpackView using ViewTemplate = SlicedEllpackView< Device_, Index_ >; using ViewType = SlicedEllpackView; using ConstViewType = SlicedEllpackView< Device, std::add_const_t< Index > >; - using SegmentViewType = EllpackSegmentView< IndexType >; + using SegmentViewType = SegmentView< IndexType, RowMajorOrder >; __cuda_callable__ SlicedEllpackView();