Newer
Older
#pragma once
#include <TNL/Containers/Array.h>
#include "../util/reduction.cuh"
#include "task.h"
using namespace TNL;
using namespace TNL::Containers;
template <typename Value, typename Device, typename Function>
__device__ Value pickPivot(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
//return src[0];
//return src[src.getSize()-1];
if (src.getSize() == 1)
Value a = src[0], b = src[src.getSize() / 2], c = src[src.getSize() - 1];
if (Cmp(a, b)) // ..a..b..
if (Cmp(b, c)) // ..a..b..c
else if (Cmp(c, a)) //..c..a..b..
return a;
else //..a..c..b..
return c;
}
else //..b..a..
{
if (Cmp(a, c)) //..b..a..c
else if (Cmp(c, b)) //..c..b..a..
return b;
else //..b..c..a..
return c;
}
template <typename Value, typename Device, typename Function>
__device__ int pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
{
//return 0;
//return src.getSize()-1;
if (src.getSize() <= 1)
Value a = src[0], b = src[src.getSize() / 2], c = src[src.getSize() - 1];
if (Cmp(a, b)) // ..a..b..
if (Cmp(b, c)) // ..a..b..c
return src.getSize() / 2;
else if (Cmp(c, a)) //..c..a..b..
return 0;
else //..a..c..b..
return src.getSize() - 1;
}
else //..b..a..
{
if (Cmp(a, c)) //..b..a..c
else if (Cmp(c, b)) //..c..b..a..
return src.getSize() / 2;
else //..b..c..a..
return src.getSize() - 1;
}
}
__device__ void countElem(ArrayView<Value, Devices::Cuda> arr,
const Function &Cmp,
int &smaller, int &bigger,
const Value &pivot)
{
for (int i = threadIdx.x; i < arr.getSize(); i += blockDim.x)
{
if (Cmp(data, pivot))
else if (Cmp(pivot, data))
__device__ void copyDataShared(ArrayView<Value, Devices::Cuda> src,
ArrayView<Value, Devices::Cuda> dst,
const Function &Cmp,
Value *sharedMem,
int smallerStart, int biggerStart,
int smallerTotal, int biggerTotal,
int smallerOffset, int biggerOffset, //exclusive prefix sum of elements
const Value &pivot)
{
for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
{
sharedMem[smallerOffset++] = data;
sharedMem[smallerTotal + biggerOffset++] = data;
}
__syncthreads();
for (int i = threadIdx.x; i < smallerTotal + biggerTotal; i += blockDim.x)
{
if (i < smallerTotal)
dst[smallerStart + i] = sharedMem[i];
else
dst[biggerStart + i - smallerTotal] = sharedMem[i];
}
__device__ void copyData(ArrayView<Value, Devices::Cuda> src,
ArrayView<Value, Devices::Cuda> dst,
const Function &Cmp,
int smallerStart, int biggerStart,
const Value &pivot)
{
for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
{
if (Cmp(data, pivot))
if(smallerStart >= dst.getSize() || smallerStart < 0)
printf("failed smaller: b:%d t:%d: tried to write into [%d]/%d\n", blockDim.x, threadIdx.x, smallerStart, dst.getSize());
else if (Cmp(pivot, data))
if(biggerStart >= dst.getSize() || biggerStart < 0)
printf("failed bigger: b:%d t:%d: tried to write into [%d]/%d\n", blockDim.x, threadIdx.x, biggerStart, dst.getSize());
}
}
//----------------------------------------------------------------------------------
template <typename Value, typename Function, bool useShared>
__device__ void cudaPartition(ArrayView<Value, Devices::Cuda> src,
const Function &Cmp,
Value *sharedMem,
const Value &pivot,
int elemPerBlock, TASK &task)
{
static __shared__ int smallerStart, biggerStart;
int myBegin = elemPerBlock * (blockIdx.x - task.firstBlock);
int myEnd = TNL::min(myBegin + elemPerBlock, src.getSize());
auto srcView = src.getView(myBegin, myEnd);
//-------------------------------------------------------------------------
int smaller = 0, bigger = 0;
int smallerPrefSumInc = blockInclusivePrefixSum(smaller);
int biggerPrefSumInc = blockInclusivePrefixSum(bigger);
if (threadIdx.x == blockDim.x - 1) //last thread in block has sum of all values
{
smallerStart = atomicAdd(&(task.dstBegin), smallerPrefSumInc);
biggerStart = atomicAdd(&(task.dstEnd), -biggerPrefSumInc) - biggerPrefSumInc;
}
__syncthreads();
//-----------------------------------------------------------
if (useShared)
{
static __shared__ int smallerTotal, biggerTotal;
if (threadIdx.x == blockDim.x - 1)
{
smallerTotal = smallerPrefSumInc;
biggerTotal = biggerPrefSumInc;
}
__syncthreads();
smallerStart, biggerStart,
smallerTotal, biggerTotal,
smallerPrefSumInc - smaller, biggerPrefSumInc - bigger, //exclusive prefix sum of elements
pivot);
int destSmaller = smallerStart + smallerPrefSumInc - smaller;
int destBigger = biggerStart + biggerPrefSumInc - bigger;
copyData(srcView, dst, Cmp, destSmaller, destBigger, pivot);