Commit 3717d64a authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

versions that dont use shared memory

parent 24899efe
Loading
Loading
Loading
Loading
+73 −7
Original line number Diff line number Diff line
@@ -120,6 +120,44 @@ void bitonicMergeSharedMemory(TNL::Containers::ArrayView<Value, TNL::Devices::Cu
    }
}


template <typename Value, typename Function>
__global__
void bitonicMerge(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> arr,
                            const Function & Cmp,
                            int monotonicSeqLen, int len, int partsInSeq)
{
    //1st index and last index of subarray that this threadBlock should merge
    int myBlockStart = blockIdx.x * (2*blockDim.x);
    int myBlockEnd = TNL::min(arr.getSize(), myBlockStart+(2*blockDim.x));

    auto src = arr.getView(myBlockStart, myBlockEnd);

    //calculate the direction of swapping
    int i = blockIdx.x * blockDim.x + threadIdx.x;
    int part = i / (len / 2);
    int monotonicSeqIdx = part / partsInSeq;

    bool ascending = (monotonicSeqIdx & 1) != 0;
    //special case for parts with no "partner"
    if ((monotonicSeqIdx + 1) * monotonicSeqLen >= arr.getSize())
        ascending = true;
    //------------------------------------------

    //do bitonic merge
    for (; len > 1; len /= 2)
    {
        //calculates which 2 indexes will be compared and swap
        int part = threadIdx.x / (len / 2);
        int s = part * len + (threadIdx.x & ((len /2) - 1));
        int e = s + len / 2;

        if(e < myBlockEnd - myBlockStart) //not touching virtual padding
            cmpSwap(src[s], src[e], ascending, Cmp);
        __syncthreads();
    }
}

//---------------------------------------------

template <typename Value, typename Function>
@@ -232,6 +270,19 @@ __global__ void bitoniSort1stStepSharedMemory(TNL::Containers::ArrayView<Value,
        );
}

template <typename Value, typename Function>
__global__ void bitoniSort1stStep(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> arr, const Function & Cmp)
{
    int myBlockStart = blockIdx.x * (2*blockDim.x);
    int myBlockEnd = TNL::min(arr.getSize(), myBlockStart + (2*blockDim.x));

    if(blockIdx.x%2 || blockIdx.x + 1 == gridDim.x)
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd), Cmp);
    else
        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd),
            [&] __cuda_callable__ (const Value&a, const Value&b){return Cmp(b, a);}
        );
}

//---------------------------------------------
template <typename Value, typename Function>
@@ -246,12 +297,20 @@ void bitonicSort(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src, int
    int threadPerBlock = maxThreadsPerBlock;
    int blocks = threadsNeeded / threadPerBlock + (threadsNeeded % threadPerBlock != 0);

    const int sharedMemLen = threadPerBlock * 2;
    const int sharedMemSize = sharedMemLen* sizeof(Value);
    int sharedMemLen = threadPerBlock * 2;
    int sharedMemSize = sharedMemLen* sizeof(Value);

    //---------------------------------------------------------------------------------

    cudaDeviceProp deviceProp;
    cudaGetDeviceProperties(&deviceProp, 0);

    //---------------------------------------------------------------------------------

    if(sharedMemSize <= deviceProp.sharedMemPerBlock)
        bitoniSort1stStepSharedMemory<<<blocks, threadPerBlock, sharedMemSize>>>(arr, Cmp);
    else
        bitoniSort1stStep<<<blocks, threadPerBlock>>>(arr, Cmp);
    
    for (int monotonicSeqLen = 2*sharedMemLen; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
    {
@@ -264,9 +323,16 @@ void bitonicSort(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src, int
            }
            else
            {

                if(sharedMemSize <= deviceProp.sharedMemPerBlock)
                {
                    bitonicMergeSharedMemory<<<blocks, threadPerBlock, sharedMemSize>>>(
                        arr, Cmp, monotonicSeqLen, len, partsInSeq);
                }
                else
                {
                    bitonicMerge<<<blocks, threadPerBlock>>>(
                        arr, Cmp, monotonicSeqLen, len, partsInSeq);
                }
                break;
            }
        }