diff --git a/src/TNL/Containers/Segments.h b/src/TNL/Containers/Segments.h new file mode 100644 index 0000000000000000000000000000000000000000..99ea2235722cc0d3f7b594c95e42ab781800903b --- /dev/null +++ b/src/TNL/Containers/Segments.h @@ -0,0 +1,29 @@ +/*************************************************************************** + Segments.h - description + ------------------- + begin : Nov 29, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +namespace TNL { +namespace Containers { + +template< typename Value, + typename Organization > +class Segments +{ + public: + + using ValueType = Value; + using OrganizationType = Organization; + using IndexType = typename Organization::IndexType; + +}; + +} // namespace Conatiners +} // namespace TNL \ No newline at end of file diff --git a/src/TNL/Containers/Segments/CSR.h b/src/TNL/Containers/Segments/CSR.h new file mode 100644 index 0000000000000000000000000000000000000000..3aa53e76cbda272ca828126e418fdc3fe43d069a --- /dev/null +++ b/src/TNL/Containers/Segments/CSR.h @@ -0,0 +1,83 @@ +/*************************************************************************** + CSR.h - description + ------------------- + begin : Nov 29, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +#include <TNL/Containers/Vector.h> + +namespace TNL { + namespace Containers { + namespace Segments { + + +template< typename Device, + typename Index > +class Segments +{ + public: + + using DeviceType = Device; + using IndexType = Index; + using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >; + + CSR(); + + CSR( const SizesHolder& sizes ); + + CSR( const CSR& csr ); + + CSR( const CSR&& csr ); + + /** + * \brief Set number of segments + */ + //void setSegmentsCount(); + + /** + * \brief Set sizes of particular segmenets. + */ + template< typename SizesHolder = OffsetsHolder > + void setSizes( const SizesHolder& sizes ) + + /** + * \brief Number segments. + */ + Index getSize() const; + + Index getStorageSize() const; + + IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const; + + void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const; + + /*** + * \brief Go over all segments and for each segment element call + * function 'f' with arguments 'args' + */ + template< typename Function, typename... Args > + void forAll( Function& f, Args args ) const; + + /*** + * \brief Go over all segments and perform a reduction in each of them. + */ + template< typename Fetch, typename Reduction, typename ResultKeeper, typename... Args > + void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args args ); + + protected: + + OffsetsHolder offsets; + +}; + + } // namespace Segements + } // namespace Conatiners +} // namespace TNL + +#include <TNL/Containers/Segments/CSR.h> \ No newline at end of file diff --git a/src/TNL/Containers/Segments/CSR.hpp b/src/TNL/Containers/Segments/CSR.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ea45b40ba9d67b16203e0e7665859f5390653b9f --- /dev/null +++ b/src/TNL/Containers/Segments/CSR.hpp @@ -0,0 +1,110 @@ +/*************************************************************************** + CSR.hpp - description + ------------------- + begin : Nov 29, 2019 + copyright : (C) 2019 by Tomas Oberhuber + email : tomas.oberhuber@fjfi.cvut.cz + ***************************************************************************/ + +/* See Copyright Notice in tnl/Copyright */ + +#pragma once + +#include <TNL/Containers/Vector.h> +#include <TNL/Algorithms/ParalleFor.h> +#include <TNL/Containers/Segments/CSR.h> + +namespace TNL { + namespace Containers { + namespace Segments { + + +template< typename Device, + typename Index > +CSR< Device, Index >:: +CSR() +{ +} + +template< typename Device, + typename Index > +CSR< Device, Index >:: +CSR( const CSR& csr ) : offsets( csr.offsets ) +{ +} + +template< typename Device, + typename Index > +CSR< Device, Index >:: +CSR( const CSR&& csr ) : offsets( std::move( csr.offsets ) ) +{ + +} + +template< typename Device, + typename Index > +CSR< Device, Index >:: +void setSegmentsCount( const IndexType& size ) +{ + this->offsets.setSize( size + 1 ); +} + +template< typename Device, + typename Index > + template< typename SizesHolder = OffsetsHolder > +CSR< Device, Index >:: +void setSizes( const SizesHolder& sizes ) +{ + this->offsets.setSize( sizes.getSize() + 1 ); + auto view = this->offsets.getView( 0, sizes.getSize() ); + view = sizes; + this->offsets.setElement( sizes.getSize>(), 0 ); + this->offsets.scan< Algorithms::ScanType::Exclusive >(); +} + +template< typename Device, + typename Index > +CSR< Device, Index >:: +Index getSize() const +{ + return this->offsets.getSize() - 1; +} + +template< typename Device, + typename Index > + template< typename Function, typename... Args > +CSR< Device, Index >:: +void forAll( Function& f, Args args ) const +{ + const auto offsetsView = this->offsets.getView(); + auto f = [=] __cuda_callable__ ( const IndexType i, f, args ) { + const IndexType begin = offsetsView[ i ]; + const IndexType end = offsetsView[ i + 1 ]; + for( IndexType j = begin; j < end; j++ ) + f( i, j, args ); + }; + Algorithms::ParallelFor< Device >::exec( 0, this->getSize(), f ); +} + +template< typename Device, + typename Index > + template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args > +CSR< Device, Index >:: +void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Real zero, Args args ) +{ + const auto offsetsView = this->offsets.getView(); + auto f = [=] __cuda_callable__ ( const IndexType i, f, args ) { + const IndexType begin = offsetsView[ i ]; + const IndexType end = offsetsView[ i + 1 ]; + Real aux( zero ); + for( IndexType j = begin; j < end; j++ ) + reduction( aux, fetch( i, j, args ) ); + keeper( i, aux ); + }; + Algorithms::ParallelFor< Device >::exec( 0, this->getSize(), f ); +} + + + } // namespace Segements + } // namespace Conatiners +} // namespace TNL \ No newline at end of file