Commit 3040910c authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

add comments and reorder pivot calc place

parent abb378c4
Loading
Loading
Loading
Loading
+21 −16
Original line number Original line Diff line number Diff line
@@ -78,8 +78,8 @@ __device__ void singleBlockQuickSort(ArrayView<int, TNL::Devices::Cuda> arr,
{
{
    static __shared__ int stackTop;
    static __shared__ int stackTop;
    static __shared__ int stackArrBegin[stackSize], stackArrEnd[stackSize], stackDepth[stackSize];
    static __shared__ int stackArrBegin[stackSize], stackArrEnd[stackSize], stackDepth[stackSize];
    static __shared__ int begin, end, depth,pivotBegin, pivotEnd;
    static __shared__ int begin, end, depth;
    static __shared__ int pivot;
    static __shared__ int pivot, pivotBegin, pivotEnd;


    if (threadIdx.x == 0)
    if (threadIdx.x == 0)
    {
    {
@@ -93,53 +93,58 @@ __device__ void singleBlockQuickSort(ArrayView<int, TNL::Devices::Cuda> arr,


    while(stackTop > 0)
    while(stackTop > 0)
    {
    {
        //pick up partition to break up
        if (threadIdx.x == 0)
        if (threadIdx.x == 0)
        {
        {
            begin = stackArrBegin[stackTop-1];
            begin = stackArrBegin[stackTop-1];
            end = stackArrEnd[stackTop-1];
            end = stackArrEnd[stackTop-1];
            depth = stackDepth[stackTop-1];
            depth = stackDepth[stackTop-1];
            stackTop--;
            stackTop--;
            pivot = pickPivot((depth&1) == 0? 
                                    arr.getView(begin, end) :
                                    aux.getView(begin, end),
                                Cmp
                            );
        }
        }
        __syncthreads();
        __syncthreads();


        int size = end - begin;
        int size = end - begin;
        auto src = (depth&1) == 0 ? arr.getView(begin, end) : aux.getView(begin, end);
        auto &src = (depth&1) == 0 ? arr : aux;
        auto dst = (depth&1) == 0 ? aux.getView(begin, end) : arr.getView(begin, end);


        //small enough for for bitonic
        if(size <= blockDim.x*2)
        if(size <= blockDim.x*2)
        {
        {
            externSort<Function, 2048>(src, arr.getView(begin, end), Cmp);
            externSort<Function, 2048>(src.getView(begin, end), arr.getView(begin, end), Cmp);
            __syncthreads();
            __syncthreads();
            continue;
            continue;
        }
        }
        //------------------------------------------------------

        //actually do partitioning from here on out
        if(threadIdx.x == 0)
            pivot = pickPivot(src.getView(begin, end),Cmp);
        __syncthreads();


        int smaller = 0, bigger = 0;
        int smaller = 0, bigger = 0;
        countElem(src, smaller, bigger, pivot);
        countElem(src.getView(begin, end), smaller, bigger, pivot);


        //synchronization is in this function already
        int smallerOffset = blockInclusivePrefixSum(smaller);
        int smallerOffset = blockInclusivePrefixSum(smaller);
        int biggerOffset = blockInclusivePrefixSum(bigger);
        int biggerOffset = blockInclusivePrefixSum(bigger);


        if (threadIdx.x == blockDim.x - 1)
        if (threadIdx.x == blockDim.x - 1) //has sum of all smaller and greater elements than pivot in src
        {
        {
            pivotBegin = smallerOffset;
            pivotBegin = 0 + smallerOffset;
            pivotEnd = size - biggerOffset;
            pivotEnd = size - biggerOffset;
        }
        }
        __syncthreads();
        __syncthreads();


        int destSmaller = 0 + smallerOffset - smaller;
        int destSmaller = 0 + (smallerOffset - smaller);
        int destBigger = pivotEnd  + (biggerOffset - bigger);
        int destBigger = pivotEnd  + (biggerOffset - bigger);
        auto &dst = (depth&1) == 0 ? aux : arr;


        copyData(src, dst, destSmaller, destBigger, pivot);
        copyData(src.getView(begin, end), dst.getView(begin, end), destSmaller, destBigger, pivot);
        __syncthreads();
        __syncthreads();


        for (int i = pivotBegin + threadIdx.x; i < pivotEnd; i += blockDim.x)
        for (int i = pivotBegin + threadIdx.x; i < pivotEnd; i += blockDim.x)
            arr[begin + i] = pivot;
            arr[begin + i] = pivot;


        //creates new tasks
        if(threadIdx.x == 0)
        if(threadIdx.x == 0)
        {
        {
            stackPush<stackSize>(stackArrBegin, stackArrEnd, stackDepth, stackTop,
            stackPush<stackSize>(stackArrBegin, stackArrEnd, stackDepth, stackTop,
@@ -147,6 +152,6 @@ __device__ void singleBlockQuickSort(ArrayView<int, TNL::Devices::Cuda> arr,
                    begin +pivotEnd, end,
                    begin +pivotEnd, end,
                    depth);
                    depth);
        }
        }
        __syncthreads();
        __syncthreads(); //sync to update stackTop
    } //ends while loop
    } //ends while loop
}
}
 No newline at end of file