Skip to content
Snippets Groups Projects
Commit 170d652a authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Adding CSRView.

parent 91d38ffa
No related branches found
No related tags found
1 merge request!48Segments
/***************************************************************************
CSRView.h - description
-------------------
begin : Dec 11, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
#include <TNL/Containers/Vector.h>
namespace TNL {
namespace Containers {
namespace Segments {
template< typename Device,
typename Index >
class CSRView
{
public:
using DeviceType = Device;
using IndexType = Index;
using OffsetsHolderView = typedef Containers::Vector< IndexType, DeviceType, IndexType >::ViewType;
__cuda_callable__
CSRView();
__cuda_callable__
CSRView( const OffsetsHolderView& offsets );
__cuda_callable__
CSRView( const CSRView& csr_view );
__cuda_callable__
CSRView( const CSRView&& csr_view );
/**
* \brief Number segments.
*/
__cuda_callable__
IndexType getSegmentsCount() const;
/***
* \brief Returns size of the segment number \r segmentIdx
*/
__cuda_callable__
IndexType getSegmentSize( const IndexType segmentIdx ) const;
/***
* \brief Returns number of elements managed by all segments.
*/
__cuda_callable__
IndexType getSize() const;
/***
* \brief Returns number of elements that needs to be allocated.
*/
__cuda_callable__
IndexType getStorageSize() const;
__cuda_callable__
IndexType getGlobalIndex( const Index segmentIdx, const Index localIdx ) const;
__cuda_callable__
void getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const;
/***
* \brief Go over all segments and for each segment element call
* function 'f' with arguments 'args'. The return type of 'f' is bool.
* When its true, the for-loop continues. Once 'f' returns false, the for-loop
* is terminated.
*/
template< typename Function, typename... Args >
void forSegments( IndexType first, IndexType last, Function& f, Args... args ) const;
template< typename Function, typename... Args >
void forAll( Function& f, Args... args ) const;
/***
* \brief Go over all segments and perform a reduction in each of them.
*/
template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
void segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
void allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const;
void save( File& file ) const;
void load( File& file );
protected:
OffsetsHolderView offsets;
};
} // namespace Segements
} // namespace Conatiners
} // namespace TNL
#include <TNL/Containers/Segments/CSRView.hpp>
/***************************************************************************
CSRView.hpp - description
-------------------
begin : Dec 11, 2019
copyright : (C) 2019 by Tomas Oberhuber
email : tomas.oberhuber@fjfi.cvut.cz
***************************************************************************/
/* See Copyright Notice in tnl/Copyright */
#pragma once
#include <TNL/Containers/Vector.h>
#include <TNL/Algorithms/ParallelFor.h>
#include <TNL/Containers/Segments/CSRView.h>
namespace TNL {
namespace Containers {
namespace Segments {
template< typename Device,
typename Index >
__cuda_callable__
CSRView< Device, Index >::
CSRView()
{
}
template< typename Device,
typename Index >
__cuda_callable__
CSRView< Device, Index >::
CSRView( const OffsetsHolderView& offsets_view )
: offsets( offsets_view )
{
}
template< typename Device,
typename Index >
__cuda_callable__
CSRView< Device, Index >::
CSRView( const CSRView& csr_view )
: offsets( csr_view.offsest )
{
}
template< typename Device,
typename Index >
__cuda_callable__
CSRView< Device, Index >::
CSRView( const CSRView&& csr_view )
: offsets( std::move( csr_view.offsest ) )
{
}
template< typename Device,
typename Index >
__cuda_callable__
Index
CSRView< Device, Index >::
getSegmentsCount() const
{
return this->offsets.getSize() - 1;
}
template< typename Device,
typename Index >
__cuda_callable__
Index
CSRView< Device, Index >::
getSegmentSize( const IndexType segmentIdx ) const
{
if( ! std::is_same< DeviceType, Devices::Host >::value )
{
#ifdef __CUDA_ARCH__
return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ];
#else
return offsets.getElement( segmentIdx + 1 ) - offsets.getElement( segmentIdx );
#endif
}
return offsets[ segmentIdx + 1 ] - offsets[ segmentIdx ];
}
template< typename Device,
typename Index >
__cuda_callable__
Index
CSRView< Device, Index >::
getSize() const
{
return this->getStorageSize();
}
template< typename Device,
typename Index >
__cuda_callable__
Index
CSRView< Device, Index >::
getStorageSize() const
{
if( ! std::is_same< DeviceType, Devices::Host >::value )
{
#ifdef __CUDA_ARCH__
return offsets[ this->getSegmentsCount() ];
#else
return offsets.getElement( this->getSegmentsCount() );
#endif
}
return offsets[ this->getSegmentsCount() ];
}
template< typename Device,
typename Index >
__cuda_callable__
Index
CSRView< Device, Index >::
getGlobalIndex( const Index segmentIdx, const Index localIdx ) const
{
if( ! std::is_same< DeviceType, Devices::Host >::value )
{
#ifdef __CUDA_ARCH__
return offsets[ segmentIdx ] + localIdx;
#else
return offsets.getElement( segmentIdx ) + localIdx;
#endif
}
return offsets[ segmentIdx ] + localIdx;
}
template< typename Device,
typename Index >
__cuda_callable__
void
CSRView< Device, Index >::
getSegmentAndLocalIndex( const Index globalIdx, Index& segmentIdx, Index& localIdx ) const
{
}
template< typename Device,
typename Index >
template< typename Function, typename... Args >
void
CSRView< Device, Index >::
forSegments( IndexType first, IndexType last, Function& f, Args... args ) const
{
const auto offsetsView = this->offsets.getConstView();
auto l = [=] __cuda_callable__ ( const IndexType segmentIdx, Args... args ) mutable {
const IndexType begin = offsetsView[ segmentIdx ];
const IndexType end = offsetsView[ segmentIdx + 1 ];
IndexType localIdx( 0 );
for( IndexType globalIdx = begin; globalIdx < end; globalIdx++ )
if( ! f( segmentIdx, localIdx++, globalIdx, args... ) )
break;
};
Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
}
template< typename Device,
typename Index >
template< typename Function, typename... Args >
void
CSRView< Device, Index >::
forAll( Function& f, Args... args ) const
{
this->forSegments( 0, this->getSize(), f, args... );
}
template< typename Device,
typename Index >
template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
void
CSRView< Device, Index >::
segmentsReduction( IndexType first, IndexType last, Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
using RealType = decltype( fetch( IndexType(), IndexType() ) );
const auto offsetsView = this->offsets.getConstView();
auto l = [=] __cuda_callable__ ( const IndexType i, Args... args ) mutable {
const IndexType begin = offsetsView[ i ];
const IndexType end = offsetsView[ i + 1 ];
RealType aux( zero );
for( IndexType j = begin; j < end; j++ )
reduction( aux, fetch( i, j, args... ) );
keeper( i, aux );
};
Algorithms::ParallelFor< Device >::exec( first, last, l, args... );
}
template< typename Device,
typename Index >
template< typename Fetch, typename Reduction, typename ResultKeeper, typename Real, typename... Args >
void
CSRView< Device, Index >::
allReduction( Fetch& fetch, Reduction& reduction, ResultKeeper& keeper, const Real& zero, Args... args ) const
{
this->segmentsReduction( 0, this->getSegmentsCount(), fetch, reduction, keeper, zero, args... );
}
template< typename Device,
typename Index >
void
CSRView< Device, Index >::
save( File& file ) const
{
file << this->offsets;
}
template< typename Device,
typename Index >
void
CSRView< Device, Index >::
load( File& file )
{
file >> this->offsets;
}
} // namespace Segments
} // namespace Conatiners
} // namespace TNL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment