diff --git a/src/quicksort/cudaPartition.cuh b/src/quicksort/cudaPartition.cuh
index ebc28137e617fcba2222c572c8e98d495422c430..93f4e5fe3651f4dac5c8ddeaa3e17ae46da78479 100644
--- a/src/quicksort/cudaPartition.cuh
+++ b/src/quicksort/cudaPartition.cuh
@@ -39,7 +39,7 @@ __device__ Value pickPivot(TNL::Containers::ArrayView<Value, Device> src, const
 }
 
 template <typename Value, typename Device, typename Function>
-__device__ Value pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
+__device__ int pickPivotIdx(TNL::Containers::ArrayView<Value, Device> src, const Function &Cmp)
 {
     //return 0;
     //return src.getSize()-1;
diff --git a/src/quicksort/quicksort.cuh b/src/quicksort/quicksort.cuh
index 22dd4cbd25b9d37527d6fc368ead9db5aebc72e9..1c4e4a3dc78bb6c6f4c63d40b04c0dfefdac0f63 100644
--- a/src/quicksort/quicksort.cuh
+++ b/src/quicksort/quicksort.cuh
@@ -75,15 +75,12 @@ __global__ void cudaQuickSort1stPhase(ArrayView<Value, Devices::Cuda> arr, Array
     extern __shared__ int externMem[];
     Value *sharedMem = (Value *)externMem;
 
-    static __shared__ Value pivot;
 
     TASK &myTask = tasks[taskMapping[blockIdx.x]];
     auto &src = (myTask.depth & 1) == 0 ? arr : aux;
     auto &dst = (myTask.depth & 1) == 0 ? aux : arr;
 
-    if (threadIdx.x == 0)
-        pivot = src[myTask.pivotIdx];
-    __syncthreads();
+    Value pivot = src[myTask.pivotIdx];
 
     cudaPartition<Value, Function, useShared>(
         src.getView(myTask.partitionBegin, myTask.partitionEnd),
@@ -99,17 +96,9 @@ __global__ void cudaWritePivot(ArrayView<Value, Devices::Cuda> arr, ArrayView<Va
                                ArrayView<TASK, Devices::Cuda> tasks, ArrayView<TASK, Devices::Cuda> newTasks, int *newTasksCnt,
                                ArrayView<TASK, Devices::Cuda> secondPhaseTasks, int *secondPhaseTasksCnt)
 {
-    static __shared__ Value pivot;
     TASK &myTask = tasks[blockIdx.x];
 
-    if (threadIdx.x == 0)
-    {
-        if ((myTask.depth & 1) == 0)
-            pivot = arr[myTask.pivotIdx];
-        else
-            pivot = aux[myTask.pivotIdx];
-    }
-    __syncthreads();
+    Value pivot = (myTask.depth & 1) == 0 ? arr[myTask.pivotIdx] : aux[myTask.pivotIdx];
 
     int leftBegin = myTask.partitionBegin, leftEnd = myTask.partitionBegin + myTask.dstBegin;
     int rightBegin = myTask.partitionBegin + myTask.dstEnd, rightEnd = myTask.partitionEnd;
diff --git a/src/quicksort/quicksort_1Block.cuh b/src/quicksort/quicksort_1Block.cuh
index e63e4e0062cef7f424d449942f691d1b8e7423d7..45967366cb9da52f54190a3e02a73065f31c1767 100644
--- a/src/quicksort/quicksort_1Block.cuh
+++ b/src/quicksort/quicksort_1Block.cuh
@@ -17,6 +17,14 @@ __device__ void externSort(ArrayView<Value, TNL::Devices::Cuda> src,
     bitonicSort_Block(src, dst, sharedMem, Cmp);
 }
 
+template <typename Value, typename Function>
+__device__ void externSort(ArrayView<Value, TNL::Devices::Cuda> src,
+                           ArrayView<Value, TNL::Devices::Cuda> dst,
+                           const Function &Cmp)
+{
+    bitonicSort_Block(src, dst, Cmp);
+}
+
 template <int stackSize>
 __device__ void stackPush(int stackArrBegin[], int stackArrEnd[],
                           int stackDepth[], int &stackTop,
@@ -79,7 +87,11 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
     if (arr.getSize() <= blockDim.x * 2)
     {
         auto src = (_depth & 1) == 0 ? arr : aux;
-        externSort<Value, Function>(src, arr, Cmp, sharedMem);
+        if(useShared)
+            externSort<Value, Function>(src, arr, Cmp, sharedMem);
+        else
+            externSort<Value, Function>(src, arr, Cmp);
+
         return;
     }
 
@@ -87,7 +99,6 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
     static __shared__ int stackArrBegin[stackSize], stackArrEnd[stackSize], stackDepth[stackSize];
     static __shared__ int begin, end, depth;
     static __shared__ int pivotBegin, pivotEnd;
-    static __shared__ Value pivot;
 
     if (threadIdx.x == 0)
     {
@@ -117,16 +128,16 @@ __device__ void singleBlockQuickSort(ArrayView<Value, TNL::Devices::Cuda> arr,
         //small enough for for bitonic
         if (size <= blockDim.x * 2)
         {
-            externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp, sharedMem);
+            if(useShared)
+                externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp, sharedMem);
+            else
+                externSort<Value, Function>(src.getView(begin, end), arr.getView(begin, end), Cmp);
             __syncthreads();
             continue;
         }
         //------------------------------------------------------
 
-        //actually do partitioning from here on out
-        if (threadIdx.x == 0)
-            pivot = pickPivot(src.getView(begin, end), Cmp);
-        __syncthreads();
+        Value pivot = pickPivot(src.getView(begin, end), Cmp);
 
         int smaller = 0, bigger = 0;
         countElem(src.getView(begin, end), Cmp, smaller, bigger, pivot);
diff --git a/tests/quicksort_unitTests/unitTests.cu b/tests/quicksort_unitTests/unitTests.cu
index e87493add48150a2622c12666e8e0e57d7643c7f..dc3518e494c246e530dfbe9f33d73c038eb01d22 100644
--- a/tests/quicksort_unitTests/unitTests.cu
+++ b/tests/quicksort_unitTests/unitTests.cu
@@ -141,6 +141,28 @@ TEST(types, type_double)
     ASSERT_TRUE(view == cudaArr2.getView());
 }
 
+struct TMPSTRUCT{
+    uint8_t m_data[16];
+
+    __cuda_callable__ TMPSTRUCT(){m_data[0] = 0;}
+    __cuda_callable__ TMPSTRUCT(int first){m_data[0] = first;};
+    __cuda_callable__ bool operator <(const TMPSTRUCT& other) const { return m_data[0] < other.m_data[0];}
+};
+
+
+TEST(types, struct)
+{
+    std::srand(8451);
+
+    int size = (1<<13);
+    std::vector<TMPSTRUCT> arr(size);
+    for(auto & x : arr) x = TMPSTRUCT(std::rand());
+
+    TNL::Containers::Array<TMPSTRUCT, TNL::Devices::Cuda> cudaArr(arr);
+    auto view = cudaArr.getView();
+    quicksort(view);
+}
+
 //----------------------------------------------------------------------------------
 
 int main(int argc, char **argv)