Commit fdb36d46 authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

refactor out big size check

parent cadd1522
Loading
Loading
Loading
Loading
+83 −47
Original line number Diff line number Diff line
@@ -7,7 +7,9 @@
static __device__ __forceinline__ unsigned int __btflo(unsigned int word)
{
    unsigned int ret;
    asm volatile("bfind.u32 %0, %1;" : "=r"(ret) : "r"(word));
    asm volatile("bfind.u32 %0, %1;"
                 : "=r"(ret)
                 : "r"(word));
    return ret;
}

@@ -227,7 +229,6 @@ __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices
        dst[i] = sharedMem[i];
}


/**
 * IMPORTANT: all threads in block have to call this function to work properly
 * IMPORTANT: unlike the counterpart with shared memory, this function only works in-place
@@ -307,60 +308,95 @@ __global__ void bitoniSort1stStep(TNL::Containers::ArrayView<Value, TNL::Devices
                          [&] __cuda_callable__(const Value &a, const Value &b) { return Cmp(b, a); });
}

//---------------------------------------------
//---------------------------------------------
template <typename Value, typename CMP>
void bitonicSort(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src, int begin, int end, const CMP &Cmp)
void bitonicSortWithShared(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> view, const CMP &Cmp,
                           int gridDim, int blockDim, int sharedMemLen, int sharedMemSize)
{
    TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> arr = src.getView(begin, end);
    int paddedSize = closestPow2(arr.getSize());

    int threadsNeeded = arr.getSize() / 2 + (arr.getSize() % 2 != 0);
    int paddedSize = closestPow2(view.getSize());

    const int maxThreadsPerBlock = 512;
    int threadPerBlock = maxThreadsPerBlock;
    int blocks = threadsNeeded / threadPerBlock + (threadsNeeded % threadPerBlock != 0);
    bitoniSort1stStepSharedMemory<<<gridDim, blockDim, sharedMemSize>>>(view, Cmp);
    //now alternating monotonic sequences with lenght of sharedMemLen

    int sharedMemLen = threadPerBlock * 2;
    int sharedMemSize = sharedMemLen * sizeof(Value);
    // \/ has length of 2 * sharedMemLen
    for (int monotonicSeqLen = 2 * sharedMemLen; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    {
        for (int len = monotonicSeqLen, partsInSeq = 1; len > 1; len /= 2, partsInSeq *= 2)
        {
            if (len > sharedMemLen)
            {
                bitonicMergeGlobal<<<gridDim, blockDim>>>(
                    view, Cmp, monotonicSeqLen, len, partsInSeq);
            }
            else
            {
                bitonicMergeSharedMemory<<<gridDim, blockDim, sharedMemSize>>>(
                    view, Cmp, monotonicSeqLen, len, partsInSeq);

    //---------------------------------------------------------------------------------
                //simulates sorts until len == 2 already, no need to continue this loop
                break;
            }
        }
    }
    cudaDeviceSynchronize();
}

    cudaDeviceProp deviceProp;
    cudaGetDeviceProperties(&deviceProp, 0);
//---------------------------------------------

    //---------------------------------------------------------------------------------
template <typename Value, typename CMP>
void bitonicSort(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> view,
                 const CMP &Cmp,
                 int gridDim, int blockDim)

    if (sharedMemSize <= deviceProp.sharedMemPerBlock)
        bitoniSort1stStepSharedMemory<<<blocks, threadPerBlock, sharedMemSize>>>(arr, Cmp);
    else
        bitoniSort1stStep<<<blocks, threadPerBlock>>>(arr, Cmp);
{
    int paddedSize = closestPow2(view.getSize());

    for (int monotonicSeqLen = 2 * sharedMemLen; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    {
        for (int len = monotonicSeqLen, partsInSeq = 1; len > 1; len /= 2, partsInSeq *= 2)
        {
            if (len > sharedMemLen)
            {
                bitonicMergeGlobal<<<blocks, threadPerBlock>>>(
                    arr, Cmp, monotonicSeqLen, len, partsInSeq);
            bitonicMergeGlobal<<<gridDim, blockDim>>>(view, Cmp, monotonicSeqLen, len, partsInSeq);
        }
            else
    }
    cudaDeviceSynchronize();
}

//---------------------------------------------
template <typename Value, typename CMP>
void bitonicSort(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src, int begin, int end, const CMP &Cmp)
{
    auto view = src.getView(begin, end);

    int threadsNeeded = view.getSize() / 2 + (view.getSize() % 2 != 0);

    cudaDeviceProp deviceProp;
    cudaGetDeviceProperties(&deviceProp, 0);

    const int maxThreadsPerBlock = 512;

    int sharedMemLen = maxThreadsPerBlock * 2;
    int sharedMemSize = sharedMemLen * sizeof(Value);

    if (sharedMemSize <= deviceProp.sharedMemPerBlock)
    {
                    bitonicMergeSharedMemory<<<blocks, threadPerBlock, sharedMemSize>>>(
                        arr, Cmp, monotonicSeqLen, len, partsInSeq);
        int blockDim = maxThreadsPerBlock;
        int gridDim = threadsNeeded / blockDim + (threadsNeeded % blockDim != 0);
        bitonicSortWithShared(view, Cmp, gridDim, blockDim, sharedMemLen, sharedMemSize);
    }
                else
    else if (sharedMemSize / 2 <= deviceProp.sharedMemPerBlock)
    {
                    bitonicMerge<<<blocks, threadPerBlock>>>(
                        arr, Cmp, monotonicSeqLen, len, partsInSeq);
                }
                break;
            }
        int blockDim = maxThreadsPerBlock / 2; //256
        int gridDim = threadsNeeded / blockDim + (threadsNeeded % blockDim != 0);
        sharedMemSize /= 2;
        sharedMemLen /= 2;
        bitonicSortWithShared(view, Cmp, gridDim, blockDim, sharedMemLen, sharedMemSize);
    }
    else
    {
        int gridDim = threadsNeeded / maxThreadsPerBlock + (threadsNeeded % maxThreadsPerBlock != 0);
        bitonicSort(view, Cmp, gridDim, maxThreadsPerBlock);
    }
    cudaDeviceSynchronize();
}

//---------------------------------------------