Skip to content
Snippets Groups Projects
cudaPartition.cuh 6.43 KiB
Newer Older
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
#pragma once

#include <TNL/Containers/Array.h>
#include "../util/reduction.cuh"
#include "task.h"

using namespace TNL;
using namespace TNL::Containers;

Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
template <typename Value, typename Device, typename Function>
__device__ Value pickPivot(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
{
    //return src[0];
    //return src[src.getSize()-1];

        return src[0];

    Value a = src[0], b = src[src.getSize() / 2], c = src[src.getSize() - 1];

    if (Cmp(a, b)) // ..a..b..
        if (Cmp(b, c)) // ..a..b..c
            return b;
        else if (Cmp(c, a)) //..c..a..b..
            return a;
        else //..a..c..b..
            return c;
    }
    else //..b..a..
    {
        if (Cmp(a, c)) //..b..a..c
            return a;
        else if (Cmp(c, b)) //..c..b..a..
            return b;
        else //..b..c..a..
            return c;
    }
template <typename Value, typename Device, typename Function>
__device__ int pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
{
    //return 0;
    //return src.getSize()-1;

    Value a = src[0], b = src[src.getSize() / 2], c = src[src.getSize() - 1];

    if (Cmp(a, b)) // ..a..b..
        if (Cmp(b, c)) // ..a..b..c
            return src.getSize() / 2;
        else if (Cmp(c, a)) //..c..a..b..
            return 0;
        else //..a..c..b..
            return src.getSize() - 1;
    }
    else //..b..a..
    {
        if (Cmp(a, c)) //..b..a..c
        else if (Cmp(c, b)) //..c..b..a..
            return src.getSize() / 2;
        else //..b..c..a..
            return src.getSize() - 1;
    }
}

Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
template <typename Value, typename Function>
__device__ void countElem(ArrayView<Value, Devices::Cuda> arr,
                          const Function &Cmp,
                          int &smaller, int &bigger,
                          const Value &pivot)
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
{
    for (int i = threadIdx.x; i < arr.getSize(); i += blockDim.x)
    {
        const Value data = arr[i];
        else if (Cmp(pivot, data))
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
template <typename Value, typename Function>
__device__ void copyDataShared(ArrayView<Value, Devices::Cuda> src,
                               ArrayView<Value, Devices::Cuda> dst,
                               const Function &Cmp,
                               Value *sharedMem,
                               int smallerStart, int biggerStart,
                               int smallerTotal, int biggerTotal,
                               int smallerOffset, int biggerOffset, //exclusive prefix sum of elements
                               const Value &pivot)
{

    for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
    {
        const Value data = src[i];
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
        if (Cmp(data, pivot))
            sharedMem[smallerOffset++] = data;
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
        else if (Cmp(pivot, data))
            sharedMem[smallerTotal + biggerOffset++] = data;
    }
    __syncthreads();

    for (int i = threadIdx.x; i < smallerTotal + biggerTotal; i += blockDim.x)
    {
            dst[smallerStart + i] = sharedMem[i];
        else
            dst[biggerStart + i - smallerTotal] = sharedMem[i];
    }
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
template <typename Value, typename Function>
__device__ void copyData(ArrayView<Value, Devices::Cuda> src,
                         ArrayView<Value, Devices::Cuda> dst,
                         const Function &Cmp,
                         int smallerStart, int biggerStart,
                         const Value &pivot)
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
{
    for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
    {
        const Value data = src[i];
            if(smallerStart >= dst.getSize() || smallerStart < 0)
                printf("failed smaller: b:%d t:%d: tried to write into [%d]/%d\n", blockDim.x, threadIdx.x, smallerStart, dst.getSize());
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
            dst[smallerStart++] = data;
            if(biggerStart >= dst.getSize() || biggerStart < 0)
                printf("failed bigger: b:%d t:%d: tried to write into [%d]/%d\n", blockDim.x, threadIdx.x, biggerStart, dst.getSize());
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
            dst[biggerStart++] = data;
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
    }
}

//----------------------------------------------------------------------------------

template <typename Value, typename Function, bool useShared>
__device__ void cudaPartition(ArrayView<Value, Devices::Cuda> src,
                              ArrayView<Value, Devices::Cuda> dst,
                              const Function &Cmp,
                              Value *sharedMem,
                              const Value &pivot,
                              int elemPerBlock, TASK &task)
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
{
    static __shared__ int smallerStart, biggerStart;

    int myBegin = elemPerBlock * (blockIdx.x - task.firstBlock);
    int myEnd = TNL::min(myBegin + elemPerBlock, src.getSize());
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed

    auto srcView = src.getView(myBegin, myEnd);

    //-------------------------------------------------------------------------

    int smaller = 0, bigger = 0;
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
    countElem(srcView, Cmp, smaller, bigger, pivot);
    int smallerPrefSumInc = blockInclusivePrefixSum(smaller);
    int biggerPrefSumInc = blockInclusivePrefixSum(bigger);
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed

    if (threadIdx.x == blockDim.x - 1) //last thread in block has sum of all values
    {
        smallerStart = atomicAdd(&(task.dstBegin), smallerPrefSumInc);
        biggerStart = atomicAdd(&(task.dstEnd), -biggerPrefSumInc) - biggerPrefSumInc;
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
    }
    __syncthreads();

    //-----------------------------------------------------------
    {
        static __shared__ int smallerTotal, biggerTotal;
        if (threadIdx.x == blockDim.x - 1)
        {
            smallerTotal = smallerPrefSumInc;
            biggerTotal = biggerPrefSumInc;
        }
        __syncthreads();
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
        copyDataShared(srcView, dst, Cmp, sharedMem,
                       smallerStart, biggerStart,
                       smallerTotal, biggerTotal,
                       smallerPrefSumInc - smaller, biggerPrefSumInc - bigger, //exclusive prefix sum of elements
                       pivot);
        int destSmaller = smallerStart + smallerPrefSumInc - smaller;
        int destBigger = biggerStart + biggerPrefSumInc - bigger;
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
        copyData(srcView, dst, Cmp, destSmaller, destBigger, pivot);
Xuan Thang Nguyen's avatar
Xuan Thang Nguyen committed
}