Skip to content
Snippets Groups Projects
Commit 91806e1d authored by Tomáš Oberhuber's avatar Tomáš Oberhuber
Browse files

Moving implementation of CSRAdaptiveKernel to .hpp file.

parent 37632fd9
No related branches found
No related tags found
1 merge request!89To/matrices adaptive csr
......@@ -63,10 +63,7 @@ struct CSRAdaptiveKernel
using BlocksType = typename ViewType::BlocksType;
using BlocksView = typename BlocksType::ViewType;
static TNL::String getKernelType()
{
return ViewType::getKernelType();
};
static TNL::String getKernelType();
static constexpr Index THREADS_ADAPTIVE = sizeof(Index) == 8 ? 128 : 256;
......@@ -93,84 +90,16 @@ struct CSRAdaptiveKernel
const Offsets& offsets,
const Index size,
details::Type &type,
Index &sum )
{
sum = 0;
for (Index current = start; current < size - 1; current++ )
{
Index elements = offsets[ current + 1 ] - offsets[ current ];
sum += elements;
if( sum > SHARED_PER_WARP )
{
if( current - start > 0 ) // extra row
{
type = details::Type::STREAM;
return current;
}
else
{ // one long row
if( sum <= 2 * MAX_ELEMENTS_PER_WARP_ADAPT )
type = details::Type::VECTOR;
else
type = details::Type::LONG;
return current + 1;
}
}
}
type = details::Type::STREAM;
return size - 1; // return last row pointer
}
Index &sum );
template< typename Offsets >
void init( const Offsets& offsets )
{
using HostOffsetsType = TNL::Containers::Vector< typename Offsets::IndexType, TNL::Devices::Host, typename Offsets::IndexType >;
HostOffsetsType hostOffsets( offsets );
const Index rows = offsets.getSize();
Index sum, start( 0 ), nextStart( 0 );
// Fill blocks
std::vector< details::CSRAdaptiveKernelBlockDescriptor< Index > > inBlocks;
inBlocks.reserve( rows );
while( nextStart != rows - 1 )
{
details::Type type;
nextStart = findLimit( start, hostOffsets, rows, type, sum );
if( type == details::Type::LONG )
{
const Index blocksCount = inBlocks.size();
const Index warpsPerCudaBlock = THREADS_ADAPTIVE / TNL::Cuda::getWarpSize();
Index warpsLeft = roundUpDivision( blocksCount, warpsPerCudaBlock ) * warpsPerCudaBlock - blocksCount;
if( warpsLeft == 0 )
warpsLeft = warpsPerCudaBlock;
for( Index index = 0; index < warpsLeft; index++ )
inBlocks.emplace_back( start, details::Type::LONG, index, warpsLeft );
}
else
{
inBlocks.emplace_back(start, type,
nextStart,
offsets.getElement(nextStart),
offsets.getElement(start) );
}
start = nextStart;
}
inBlocks.emplace_back(nextStart);
this->blocks = inBlocks;
this->view.setBlocks( blocks );
}
void reset()
{
this->blocks.reset();
this->view.setBlocks( blocks );
}
ViewType getView() { return this->view; };
ConstViewType getConstView() const { return this->view; };
void init( const Offsets& offsets );
void reset();
ViewType getView();
ConstViewType getConstView() const;
template< typename OffsetsView,
typename Fetch,
......@@ -185,10 +114,7 @@ struct CSRAdaptiveKernel
const Reduction& reduction,
ResultKeeper& keeper,
const Real& zero,
Args... args ) const
{
view.segmentsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
}
Args... args ) const;
protected:
BlocksType blocks;
......
......@@ -22,7 +22,148 @@ namespace TNL {
namespace Algorithms {
namespace Segments {
template< typename Index,
typename Device >
TNL::String
CSRAdaptiveKernel< Index, Device >::
getKernelType()
{
return ViewType::getKernelType();
};
template< typename Index,
typename Device >
template< typename Offsets >
Index
CSRAdaptiveKernel< Index, Device >::
findLimit( const Index start,
const Offsets& offsets,
const Index size,
details::Type &type,
Index &sum )
{
sum = 0;
for (Index current = start; current < size - 1; current++ )
{
Index elements = offsets[ current + 1 ] - offsets[ current ];
sum += elements;
if( sum > SHARED_PER_WARP )
{
if( current - start > 0 ) // extra row
{
type = details::Type::STREAM;
return current;
}
else
{ // one long row
if( sum <= 2 * MAX_ELEMENTS_PER_WARP_ADAPT )
type = details::Type::VECTOR;
else
type = details::Type::LONG;
return current + 1;
}
}
}
type = details::Type::STREAM;
return size - 1; // return last row pointer
}
template< typename Index,
typename Device >
template< typename Offsets >
void
CSRAdaptiveKernel< Index, Device >::
init( const Offsets& offsets )
{
using HostOffsetsType = TNL::Containers::Vector< typename Offsets::IndexType, TNL::Devices::Host, typename Offsets::IndexType >;
HostOffsetsType hostOffsets( offsets );
const Index rows = offsets.getSize();
Index sum, start( 0 ), nextStart( 0 );
// Fill blocks
std::vector< details::CSRAdaptiveKernelBlockDescriptor< Index > > inBlocks;
inBlocks.reserve( rows );
while( nextStart != rows - 1 )
{
details::Type type;
nextStart = findLimit( start, hostOffsets, rows, type, sum );
if( type == details::Type::LONG )
{
const Index blocksCount = inBlocks.size();
const Index warpsPerCudaBlock = THREADS_ADAPTIVE / TNL::Cuda::getWarpSize();
Index warpsLeft = roundUpDivision( blocksCount, warpsPerCudaBlock ) * warpsPerCudaBlock - blocksCount;
if( warpsLeft == 0 )
warpsLeft = warpsPerCudaBlock;
for( Index index = 0; index < warpsLeft; index++ )
inBlocks.emplace_back( start, details::Type::LONG, index, warpsLeft );
}
else
{
inBlocks.emplace_back(start, type,
nextStart,
offsets.getElement(nextStart),
offsets.getElement(start) );
}
start = nextStart;
}
inBlocks.emplace_back(nextStart);
this->blocks = inBlocks;
this->view.setBlocks( blocks );
}
template< typename Index,
typename Device >
void
CSRAdaptiveKernel< Index, Device >::
reset()
{
this->blocks.reset();
this->view.setBlocks( blocks );
}
template< typename Index,
typename Device >
auto
CSRAdaptiveKernel< Index, Device >::
getView() -> ViewType
{
return this->view;
}
template< typename Index,
typename Device >
auto
CSRAdaptiveKernel< Index, Device >::
getConstView() const -> ConstViewType
{
return this->view;
};
template< typename Index,
typename Device >
template< typename OffsetsView,
typename Fetch,
typename Reduction,
typename ResultKeeper,
typename Real,
typename... Args >
void
CSRAdaptiveKernel< Index, Device >::
segmentsReduction( const OffsetsView& offsets,
Index first,
Index last,
Fetch& fetch,
const Reduction& reduction,
ResultKeeper& keeper,
const Real& zero,
Args... args ) const
{
view.segmentsReduction( offsets, first, last, fetch, reduction, keeper, zero, args... );
}
} // namespace Segments
} // namespace Algorithms
} // namespace TNL
\ No newline at end of file
} // namespace TNL
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment