Commit af5ad08b authored by Xuan Thang Nguyen's avatar Xuan Thang Nguyen
Browse files

fix blockwide pivot write

parent 5000c37b
Loading
Loading
Loading
Loading
+10 −15
Original line number Diff line number Diff line
@@ -34,8 +34,7 @@ __device__ void copyData(CudaArrayView arr, int myBegin, int myEnd, int pivot,

__global__ void cudaPartition(CudaArrayView arr, int begin, int end,
                              CudaArrayView aux, int *auxBeginIdx, int *auxEndIdx,
                              int pivotIdx, int *newPivotPos,
                              int elemPerBlock)
                              int pivotIdx, int elemPerBlock)
{
    static __shared__ int smallerStart, biggerStart;
    static __shared__ int pivot;
@@ -61,14 +60,6 @@ __global__ void cudaPartition(CudaArrayView arr, int begin, int end,
    __syncthreads();

    copyData(arr, myBegin, myEnd, pivot, aux, smallerStart + smallerOffset - smaller, biggerStart + biggerOffset - bigger);
    __syncthreads();
    
    //inserts pivot
    if (threadIdx.x * blockIdx.x == 0)
    {
        aux[*auxEndIdx - 1] = pivot;
        *newPivotPos = *auxEndIdx - 1; 
    }
}

int partition(CudaArrayView arr, int begin, int end, int pivotIdx)
@@ -95,18 +86,22 @@ int partition(CudaArrayView arr, int begin, int end, int pivotIdx)
    TNL::Algorithms::MultiDeviceMemoryOperations<TNL::Devices::Cuda, TNL::Devices::Cuda >::
    copy(aux.getData(), arr.getData(), arr.getSize());
    
    TNL::Containers::Array<int, TNL::Devices::Cuda> cudaAuxBegin({begin}), cudaAuxEnd({end}), newPivotPos(1);
    TNL::Containers::Array<int, TNL::Devices::Cuda> cudaAuxBegin({begin}), cudaAuxEnd({end});
    
    //------------------------------------

    int pivot = arr.getElement(pivotIdx);
    cudaPartition<<<blocks, threadsPerBlock>>>(arr, begin, end,
        aux, cudaAuxBegin.getData(), cudaAuxEnd.getData(),
        pivotIdx, newPivotPos.getData(),
        elemPerBlock);
        pivotIdx, elemPerBlock);
    cudaDeviceSynchronize();

    pivotIdx = cudaAuxEnd.getElement(0) - 1;
    aux.setElement(pivotIdx, pivot);
    //------------------------------------
    TNL::Algorithms::MultiDeviceMemoryOperations<TNL::Devices::Cuda, TNL::Devices::Cuda >::
    copy(arr.getData(), aux.getData(), aux.getSize());
    return newPivotPos.getElement(0);
    return pivotIdx;
}

//-----------------------------------------------------------------------------------------