Commit 5617f462 authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

fix copy and add support for block bitonic sort of bigger size

parent e7519920
Loading
Loading
Loading
Loading
+36 −28
Original line number Original line Diff line number Diff line
@@ -191,27 +191,32 @@ __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices
    //------------------------------------------
    //------------------------------------------
    //bitonic activity
    //bitonic activity
    {
    {
        int i = threadIdx.x;
        int paddedSize = closestPow2_ptx(src.getSize());
        int paddedSize = closestPow2_ptx(src.getSize());


        for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
        for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
        {
        {
            //calculate the direction of swapping
            int monotonicSeqIdx = i / (monotonicSeqLen / 2);
            bool ascending = (monotonicSeqIdx & 1) != 0;
            if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
                ascending = true;

            for (int len = monotonicSeqLen; len > 1; len /= 2)
            for (int len = monotonicSeqLen; len > 1; len /= 2)
            {
                for(int i = threadIdx.x; ; i+=blockDim.x) //simulates other blocks in case src.size > blockDim.x*2
                {
                {
                    //calculates which 2 indexes will be compared and swap
                    //calculates which 2 indexes will be compared and swap
                    int part = i / (len / 2);
                    int part = i / (len / 2);
                    int s = part * len + (i & ((len / 2) - 1));
                    int s = part * len + (i & ((len / 2) - 1));
                    int e = s + len / 2;
                    int e = s + len / 2;


                if (e < src.getSize()) //not touching virtual padding
                    if(e >= src.getSize()) //touching virtual padding, the order dont swap
                        break;

                    //calculate the direction of swapping
                    int monotonicSeqIdx = i / (monotonicSeqLen / 2);
                    bool ascending = (monotonicSeqIdx & 1) != 0;
                    if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
                        ascending = true;

                    cmpSwap(sharedMem[s], sharedMem[e], ascending, Cmp);
                    cmpSwap(sharedMem[s], sharedMem[e], ascending, Cmp);
                __syncthreads();
                }
                
                __syncthreads(); //only 1 synchronization needed
            }
            }
        }
        }
    }
    }
@@ -232,29 +237,32 @@ __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices
 * */
 * */
template <typename Value, typename Function>
template <typename Value, typename Function>
__device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src,
__device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src,
                                  TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> dst,
                                  const Function &Cmp)
                                  const Function &Cmp)
{
{
    int i = threadIdx.x;
    int paddedSize = closestPow2_ptx(src.getSize());
    int paddedSize = closestPow2_ptx(src.getSize());


    for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    {
    {
        //calculate the direction of swapping
        int monotonicSeqIdx = i / (monotonicSeqLen / 2);
        bool ascending = (monotonicSeqIdx & 1) != 0;
        if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
            ascending = true;

        for (int len = monotonicSeqLen; len > 1; len /= 2)
        for (int len = monotonicSeqLen; len > 1; len /= 2)
        {
            for(int i = threadIdx.x; ; i+=blockDim.x) //simulates other blocks in case src.size > blockDim.x*2
            {
            {
                //calculates which 2 indexes will be compared and swap
                //calculates which 2 indexes will be compared and swap
                int part = i / (len / 2);
                int part = i / (len / 2);
                int s = part * len + (i & ((len / 2) - 1));
                int s = part * len + (i & ((len / 2) - 1));
                int e = s + len / 2;
                int e = s + len / 2;


            if (e < src.getSize()) //not touching virtual padding
                if(e >= src.getSize())
                    break;

                //calculate the direction of swapping
                int monotonicSeqIdx = i / (monotonicSeqLen / 2);
                bool ascending = (monotonicSeqIdx & 1) != 0;
                if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
                    ascending = true;

                cmpSwap(src[s], src[e], ascending, Cmp);
                cmpSwap(src[s], src[e], ascending, Cmp);
            }
            __syncthreads();
            __syncthreads();
        }
        }
    }
    }
@@ -292,9 +300,9 @@ __global__ void bitoniSort1stStep(TNL::Containers::ArrayView<Value, TNL::Devices
    int myBlockEnd = TNL::min(arr.getSize(), myBlockStart + (2 * blockDim.x));
    int myBlockEnd = TNL::min(arr.getSize(), myBlockStart + (2 * blockDim.x));


    if (blockIdx.x % 2 || blockIdx.x + 1 == gridDim.x)
    if (blockIdx.x % 2 || blockIdx.x + 1 == gridDim.x)
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd), Cmp);
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), Cmp);
    else
    else
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd),
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd),
                          [&] __cuda_callable__(const Value &a, const Value &b) { return Cmp(b, a); });
                          [&] __cuda_callable__(const Value &a, const Value &b) { return Cmp(b, a); });
}
}