Commit fb5037da authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

add cmp function

parent 44a1ed3e
Loading
Loading
Loading
Loading
+15 −17
Original line number Diff line number Diff line
@@ -3,9 +3,6 @@
#include <TNL/Containers/Array.h>
#include "../util/reduction.cuh"
#include "task.h"
#include <iostream>

#define deb(x) std::cout << #x << " = " << x << std::endl;

using namespace TNL;
using namespace TNL::Containers;
@@ -73,26 +70,26 @@ __device__ Value pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, con
    }
}

template <typename Value>
template <typename Value, typename Function>
__device__
void countElem(ArrayView<Value, Devices::Cuda> arr,
void countElem(ArrayView<Value, Devices::Cuda> arr, const Function & Cmp,
             int &smaller, int &bigger,
             const Value &pivot)
{
    for (int i = threadIdx.x; i < arr.getSize(); i += blockDim.x)
    {
        const Value data = arr[i];
        if(data < pivot)
        if(Cmp(data, pivot))
            smaller++;
        else if(data > pivot)
        else if(Cmp(pivot, data) )
            bigger++;
    }
}

template <typename Value>
template <typename Value, typename Function>
__device__
void copyDataShared(ArrayView<Value, Devices::Cuda> src,
              ArrayView<Value, Devices::Cuda> dst,
              ArrayView<Value, Devices::Cuda> dst, const Function & Cmp,
              Value *sharedMem,
              int smallerStart, int biggerStart,
              int smallerTotal, int biggerTotal,
@@ -103,9 +100,9 @@ void copyDataShared(ArrayView<Value, Devices::Cuda> src,
    for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
    {
        const Value data = src[i];
        if (data < pivot)
        if (Cmp(data, pivot))
            sharedMem[smallerOffset++] = data;
        else if (data > pivot)
        else if (Cmp(pivot, data))
            sharedMem[smallerTotal + biggerOffset++] = data;
    }
    __syncthreads();
@@ -119,17 +116,18 @@ void copyDataShared(ArrayView<Value, Devices::Cuda> src,
    }
}

template <typename Value>
template <typename Value, typename Function>
__device__
void copyData(ArrayView<Value, Devices::Cuda> src,
              ArrayView<Value, Devices::Cuda> dst,
              const Function & Cmp, 
              int smallerStart, int biggerStart,
              const Value &pivot)
{
    for (int i = threadIdx.x; i < src.getSize(); i += blockDim.x)
    {
        const Value data = src[i];
        if (data < pivot)
        if ( Cmp(data, pivot) )
        {
            /*
            if(smallerStart >= dst.getSize() || smallerStart < 0)
@@ -137,7 +135,7 @@ void copyData(ArrayView<Value, Devices::Cuda> src,
            */
            dst[smallerStart++] = data;
        }
        else if (data > pivot)
        else if ( Cmp(pivot, data) )
        {
            /*
            if(biggerStart >= dst.getSize() || biggerStart < 0)
@@ -168,7 +166,7 @@ __device__ void cudaPartition(ArrayView<Value, Devices::Cuda> src,
    //-------------------------------------------------------------------------

    int smaller = 0, bigger = 0;
    countElem(srcView, smaller, bigger, pivot);
    countElem(srcView, Cmp, smaller, bigger, pivot);
    
    int smallerPrefSumInc = blockInclusivePrefixSum(smaller);
    int biggerPrefSumInc = blockInclusivePrefixSum(bigger);
@@ -191,7 +189,7 @@ __device__ void cudaPartition(ArrayView<Value, Devices::Cuda> src,
        }
        __syncthreads();

        copyDataShared(srcView, dst, sharedMem,
        copyDataShared(srcView, dst, Cmp, sharedMem,
                        smallerStart, biggerStart,
                        smallerTotal, biggerTotal,
                        smallerPrefSumInc - smaller, biggerPrefSumInc - bigger, //exclusive prefix sum of elements
@@ -201,6 +199,6 @@ __device__ void cudaPartition(ArrayView<Value, Devices::Cuda> src,
    {
        int destSmaller = smallerStart + smallerPrefSumInc - smaller;
        int destBigger = biggerStart + biggerPrefSumInc - bigger;
        copyData(srcView, dst, destSmaller, destBigger, pivot);
        copyData(srcView, dst, Cmp, destSmaller, destBigger, pivot);
    }
}
 No newline at end of file
+3 −3
Original line number Diff line number Diff line
@@ -130,7 +130,7 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
        __syncthreads();

        int smaller = 0, bigger = 0;
        countElem(src.getView(begin, end), smaller, bigger, pivot);
        countElem(src.getView(begin, end), Cmp, smaller, bigger, pivot);

        //synchronization is in this function already
        int smallerPrefSumInc = blockInclusivePrefixSum(smaller);
@@ -160,7 +160,7 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
            }
            __syncthreads();

            copyDataShared(src.getView(begin, end), dst.getView(begin, end),
            copyDataShared(src.getView(begin, end), dst.getView(begin, end), Cmp,
                sharedMem,
                0, pivotEnd,
                smallerTotal, biggerTotal,
@@ -172,7 +172,7 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
            int destSmaller = 0 + (smallerPrefSumInc - smaller);
            int destBigger = pivotEnd  + (biggerPrefSumInc - bigger);

            copyData(src.getView(begin, end), dst.getView(begin, end), destSmaller, destBigger, pivot);
            copyData(src.getView(begin, end), dst.getView(begin, end), Cmp, destSmaller, destBigger, pivot);
        }

        __syncthreads();