From 5617f46270149b5f888608829ec3db5cc211f776 Mon Sep 17 00:00:00 2001
From: Xuan Thang Nguyen <nguyexu2@fit.cvut.cz>
Date: Fri, 9 Apr 2021 17:12:56 +0200
Subject: [PATCH] fix copy and add support for block bitonic sort of bigger
 size

---
 src/bitonicSort/bitonicSort.h | 64 ++++++++++++++++++++---------------
 1 file changed, 36 insertions(+), 28 deletions(-)

diff --git a/src/bitonicSort/bitonicSort.h b/src/bitonicSort/bitonicSort.h
index ac1c67f..958e8e4 100644
--- a/src/bitonicSort/bitonicSort.h
+++ b/src/bitonicSort/bitonicSort.h
@@ -191,27 +191,32 @@ __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices
     //------------------------------------------
     //bitonic activity
     {
-        int i = threadIdx.x;
         int paddedSize = closestPow2_ptx(src.getSize());
 
         for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
         {
-            //calculate the direction of swapping
-            int monotonicSeqIdx = i / (monotonicSeqLen / 2);
-            bool ascending = (monotonicSeqIdx & 1) != 0;
-            if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
-                ascending = true;
-
             for (int len = monotonicSeqLen; len > 1; len /= 2)
             {
-                //calculates which 2 indexes will be compared and swap
-                int part = i / (len / 2);
-                int s = part * len + (i & ((len / 2) - 1));
-                int e = s + len / 2;
+                for(int i = threadIdx.x; ; i+=blockDim.x) //simulates other blocks in case src.size > blockDim.x*2
+                {
+                    //calculates which 2 indexes will be compared and swap
+                    int part = i / (len / 2);
+                    int s = part * len + (i & ((len / 2) - 1));
+                    int e = s + len / 2;
+
+                    if(e >= src.getSize()) //touching virtual padding, the order dont swap
+                        break;
+
+                    //calculate the direction of swapping
+                    int monotonicSeqIdx = i / (monotonicSeqLen / 2);
+                    bool ascending = (monotonicSeqIdx & 1) != 0;
+                    if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
+                        ascending = true;
 
-                if (e < src.getSize()) //not touching virtual padding
                     cmpSwap(sharedMem[s], sharedMem[e], ascending, Cmp);
-                __syncthreads();
+                }
+                
+                __syncthreads(); //only 1 synchronization needed
             }
         }
     }
@@ -232,29 +237,32 @@ __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices
  * */
 template <typename Value, typename Function>
 __device__ void bitonicSort_Block(TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> src,
-                                  TNL::Containers::ArrayView<Value, TNL::Devices::Cuda> dst,
                                   const Function &Cmp)
 {
-    int i = threadIdx.x;
     int paddedSize = closestPow2_ptx(src.getSize());
 
     for (int monotonicSeqLen = 2; monotonicSeqLen <= paddedSize; monotonicSeqLen *= 2)
     {
-        //calculate the direction of swapping
-        int monotonicSeqIdx = i / (monotonicSeqLen / 2);
-        bool ascending = (monotonicSeqIdx & 1) != 0;
-        if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
-            ascending = true;
-
         for (int len = monotonicSeqLen; len > 1; len /= 2)
         {
-            //calculates which 2 indexes will be compared and swap
-            int part = i / (len / 2);
-            int s = part * len + (i & ((len / 2) - 1));
-            int e = s + len / 2;
+            for(int i = threadIdx.x; ; i+=blockDim.x) //simulates other blocks in case src.size > blockDim.x*2
+            {
+                //calculates which 2 indexes will be compared and swap
+                int part = i / (len / 2);
+                int s = part * len + (i & ((len / 2) - 1));
+                int e = s + len / 2;
+
+                if(e >= src.getSize())
+                    break;
+
+                //calculate the direction of swapping
+                int monotonicSeqIdx = i / (monotonicSeqLen / 2);
+                bool ascending = (monotonicSeqIdx & 1) != 0;
+                if ((monotonicSeqIdx + 1) * monotonicSeqLen >= src.getSize()) //special case for parts with no "partner"
+                    ascending = true;
 
-            if (e < src.getSize()) //not touching virtual padding
                 cmpSwap(src[s], src[e], ascending, Cmp);
+            }
             __syncthreads();
         }
     }
@@ -292,9 +300,9 @@ __global__ void bitoniSort1stStep(TNL::Containers::ArrayView<Value, TNL::Devices
     int myBlockEnd = TNL::min(arr.getSize(), myBlockStart + (2 * blockDim.x));
 
     if (blockIdx.x % 2 || blockIdx.x + 1 == gridDim.x)
-        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd), Cmp);
+        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), Cmp);
     else
-        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd), arr.getView(myBlockStart, myBlockEnd),
+        bitonicSort_Block(arr.getView(myBlockStart, myBlockEnd),
                           [&] __cuda_callable__(const Value &a, const Value &b) { return Cmp(b, a); });
 }
 
-- 
GitLab