Commit 0a2e9a5a authored by Tat Dat Duong's avatar Tat Dat Duong
Browse files

chore: support key nodes, merge host and cuda tests

parent 7590a1db
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -60,5 +60,4 @@ include_directories(~/.local/include)
include_directories(src)

add_subdirectory(test)
add_subdirectory(unittest)
add_subdirectory(benchmark)
+9 −5
Original line number Diff line number Diff line
@@ -31,7 +31,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Cuda> {
  template <typename Warp>
  __device__ static inline BNode *scan(BNode *node, KeyType key, Warp &warp) {

    if (node->mHighKeyFlag && key >= node->mHighKey) {
    if (node->mHighKeyFlag && !(key < node->mHighKey)) {
      return node->mSibling;
    }

@@ -188,7 +188,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Cuda> {
  __device__ static inline uint32_t lowerBoundWarp(KeyType *haystack, uint32_t haystackSize,
                                                   KeyType needle, Warp &warp) {
    uint64_t rank = warp.thread_rank();
    uint32_t ballot = warp.ballot(haystack[rank] >= needle);
    uint32_t ballot = warp.ballot(!(haystack[rank] < needle));
    uint32_t index = __ffs(ballot) - 1;

    // Enforce one PTX instruction
@@ -200,7 +200,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Cuda> {
  __device__ static inline uint32_t upperBoundWarp(KeyType *haystack, uint32_t haystackSize,
                                                   KeyType needle, Warp &warp) {
    uint64_t rank = warp.thread_rank();
    uint32_t ballot = warp.ballot(haystack[rank] > needle);
    uint32_t ballot = warp.ballot(needle < haystack[rank]);
    uint32_t index = __ffs(ballot) - 1;

    // Enforce one PTX instruction
@@ -229,9 +229,13 @@ public:
      uint16_t isLeaf = next->mLeaf;

      uint32_t size = next->mSize;
      uint32_t ballot = warp.ballot(next->mKeys[rank] >= key);
      uint32_t ballot = warp.ballot(!(next->mKeys[rank] < key));
      uint32_t targetIdx = min(__ffs(ballot) - 1, size);
      uint32_t foundKey = next->mKeys[targetIdx] == key;

      KeyType targetKey = next->mKeys[targetIdx];

      // targetKey == key
      uint32_t foundKey = !(targetKey < key || key < targetKey);

      if (isLeaf) {
        if (foundKey && targetIdx < size) {
+22 −18
Original line number Diff line number Diff line
@@ -27,8 +27,8 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
  using Allocator = BumpAllocator<BNode, Device>;
  using _Operations = BLinkOperations<KeyType, ValueType, Order, Device>;

  __cuda_callable__ static inline BNode *scan(BNode *node, KeyType key) {
    if (node->mHighKeyFlag && key >= node->mHighKey) {
  static inline BNode *scan(BNode *node, KeyType key) {
    if (node->mHighKeyFlag && !(key < node->mHighKey)) {
      return node->mSibling;
    }

@@ -44,7 +44,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
    return result;
  }

  __cuda_callable__ static inline BNode *moveSide(BNode *cursor, KeyType key) {
  static inline BNode *moveSide(BNode *cursor, KeyType key) {
    BNode *tmp, *result = cursor;
    while (result != nullptr && result->mSibling != nullptr &&
           (tmp = scan(result, key)) == result->mSibling) {
@@ -54,7 +54,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
    return result;
  }

  __cuda_callable__ static inline BNode *findLeaf(BNode *root, KeyType key) {
  static inline BNode *findLeaf(BNode *root, KeyType key) {
    BNode *cursor = root;
    while (cursor != nullptr && cursor->mLeaf == false) {
      cursor = scan(cursor, key);
@@ -64,7 +64,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
    return cursor;
  }

  __cuda_callable__ static BNode *splitNode(BNode *cursor, Allocator &alloc, size_t siblingStart,
  static BNode *splitNode(BNode *cursor, Allocator &alloc, size_t siblingStart,
                          size_t cursorCount) {

    BNode *sibling = alloc.allocate();
@@ -91,8 +91,7 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
    return sibling;
  }

  __cuda_callable__ static inline void insertIntoFreeInternal(BNode *curr, KeyType insertKey,
                                                              BNode *insertNode) {
  static inline void insertIntoFreeInternal(BNode *curr, KeyType insertKey, BNode *insertNode) {
    size_t keyIdx = upperBound(curr->mKeys, curr->mSize, insertKey);
    _Operations::insertChild(curr, keyIdx + 1, insertNode, curr->childSize());
    _Operations::insertKey(curr, keyIdx, insertKey);
@@ -103,15 +102,15 @@ class BLinkTree<KeyType, ValueType, Order, Devices::Host> {
    }
  }

  __cuda_callable__ static inline void insertIntoFreeLeaf(BNode *curr, KeyType insertKey,
                                                          ValueType value, BNode *insertNode) {
  static inline void insertIntoFreeLeaf(BNode *curr, KeyType insertKey, ValueType value,
                                        BNode *insertNode) {
    size_t keyIdx = upperBound(curr->mKeys, curr->mSize, insertKey);
    _Operations::insertChild(curr, keyIdx, insertNode, curr->childSize());
    _Operations::insertKey(curr, keyIdx, insertKey, value);
  }

  __cuda_callable__ static inline void increaseTreeHeight(BNode *curr, KeyType insertKey,
                                                          BNode *insertNode, Allocator &alloc) {
  static inline void increaseTreeHeight(BNode *curr, KeyType insertKey, BNode *insertNode,
                                        Allocator &alloc) {

    BNode *leftNode = alloc.allocate();

@@ -141,27 +140,32 @@ public:
  using Node = BNode;
  using Operations = _Operations;

  __cuda_callable__ static inline BNode *init(Allocator &alloc) {
  static inline BNode *init(Allocator &alloc) {
    BNode *root = alloc.allocate();
    _Operations::init(root, true, nullptr, false);
    return root;
  }

  __cuda_callable__ static inline bool find(BNode *root, KeyType key, ValueType &result) {
  static inline bool find(BNode *root, KeyType key, ValueType &result) {
    BNode *leaf = findLeaf(root, key);
    if (leaf == nullptr)
      return false;

    size_t it = lowerBound(leaf->mKeys, leaf->mSize, key);
    if (it < leaf->mSize && leaf->mKeys[it] == key) {
    if (it >= leaf->mSize) {
      return false;
    }

    KeyType leafKey = leaf->mKeys[it];
    // leafKey == key
    if (!(leafKey < key || key < leafKey)) {
      result = leaf->mValues[it];
      return true;
    }
    return false;
  }

  __cuda_callable__ static bool insert(BNode *root, KeyType key, ValueType value, Allocator &alloc,
                                       Latch &latch) {
  static bool insert(BNode *root, KeyType key, ValueType value, Allocator &alloc, Latch &latch) {
    BNode *prev = nullptr, *curr = nullptr;
    do {
      BNode *tmpPrev = prev;
@@ -209,7 +213,7 @@ public:
    return true;
  }

  __cuda_callable__ static inline bool remove(BNode *root, KeyType key, Latch &latch) {
  static inline bool remove(BNode *root, KeyType key, Latch &latch) {
    BNode *cursor = findLeaf(root, key);

    if (cursor == nullptr)
+3 −3
Original line number Diff line number Diff line
@@ -45,8 +45,8 @@ struct BLinkOperations<KeyType, ValueType, Order, TNL::Devices::Cuda> {

    auto rank = warp.thread_rank();
    if (rank < Order) {
      KeyType thrKey = 0;
      ValueType thrVal = 0;
      KeyType thrKey;
      ValueType thrVal;

      if (rank == index) {
        thrKey = key;
@@ -74,7 +74,7 @@ struct BLinkOperations<KeyType, ValueType, Order, TNL::Devices::Cuda> {

    auto rank = warp.thread_rank();
    if (rank < Order) {
      KeyType thrKey = 0;
      KeyType thrKey;

      if (rank == index) {
        thrKey = key;
+2 −2
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ struct BNodeOperations<KeyType, ValueType, Order, TNL::Devices::Cuda> {

    auto rank = warp.thread_rank();
    if (rank < Order) {
      KeyType thrKey = 0;
      KeyType thrKey;
      ValueType thrVal;

      if (rank == index) {
@@ -78,7 +78,7 @@ struct BNodeOperations<KeyType, ValueType, Order, TNL::Devices::Cuda> {

    auto rank = warp.thread_rank();
    if (rank < Order) {
      KeyType thrKey = 0;
      KeyType thrKey;

      if (rank == index) {
        thrKey = key;
Loading