Skip to content
Snippets Groups Projects
Commit cbf4c7a3 authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Added Segments.

parent 7e5bf8aa
No related branches found
No related tags found
1 merge request!48Segments
/***************************************************************************
Segments.h - description
-------------------
begin : Nov 29, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
namespace TNL {
namespace Containers {
template< typename Value,
typename Organization >
class Segments
{
public:
using ValueType = Value;
using OrganizationType = Organization;
using IndexType = typename Organization::IndexType;
};
} // namespace Conatiners
} // namespace TNL
\ No newline at end of file
/***************************************************************************
CSR.h - description
-------------------
begin : Nov 29, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
#include <TNL/Containers/Vector.h>
namespace TNL {
namespace Containers {
namespace Segments {
template< typename Device,
typename Index >
class Segments
{
public:
using DeviceType = Device;
using IndexType = Index;
using OffsetsHolder = Containers::Vector< IndexType, DeviceType, IndexType >;
CSR();
CSR( const SizesHolder& sizes );
CSR( const CSR& csr );
CSR( const CSR&& csr );
/**
* \brief Set number of segments
*/
//void setSegmentsCount();
/**
* \brief Set sizes of particular segmenets.
*/
template< typename SizesHolder = OffsetsHolder >
void setSizes( const SizesHolder& sizes )
/**
* \brief Number segments.
*/
Index getSize() const;
Index getStorageSize() const;
IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const;
void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const;
/***
* \brief Go over all segments and for each segment element call
* function 'f' with arguments 'args'
*/
template< typename Function, typename... Args >
void forAll( Function& f, Args args ) const;
/***
* \brief Go over all segments and perform a reduction in each of them.
*/
template< typename Fetch, typename Reduction, typename ResultKeeper, typename... Args >
void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Args args );
protected:
OffsetsHolder offsets;
};
} // namespace Segements
} // namespace Conatiners
} // namespace TNL
#include <TNL/Containers/Segments/CSR.h>
\ No newline at end of file
/***************************************************************************
CSR.hpp - description
-------------------
begin : Nov 29, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/ParalleFor.h>
#include <TNL/Containers/Segments/CSR.h>
namespace TNL {
namespace Containers {
namespace Segments {
template< typename Device,
typename Index >
CSR< Device, Index >::
CSR()
{
}
template< typename Device,
typename Index >
CSR< Device, Index >::
CSR( const CSR& csr ) : offsets( csr.offsets )
{
}
template< typename Device,
typename Index >
CSR< Device, Index >::
CSR( const CSR&& csr ) : offsets( std::move( csr.offsets ) )
{
}
template< typename Device,
typename Index >
CSR< Device, Index >::
void setSegmentsCount( const IndexType& size )
{
this->offsets.setSize( size + 1 );
}
template< typename Device,
typename Index >
template< typename SizesHolder = OffsetsHolder >
CSR< Device, Index >::
void setSizes( const SizesHolder& sizes )
{
this->offsets.setSize( sizes.getSize() + 1 );
auto view = this->offsets.getView( 0, sizes.getSize() );
view = sizes;
this->offsets.setElement( sizes.getSize>(), 0 );
this->offsets.scan< Algorithms::ScanType::Exclusive >();
}
template< typename Device,
typename Index >
CSR< Device, Index >::
Index getSize() const
{
return this->offsets.getSize() - 1;
}
template< typename Device,
typename Index >
template< typename Function, typename... Args >
CSR< Device, Index >::
void forAll( Function& f, Args args ) const
{
const auto offsetsView = this->offsets.getView();
auto f = [=] __cuda_callable__ ( const IndexType i, f, args ) {
const IndexType begin = offsetsView[ i ];
const IndexType end = offsetsView[ i + 1 ];
for( IndexType j = begin; j < end; j++ )
f( i, j, args );
};
Algorithms::ParallelFor< Device >::exec( 0, this->getSize(), f );
}
template< typename Device,
typename Index >
template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
CSR< Device, Index >::
void segmentsReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, Real zero, Args args )
{
const auto offsetsView = this->offsets.getView();
auto f = [=] __cuda_callable__ ( const IndexType i, f, args ) {
const IndexType begin = offsetsView[ i ];
const IndexType end = offsetsView[ i + 1 ];
Real aux( zero );
for( IndexType j = begin; j < end; j++ )
reduction( aux, fetch( i, j, args ) );
keeper( i, aux );
};
Algorithms::ParallelFor< Device >::exec( 0, this->getSize(), f );
}
} // namespace Segements
} // namespace Conatiners
} // namespace TNL
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment