Commit 5c2ab2b3 authored by kolusask's avatar kolusask
Browse files

Refactored and fixed arithmetics in HashGraphV1

parent a6350ef8
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -41,7 +41,6 @@ class HashGraphV1 {
    Array<int, Device> m_items;
    Array<int, Device> m_offset;
    std::shared_ptr<ViewType> m_view;
    HashFunction<Key> m_hash;
  
};

+2 −3
Original line number Diff line number Diff line
@@ -11,10 +11,9 @@

template<typename Item, typename Key, typename Device>
HashGraphV1<Item, Key, Device>::HashGraphV1(const Array<Item, Device>& items) :
            m_content(items.getSize()),
            m_content(items),
            m_items(items.getSize()),
            m_offset((1 << int(ceil(log2(items.getSize())))), 0),
            m_hash(31, 10538, items.getSize()),
            m_offset((1 << int(ceil(log2(items.getSize())))) + 1, 0),
            m_view(std::make_shared<ViewType>(*this, items.getConstView(), 
                                              Array<int, Device>(items.getSize()).getView(),
                                              Array<int, Device>(items.getSize()).getView())) {}
+7 −11
Original line number Diff line number Diff line
@@ -22,11 +22,6 @@ void HashGraphV1View<Item, Key, Device>::build(const typename Array<Item, Device
                                               ArrayView<int, Device> hashes,
                                               ArrayView<int, Device> counter) {
    
    auto content = m_content;
    auto fill_content = [content, input] __cuda_callable__ (int i) mutable {
        content[i] = input[i];
    };
    TNL::Algorithms::ParallelFor<Device>::exec(0, input.getSize(), fill_content);
    auto hash = m_hash;
    auto init_hashes = [hash, hashes, input] __cuda_callable__ (int i) mutable {
        hashes[i] = hash(input[i].key);
@@ -73,19 +68,20 @@ void HashGraphV1View<Item, Key, Device>::fill_offset(const ArrayView<int, Device
                                           + offset[k + (1 << (d + 1)) - 1];
    };
    for (int d = 0; d < log2(offset.getSize() - 1); d++)
        TNL::Algorithms::ParallelFor<Device>::exec(0, offset.getSize(), reduce, d);
        TNL::Algorithms::ParallelFor<Device>::exec(0, offset.getSize() - 1, reduce, d);

    offset.setElement(offset.getSize() - 1, 0);
    offset.setElement(offset.getSize() - 1, offset.getElement(offset.getSize() - 2));
    offset.setElement(offset.getSize() - 2, 0);
    auto up_sweep = [offset] __cuda_callable__ (int r, int d) mutable {
        if ((offset.getSize() - 1 - r) % (1 << d) == 0) {
        if ((offset.getSize() - 2 - r) % (1 << d) == 0) {
            int l = r - (1 << (d - 1));
            int t = offset[r];
            offset[r] += offset[l];
            offset[l] = t;
        }
    };
    for (int d = log2(offset.getSize()); d > 0; d--)
        TNL::Algorithms::ParallelFor<Device>::exec(0, offset.getSize(), up_sweep, d);
    for (int d = log2(offset.getSize() - 1); d > 0; d--)
        TNL::Algorithms::ParallelFor<Device>::exec(0, offset.getSize() - 1, up_sweep, d);
}

template<typename Item, typename Key, typename Device>
@@ -110,5 +106,5 @@ bool HashGraphV1View<Item, Key, Device>::find(const Key& key, ArrayView<Item, De
        }
    };
    TNL::Algorithms::ParallelFor<Device>::exec(m_offset.getElement(hash), end, _find);
    return rView.getElement(0) > -1
    return rView.getElement(0) > -1;
}
+5 −5
Original line number Diff line number Diff line
@@ -19,7 +19,7 @@ using Device = TNL::Devices::Host;


template<typename tPair>
class HashGraphV1Test : public ::testing::Test {
class HashGraphV1MapTest : public ::testing::Test {
protected:
    using KeyType = typename std::tuple_element<0, tPair>::type;
    using ValueType = typename std::tuple_element<1, tPair>::type;
@@ -27,7 +27,7 @@ protected:
    using MapType = HashGraphV1Map<KeyType, ValueType, Device>;
};

using HashGraphV1Types = ::testing::Types<
using HashGraphV1MapTypes = ::testing::Types<
    std::pair<int, int>,
    std::pair<int, long>,
    std::pair<int, float>,
@@ -46,9 +46,9 @@ using HashGraphV1Types = ::testing::Types<
    std::pair<double, double>
>;

TYPED_TEST_SUITE(HashGraphV1Test, HashGraphV1Types);
TYPED_TEST_SUITE(HashGraphV1MapTest, HashGraphV1MapTypes);

TYPED_TEST(HashGraphV1Test, correctQuery) {
TYPED_TEST(HashGraphV1MapTest, correctQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ValueType = typename TestFixture::ValueType;
    using ArrayType = typename TestFixture::ArrayType;
@@ -65,7 +65,7 @@ TYPED_TEST(HashGraphV1Test, correctQuery) {
    }
}

TYPED_TEST(HashGraphV1Test, wrongQuery) {
TYPED_TEST(HashGraphV1MapTest, wrongQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ValueType = typename TestFixture::ValueType;
    using ArrayType = typename TestFixture::ArrayType;