Commit c6f6d59e authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

function for choosing pivot

parent aaf607d6
Loading
Loading
Loading
Loading
+33 −0
Original line number Diff line number Diff line
@@ -54,3 +54,36 @@ __device__ void calcBlocksNeeded(int elemLeft, int elemRight, int &blocksLeft, i
    blocksRight = elemRight / minElemPerBlock + (elemRight% minElemPerBlock != 0);
    
}

template <typename Value, typename Device, typename Function>
__device__ Value pickPivot(TNL::Containers::ArrayView<Value, Device> src, const Function & Cmp)
{
    return src[0];
    //return src[src.getSize()-1];

    /*
    if(src.getSize() ==1)
        return src[0];
    
    Value a = src[0], b = src[src.getSize()/2], c = src[src.getSize() - 1];

    if(Cmp(a, b)) // ..a..b..
    {
        if(Cmp(b, c))// ..a..b..c
            return b;
        else if(Cmp(c, a))//..c..a..b..
            return a;
        else //..a..c..b..
            return c;
    }
    else //..b..a..
    {
        if(Cmp(a, c))//..b..a..c
            return a;
        else if(Cmp(c, b))//..c..b..a..
            return b;
        else //..b..c..a..
            return c;
    }
    */
}
 No newline at end of file
+6 −2
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ __device__ void multiBlockQuickSort(CudaArrayView arr, CudaArrayView aux, TASK *
    static __shared__ int pivot;

    if(threadIdx.x == 0)
        pivot = depth %2 == 0? arr[arr.getSize() - 1] : aux[aux.getSize() - 1];
        pivot = pickPivot(depth %2 == 0? arr: aux, Cmp);
    __syncthreads();
    
    bool isLast;
@@ -213,7 +213,11 @@ __device__ void singleBlockQuickSort(CudaArrayView arr, CudaArrayView aux, const
            end = stackArrEnd[stackTop-1];
            depth = stackDepth[stackTop-1];
            stackTop--;
            pivot = depth%2 == 0? arr[end - 1] : aux[end-1];
            pivot = pickPivot(depth%2 == 0? 
                                    arr.getView(begin, end) :
                                    aux.getView(begin, end),
                                Cmp
                            );
        }
        __syncthreads();