Commit 45bf0dae authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

roll back 1st step

parent cbe260c1
Loading
Loading
Loading
Loading
+41 −7
Original line number Diff line number Diff line
@@ -108,15 +108,49 @@ template <typename Value, typename CMP>
__global__ void bitoniSort1stStepSharedMemory(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> arr, CMP Cmp)
{
    extern __shared__ int externMem[];
    
    Value * sharedMem = (Value *)externMem;
    int sharedMemLen = 2*blockDim.x;

    int myBlockStart = blockIdx.x * sharedMemLen;
    int myBlockEnd = TNL::min(arr.getSize(), myBlockStart+sharedMemLen);

    if (blockIdx.x % 2 || blockIdx.x + 1 == gridDim.x)
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd), (Value *)externMem, Cmp);
    else
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd), (Value *)externMem,
                          [&] __cuda_callable__(const Value &a, const Value &b) { return Cmp(b, a); });
    //copy from globalMem into sharedMem
    for (int i = threadIdx.x; myBlockStart + i < myBlockEnd; i += blockDim.x)
        sharedMem[i] = arr[myBlockStart + i];
    __syncthreads();

    //------------------------------------------
    //bitonic activity
    {
        int i = blockIdx.x * blockDim.x + threadIdx.x;
        int paddedSize = closestPow2(myBlockEnd - myBlockStart);

        for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
        {
            //calculate the direction of swapping
            int monotonicSeqIdx = i / (monotonicSeqLen/2);
            bool ascending = (monotonicSeqIdx & 1) != 0;
            if ((monotonicSeqIdx + 1) * monotonicSeqLen >= arr.getSize()) //special case for parts with no "partner"
                ascending = true;

            for (int len = monotonicSeqLen; len > 1; len /= 2)
            {
                //calculates which 2 indexes will be compared and swap
                int part = threadIdx.x / (len / 2);
                int s = part * len + (threadIdx.x & ((len / 2) - 1));
                int e = s + len / 2;

                if(e < myBlockEnd - myBlockStart) //touching virtual padding
                    cmpSwap(sharedMem[s], sharedMem[e], ascending, Cmp);
                __syncthreads();
            }
        }
    }

    //writeback to global memory
    for (int i = threadIdx.x; myBlockStart + i < myBlockEnd; i += blockDim.x)
        arr[myBlockStart + i] = sharedMem[i];
}

//---------------------------------------------