From 5b7045ed2146a878cead708413ed970effa6a6bf Mon Sep 17 00:00:00 2001
From: Xuan Thang Nguyen <nguyexu2@fit.cvut.cz>
Date: Wed, 7 Apr 2021 23:56:08 +0200
Subject: [PATCH] support for structure sorting

---
 src/quicksort/cudaPartition.cuh        |  2 +-
 src/quicksort/quicksort.cuh            | 15 ++-------------
 src/quicksort/quicksort_1Block.cuh     | 25 ++++++++++++++++++-------
 tests/quicksort_unitTests/unitTests.cu | 22 ++++++++++++++++++++++
 4 files changed, 43 insertions(+), 21 deletions(-)

diff --git a/src/quicksort/cudaPartition.cuh b/src/quicksort/cudaPartition.cuh
index ebc2813..93f4e5f 100644
--- a/src/quicksort/cudaPartition.cuh
+++ b/src/quicksort/cudaPartition.cuh
@@ -39,7 +39,7 @@ __device__ Value pickPivot(TNL::Containers::ArrayView<Value, Device> src, const
 }
 
 template <typename Value, typename Device, typename Function>
-__device__ Value pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
+__device__ int pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
 {
     //return 0;
     //return src.getSize()-1;
diff --git a/src/quicksort/quicksort.cuh b/src/quicksort/quicksort.cuh
index 22dd4cb..1c4e4a3 100644
--- a/src/quicksort/quicksort.cuh
+++ b/src/quicksort/quicksort.cuh
@@ -75,15 +75,12 @@ __global__ void cudaQuickSort1stPhase(ArrayView<Value, Devices::Cuda> arr, Array
     extern __shared__ int externMem[];
     Value *sharedMem = (Value *)externMem;
 
-    static __shared__ Value pivot;
 
     TASK &myTask = tasks[taskMapping[blockIdx.x]];
     auto &src = (myTask.depth & 1) == 0 ? arr : aux;
     auto &dst = (myTask.depth & 1) == 0 ? aux : arr;
 
-    if (threadIdx.x == 0)
-        pivot = src[myTask.pivotIdx];
-    __syncthreads();
+    Value pivot = src[myTask.pivotIdx];
 
     cudaPartition<Value, Function, useShared>(
         src.getView(myTask.partitionBegin, myTask.partitionEnd),
@@ -99,17 +96,9 @@ __global__ void cudaWritePivot(ArrayView<Value, Devices::Cuda> arr, ArrayView<Va
                                ArrayView<TASK, Devices::Cuda> tasks, ArrayView<TASK, Devices::Cuda> newTasks, int *newTasksCnt,
                                ArrayView<TASK, Devices::Cuda> secondPhaseTasks, int *secondPhaseTasksCnt)
 {
-    static __shared__ Value pivot;
     TASK &myTask = tasks[blockIdx.x];
 
-    if (threadIdx.x == 0)
-    {
-        if ((myTask.depth & 1) == 0)
-            pivot = arr[myTask.pivotIdx];
-        else
-            pivot = aux[myTask.pivotIdx];
-    }
-    __syncthreads();
+    Value pivot = (myTask.depth & 1) == 0 ? arr[myTask.pivotIdx] : aux[myTask.pivotIdx];
 
     int leftBegin = myTask.partitionBegin, leftEnd = myTask.partitionBegin + myTask.dstBegin;
     int rightBegin = myTask.partitionBegin + myTask.dstEnd, rightEnd = myTask.partitionEnd;
diff --git a/src/quicksort/quicksort_1Block.cuh b/src/quicksort/quicksort_1Block.cuh
index e63e4e0..4596736 100644
--- a/src/quicksort/quicksort_1Block.cuh
+++ b/src/quicksort/quicksort_1Block.cuh
@@ -17,6 +17,14 @@ __device__ void externSort(ArrayView<Value, TNL::Devices::Cuda> src,
     bitonicSort_Block(src, dst, sharedMem, Cmp);
 }
 
+template <typename Value, typename Function>
+__device__ void externSort(ArrayView<Value, TNL::Devices::Cuda> src,
+                           ArrayView<Value, TNL::Devices::Cuda> dst,
+                           const Function &Cmp)
+{
+    bitonicSort_Block(src, dst, Cmp);
+}
+
 template <int stackSize>
 __device__ void stackPush(int stackArrBegin[], int stackArrEnd[],
                           int stackDepth[], int &stackTop,
@@ -79,7 +87,11 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
     if (arr.getSize() <= blockDim.x * 2)
     {
         auto src = (_depth & 1) == 0 ? arr : aux;
-        externSort<Value, Function>(src, arr, Cmp, sharedMem);
+        if(useShared)
+            externSort<Value, Function>(src, arr, Cmp, sharedMem);
+        else
+            externSort<Value, Function>(src, arr, Cmp);
+
         return;
     }
 
@@ -87,7 +99,6 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
     static __shared__ int stackArrBegin[stackSize], stackArrEnd[stackSize], stackDepth[stackSize];
     static __shared__ int begin, end, depth;
     static __shared__ int pivotBegin, pivotEnd;
-    static __shared__ Value pivot;
 
     if (threadIdx.x == 0)
     {
@@ -117,16 +128,16 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
         //small enough for for bitonic
         if (size <= blockDim.x * 2)
         {
-            externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp, sharedMem);
+            if(useShared)
+                externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp, sharedMem);
+            else
+                externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp);
             __syncthreads();
             continue;
         }
         //------------------------------------------------------
 
-        //actually do partitioning from here on out
-        if (threadIdx.x == 0)
-            pivot = pickPivot(src.getView(begin, end), Cmp);
-        __syncthreads();
+        Value pivot = pickPivot(src.getView(begin, end), Cmp);
 
         int smaller = 0, bigger = 0;
         countElem(src.getView(begin, end), Cmp, smaller, bigger, pivot);
diff --git a/tests/quicksort_unitTests/unitTests.cu b/tests/quicksort_unitTests/unitTests.cu
index e87493a..dc3518e 100644
--- a/tests/quicksort_unitTests/unitTests.cu
+++ b/tests/quicksort_unitTests/unitTests.cu
@@ -141,6 +141,28 @@ TEST(types, type_double)
     ASSERT_TRUE(view == cudaArr2.getView());
 }
 
+struct TMPSTRUCT{
+    uint8_t m_data[16];
+
+    __cuda_callable__ TMPSTRUCT(){m_data[0] = 0;}
+    __cuda_callable__ TMPSTRUCT(int first){m_data[0] = first;};
+    __cuda_callable__ bool operator <(const TMPSTRUCT& other) const { return m_data[0] < other.m_data[0];}
+};
+
+
+TEST(types, struct)
+{
+    std::srand(8451);
+
+    int size = (1<<13);
+    std::vector<TMPSTRUCT> arr(size);
+    for(auto & x : arr) x = TMPSTRUCT(std::rand());
+
+    TNL::Containers::Array<TMPSTRUCT, TNL::Devices::Cuda> cudaArr(arr);
+    auto view = cudaArr.getView();
+    quicksort(view);
+}
+
 //----------------------------------------------------------------------------------
 
 int main(int argc, char **argv)
-- 
GitLab