diff --git a/src/TNL/Containers/Segments/CSRView.h b/src/TNL/Containers/Segments/CSRView.h new file mode 100644 index 0000000000000000000000000000000000000000..5eeb7ecb3efafb5258f9100505874df4f8322538 --- /dev/null +++ b/src/TNL/Containers/Segments/CSRView.h @@ -0,0 +1,105 @@ +/*************************************************************************** + CSRView.h - description + ------------------- + begin : Dec 11, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +#include <TNL/Containers/Vector.h> + +namespace TNL { + namespace Containers { + namespace Segments { + +template< typename Device, + typename Index > +class CSRView +{ + public: + + using DeviceType = Device; + using IndexType = Index; + using OffsetsHolderView = typedef Containers::Vector< IndexType, DeviceType, IndexType >::ViewType; + + __cuda_callable__ + CSRView(); + + __cuda_callable__ + CSRView( const OffsetsHolderView& offsets ); + + __cuda_callable__ + CSRView( const CSRView& csr_view ); + + __cuda_callable__ + CSRView( const CSRView&& csr_view ); + + /** + * \brief Number segments. + */ + __cuda_callable__ + IndexType getSegmentsCount() const; + + /*** + * \brief Returns size of the segment number \r segmentIdx + */ + __cuda_callable__ + IndexType getSegmentSize( const IndexType segmentIdx ) const; + + /*** + * \brief Returns number of elements managed by all segments. + */ + __cuda_callable__ + IndexType getSize() const; + + /*** + * \brief Returns number of elements that needs to be allocated. + */ + __cuda_callable__ + IndexType getStorageSize() const; + + __cuda_callable__ + IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const; + + __cuda_callable__ + void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; + + /*** + * \brief Go over all segments and for each segment element call + * function 'f' with arguments 'args'. The return type of 'f' is bool. + * When its true, the for-loop continues. Once 'f' returns false, the for-loop + * is terminated. + */ + template< typename Function, typename... Args > + void forSegments( IndexType first, IndexType last, Function& f, Args... args ) const; + + template< typename Function, typename... Args > + void forAll( Function& f, Args... args ) const; + + + /*** + * \brief Go over all segments and perform a reduction in each of them. + */ + template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > + void segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + + template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > + void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const; + + void save( File& file ) const; + + void load( File& file ); + + protected: + + OffsetsHolderView offsets; +}; + } // namespace Segements + } // namespace Conatiners +} // namespace TNL + +#include <TNL/Containers/Segments/CSRView.hpp> diff --git a/src/TNL/Containers/Segments/CSRView.hpp b/src/TNL/Containers/Segments/CSRView.hpp new file mode 100644 index 0000000000000000000000000000000000000000..30ed24071e7b42b8cd27cf959e890366682319ed --- /dev/null +++ b/src/TNL/Containers/Segments/CSRView.hpp @@ -0,0 +1,221 @@ +/*************************************************************************** + CSRView.hpp - description + ------------------- + begin : Dec 11, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +#include <TNL/Containers/Vector.h> +#include <TNL/Algorithms/ParallelFor.h> +#include <TNL/Containers/Segments/CSRView.h> + +namespace TNL { + namespace Containers { + namespace Segments { + + +template< typename Device, + typename Index > +__cuda_callable__ +CSRView< Device, Index >:: +CSRView() +{ +} + +template< typename Device, + typename Index > +__cuda_callable__ +CSRView< Device, Index >:: +CSRView( const OffsetsHolderView& offsets_view ) + : offsets( offsets_view ) +{ +} + +template< typename Device, + typename Index > +__cuda_callable__ +CSRView< Device, Index >:: +CSRView( const CSRView& csr_view ) + : offsets( csr_view.offsest ) +{ + +} + +template< typename Device, + typename Index > +__cuda_callable__ +CSRView< Device, Index >:: +CSRView( const CSRView&& csr_view ) + : offsets( std::move( csr_view.offsest ) ) +{ + +} + +template< typename Device, + typename Index > +__cuda_callable__ +Index +CSRView< Device, Index >:: +getSegmentsCount() const +{ + return this->offsets.getSize() - 1; +} + +template< typename Device, + typename Index > +__cuda_callable__ +Index +CSRView< Device, Index >:: +getSegmentSize( const IndexType segmentIdx ) const +{ + if( ! std::is_same< DeviceType, Devices::Host >::value ) + { +#ifdef __CUDA_ARCH__ + return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; +#else + return offsets.getElement( segmentIdx + 1 ) - offsets.getElement( segmentIdx ); +#endif + } + return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ]; +} + +template< typename Device, + typename Index > +__cuda_callable__ +Index +CSRView< Device, Index >:: +getSize() const +{ + return this->getStorageSize(); +} + +template< typename Device, + typename Index > +__cuda_callable__ +Index +CSRView< Device, Index >:: +getStorageSize() const +{ + if( ! std::is_same< DeviceType, Devices::Host >::value ) + { +#ifdef __CUDA_ARCH__ + return offsets[ this->getSegmentsCount() ]; +#else + return offsets.getElement( this->getSegmentsCount() ); +#endif + } + return offsets[ this->getSegmentsCount() ]; +} + +template< typename Device, + typename Index > +__cuda_callable__ +Index +CSRView< Device, Index >:: +getGlobalIndex( const Index segmentIdx, const Index localIdx ) const +{ + if( ! std::is_same< DeviceType, Devices::Host >::value ) + { +#ifdef __CUDA_ARCH__ + return offsets[ segmentIdx ] + localIdx; +#else + return offsets.getElement( segmentIdx ) + localIdx; +#endif + } + return offsets[ segmentIdx ] + localIdx; +} + +template< typename Device, + typename Index > +__cuda_callable__ +void +CSRView< Device, Index >:: +getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const +{ +} + +template< typename Device, + typename Index > + template< typename Function, typename... Args > +void +CSRView< Device, Index >:: +forSegments( IndexType first, IndexType last, Function& f, Args... args ) const +{ + const auto offsetsView = this->offsets.getConstView(); + auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable { + const IndexType begin = offsetsView[ segmentIdx ]; + const IndexType end = offsetsView[ segmentIdx + 1 ]; + IndexType localIdx( 0 ); + for( IndexType globalIdx = begin; globalIdx < end; globalIdx++ ) + if( ! f( segmentIdx, localIdx++, globalIdx, args... ) ) + break; + }; + Algorithms::ParallelFor< Device >::exec( first, last, l, args... ); +} + +template< typename Device, + typename Index > + template< typename Function, typename... Args > +void +CSRView< Device, Index >:: +forAll( Function& f, Args... args ) const +{ + this->forSegments( 0, this->getSize(), f, args... ); +} + +template< typename Device, + typename Index > + template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > +void +CSRView< Device, Index >:: +segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const +{ + using RealType = decltype( fetch( IndexType(), IndexType() ) ); + const auto offsetsView = this->offsets.getConstView(); + auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable { + const IndexType begin = offsetsView[ i ]; + const IndexType end = offsetsView[ i + 1 ]; + RealType aux( zero ); + for( IndexType j = begin; j < end; j++ ) + reduction( aux, fetch( i, j, args... ) ); + keeper( i, aux ); + }; + Algorithms::ParallelFor< Device >::exec( first, last, l, args... ); +} + +template< typename Device, + typename Index > + template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > +void +CSRView< Device, Index >:: +allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const +{ + this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... ); +} + +template< typename Device, + typename Index > +void +CSRView< Device, Index >:: +save( File& file ) const +{ + file << this->offsets; +} + +template< typename Device, + typename Index > +void +CSRView< Device, Index >:: +load( File& file ) +{ + file >> this->offsets; +} + + } // namespace Segments + } // namespace Conatiners +} // namespace TNL