Commit a70f200b authored by kolusask's avatar kolusask
Browse files

Store hash functions indeces with the content

parent d90cadb7
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ template<typename Item, typename Key, typename Device>
class CuckooHashTable {
    using Self = CuckooHashTable<Item, Key, Device>;
    using ViewType = CuckooHashTableView<Item, Key, Device>;
    using Entry = typename ViewType::Entry;
    friend ViewType::CuckooHashTableView(Self&, const typename Array<Item, Device>::ConstViewType);

  public:
@@ -69,7 +70,7 @@ class CuckooHashTable {

    Array<HashFunction<Key>, Device> m_hashFunctions;
    Array<int, Device> m_table;
    Array<Item, Device> m_content;
    Array<Entry, Device> m_content;

  private:
    int m_iterations;
+9 −2
Original line number Diff line number Diff line
@@ -23,6 +23,11 @@ class CuckooHashTableView {
    using TableType = CuckooHashTable<Item, Key, Device>;

  public:
    struct Entry {
      Item item;
      int hashFunction;
    };

    CuckooHashTableView(TableType& table, const typename Array<Item, Device>::ConstViewType items);
  
    //! Find and return item having provided key
@@ -43,12 +48,14 @@ class CuckooHashTableView {
    bool insert(Item item, bool& duplicate);

    //! Populate the table with given values, regenerate hash functions on each failure
    void build(const typename Array<Item, Device>::ConstViewType items);
    void build();
    
    void copy_items(const typename Array<Item, Device>::ConstViewType items);

  protected:
    ArrayView<HashFunction<Key>, Device> m_hashFunctions;
    ArrayView<int, Device> m_table;
    ArrayView<Item, Device> m_content;
    ArrayView<Entry, Device> m_content;

  private:
    int m_iterations;
+30 −24
Original line number Diff line number Diff line
@@ -16,8 +16,10 @@ CuckooHashTableView<Item, Key, Device>::CuckooHashTableView(TableType& table, co
        m_generations(0),
        m_duplicates(0) {
    m_table.setValue(-1);
    if (!items.empty())
        build(items);
    if (!items.empty()) {
        copy_items(items);
        build();
    }
}

template<typename Item, typename Key, typename Device>
@@ -32,12 +34,13 @@ void CuckooHashTableView<Item, Key, Device>::init_hash_functions() {
}

template<typename Item, typename Key, typename Device>
__cuda_callable__ bool insert_impl(int entry, typename Array<Item, Device>::ConstViewType items, 
__cuda_callable__ bool insert_impl(int entry, 
                                   ArrayView<typename CuckooHashTableView<Item, Key, Device>::Entry, Device> items, 
                                   bool& duplicate, int iter,
                                   ArrayView<HashFunction<Key>, Device> hf, 
                                   ArrayView<int, Device> table) {
    duplicate = false;
    int index = hf[0](items[entry].key);
    int index = hf[0](items[entry].item.key);

    for (int i = 0; i < iter; i++) {
        #ifdef __CUDA_ARCH__
@@ -51,18 +54,15 @@ __cuda_callable__ bool insert_impl(int entry, typename Array<Item, Device>::Cons
            return true;

        // Don't allow inserting duplicates
        if (items[entry].key == items[table[index]].key) {
        if (items[entry].item.key == items[table[index]].item.key) {
            duplicate = true;
            return false;
        }

        // Find a new position for the old content
        int h;
        for (h = 0; h < hf.getSize(); h++)
            if (hf[h](items[entry].key) == index)
                break;
        int newHash = (h + 1) % hf.getSize();
        index = hf[newHash](items[entry].key); 
        int newHash = (items[entry].hashFunction + 1) % hf.getSize();
        items[entry].hashFunction = newHash;
        index = hf[newHash](items[entry].item.key); 
    }

    // maximum number of iterations reached
@@ -70,17 +70,20 @@ __cuda_callable__ bool insert_impl(int entry, typename Array<Item, Device>::Cons
}

template<typename Item, typename Key, typename Device>
void CuckooHashTableView<Item, Key, Device>::build(const typename Array<Item, Device>::ConstViewType items) {
    auto cont = m_content;
    auto _fill_cont = [cont, items] __cuda_callable__ (int i) mutable {
        cont[i] = items[i];
void CuckooHashTableView<Item, Key, Device>::copy_items(const typename Array<Item, Device>::ConstViewType items) {
    auto _init_content = [items] __cuda_callable__ (int i, ArrayView<Entry, Device> content) mutable {
        content[i] = {items[i], -1};
    };
    TNL::Algorithms::ParallelFor<Device>::exec(0, cont.getSize(), _fill_cont);
    TNL::Algorithms::ParallelFor<Device>::exec(0, items.getSize(), _init_content, m_content);
}

template<typename Item, typename Key, typename Device>
void CuckooHashTableView<Item, Key, Device>::build() {
    Array<unsigned, Device> result(2);
    auto rview = result.getView();
    rview.setValue(0);
    auto _build = [items, rview] __cuda_callable__ (int i, 
    auto _build = [rview] __cuda_callable__ (int i, 
                                             decltype(m_content) items,
                                             decltype(m_iterations) iter,
                                             decltype(m_hashFunctions) hf,
                                             decltype(m_table) table) mutable {
@@ -97,7 +100,7 @@ void CuckooHashTableView<Item, Key, Device>::build(const typename Array<Item, De
        m_table.setValue(-1);
        m_duplicates = 0;
        rview.setValue(0);
        TNL::Algorithms::ParallelFor<Device>::exec(0, items.getSize(), _build, m_iterations, m_hashFunctions, m_table);
        TNL::Algorithms::ParallelFor<Device>::exec(0, m_content.getSize(), _build, m_content, m_iterations, m_hashFunctions, m_table);
    } while (rview.getElement(1));
    m_duplicates = rview.getElement(0);
}
@@ -113,16 +116,19 @@ bool CuckooHashTableView<Item, Key, Device>::find(const Key& key, Item* item) co
    Array<int, Device> res(1);
    res.setValue(-1);
    auto rview = res.getView();
    auto _find = [rview, key] __cuda_callable__ (int i, ArrayView<HashFunction<Key>, Device> hf, ArrayView<int, Device> table, ArrayView<Item, Device> cont) mutable { 
    auto _find = [rview, key] __cuda_callable__ (int i, 
                                                 ArrayView<HashFunction<Key>, Device> hf, 
                                                 ArrayView<int, Device> table, 
                                                 ArrayView<Entry, Device> cont) mutable { 
        int pos = hf[i](key);
        if (table[pos] >= 0 && cont[table[pos]].key == key)
        if (table[pos] >= 0 && cont[table[pos]].item.key == key)
            rview[0] = pos;
    };
    TNL::Algorithms::ParallelFor<Device>::exec(0, m_hashFunctions.getSize(), _find, m_hashFunctions, m_table, m_content);
    int result = rview.getElement(0);
    if (result >= 0) {
        if (item)
            *item = m_content.getElement(m_table.getElement(result));
            *item = m_content.getElement(m_table.getElement(result)).item;
        return true;
    }
    return false;