diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp index a8f12e7dc3e3db1504a54a045e92b8cb5f0cbbfc..280ed6ebf212d0419ab16ac7dbccb7ddd664c507 100644 --- a/src/TNL/Containers/Segments/CSR.hpp +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -63,11 +63,6 @@ CSR< Device, Index, IndexAllocator >:: setSegmentsSizes( const SizesHolder& sizes ) { details::CSR< Device, Index >::setSegmentsSizes( sizes, this->offsets ); - /*this->offsets.setSize( sizes.getSize() + 1 ); - auto view = this->offsets.getView( 0, sizes.getSize() ); - view = sizes; - this->offsets.setElement( sizes.getSize(), 0 ); - this->offsets.template scan< Algorithms::ScanType::Exclusive >();*/ } template< typename Device, @@ -109,15 +104,7 @@ Index CSR< Device, Index, IndexAllocator >:: getSegmentSize( const IndexType segmentIdx ) const { - if( ! std::is_same< DeviceType, Devices::Host >::value ) - { -#ifdef __CUDA_ARCH__ - return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; -#else - return offsets.getElement( segmentIdx + 1 ) - offsets.getElement( segmentIdx ); -#endif - } - return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; + return details::CSR< Device, Index >::getSegmentSize( this->offsets, segmentIdx ); } template< typename Device, @@ -139,15 +126,7 @@ Index CSR< Device, Index, IndexAllocator >:: getStorageSize() const { - if( ! std::is_same< DeviceType, Devices::Host >::value ) - { -#ifdef __CUDA_ARCH__ - return offsets[ this->getSegmentsCount() ]; -#else - return offsets.getElement( this->getSegmentsCount() ); -#endif - } - return offsets[ this->getSegmentsCount() ]; + return details::CSR< Device, Index >::getStorageSize( this->offsets ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp index f50a74985f050f18b5913008636552e1bbf4f760..dd4c434ba079030df67ca7912931b7543fb506e9 100644 --- a/src/TNL/Containers/Segments/CSRView.hpp +++ b/src/TNL/Containers/Segments/CSRView.hpp @@ -13,6 +13,7 @@ #include <TNL/Containers/Vector.h> #include <TNL/Algorithms/ParallelFor.h> #include <TNL/Containers/Segments/CSRView.h> +#include <TNL/Containers/Segments/details/CSR.h> namespace TNL { namespace Containers { @@ -98,15 +99,7 @@ Index CSRView< Device, Index >:: getSegmentSize( const IndexType segmentIdx ) const { - if( ! std::is_same< DeviceType, Devices::Host >::value ) - { -#ifdef __CUDA_ARCH__ - return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; -#else - return offsets.getElement( segmentIdx + 1 ) - offsets.getElement( segmentIdx ); -#endif - } - return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; + return details::CSR< Device, Index >::getSegmentSize( this->offsets, segmentIdx ); } template< typename Device, @@ -126,15 +119,7 @@ Index CSRView< Device, Index >:: getStorageSize() const { - if( ! std::is_same< DeviceType, Devices::Host >::value ) - { -#ifdef __CUDA_ARCH__ - return offsets[ this->getSegmentsCount() ]; -#else - return offsets.getElement( this->getSegmentsCount() ); -#endif - } - return offsets[ this->getSegmentsCount() ]; + return details::CSR< Device, Index >::getStorageSize( this->offsets ); } template< typename Device, diff --git a/src/TNL/Containers/Segments/details/CSR.h b/src/TNL/Containers/Segments/details/CSR.h index 47e768d289cb307957c117402a4a45ce8cd54c7e..38f097669150b7e3f929bdeab3beb1af03ce3e7d 100644 --- a/src/TNL/Containers/Segments/details/CSR.h +++ b/src/TNL/Containers/Segments/details/CSR.h @@ -35,23 +35,48 @@ class CSR offsets.template scan< Algorithms::ScanType::Exclusive >(); } - /*** - * \brief Returns size of the segment number \r segmentIdx - */ + template< typename CSROffsets > __cuda_callable__ - IndexType getSegmentSize( const IndexType segmentIdx ) const; + static IndexType getSegmentsCount( const CSROffsets& offsets ) + { + return offsets.getSize() - 1; + } /*** - * \brief Returns number of elements managed by all segments. + * \brief Returns size of the segment number \r segmentIdx */ + template< typename CSROffsets > __cuda_callable__ - IndexType getSize() const; + static IndexType getSegmentSize( const CSROffsets& offsets, const IndexType segmentIdx ) + { + if( ! std::is_same< DeviceType, Devices::Host >::value ) + { +#ifdef __CUDA_ARCH__ + return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; +#else + return offsets.getElement( segmentIdx + 1 ) - offsets.getElement( segmentIdx ); +#endif + } + return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; + } /*** * \brief Returns number of elements that needs to be allocated. */ + template< typename CSROffsets > __cuda_callable__ - IndexType getStorageSize() const; + static IndexType getStorageSize( const CSROffsets& offsets ) + { + if( ! std::is_same< DeviceType, Devices::Host >::value ) + { +#ifdef __CUDA_ARCH__ + return offsets[ getSegmentsCount( offsets ) ]; +#else + return offsets.getElement( getSegmentsCount( offsets ) ); +#endif + } + return offsets[ getSegmentsCount( offsets ) ]; + } __cuda_callable__ IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const; @@ -85,5 +110,3 @@ class CSR } // namespace Segements } // namespace Conatiners } // namespace TNL - -#include <TNL/Containers/Segments/CSR.hpp> diff --git a/src/TNL/Containers/Segments/details/Ellpack.h b/src/TNL/Containers/Segments/details/Ellpack.h index b08ad0f04f9d316f6e2ce62aea5d8990c1204978..ecfe63107325793717482b3710c9533a153c34c1 100644 --- a/src/TNL/Containers/Segments/details/Ellpack.h +++ b/src/TNL/Containers/Segments/details/Ellpack.h @@ -103,5 +103,3 @@ class Ellpack } // namespace Segements } // namespace Conatiners } // namespace TNL - -#include <TNL/Containers/Segments/Ellpack.hpp> diff --git a/src/TNL/Containers/Segments/details/SlicedEllpack.h b/src/TNL/Containers/Segments/details/SlicedEllpack.h index ecc2c8c7ef1d8fa24d418372c07c2f769ab75cc9..6f185bc469e1c1826348b5662735d6a2992fc087 100644 --- a/src/TNL/Containers/Segments/details/SlicedEllpack.h +++ b/src/TNL/Containers/Segments/details/SlicedEllpack.h @@ -102,5 +102,3 @@ class SlicedEllpack } // namespace Segements } // namespace Conatiners } // namespace TNL - -#include <TNL/Containers/Segments/SlicedEllpack.hpp>