Commit 85fe6cf2 authored by kolusask's avatar kolusask
Browse files

Add unit tests for HashGraphV2

parent ff5330e5
Loading
Loading
Loading
Loading
+4 −4
Original line number Original line Diff line number Diff line
@@ -19,7 +19,7 @@ using Device = TNL::Devices::Host;




template<typename Key>
template<typename Key>
class HashGraphSetTest : public ::testing::Test {
class HashGraphV1SetTest : public ::testing::Test {
protected:
protected:
    using KeyType = Key;
    using KeyType = Key;
    using ArrayType = Array<KeyType, Device>;
    using ArrayType = Array<KeyType, Device>;
@@ -28,9 +28,9 @@ protected:


using HashGraphSetTypes = ::testing::Types<int, long, float, double>;
using HashGraphSetTypes = ::testing::Types<int, long, float, double>;


TYPED_TEST_SUITE(HashGraphSetTest, HashGraphSetTypes);
TYPED_TEST_SUITE(HashGraphV1SetTest, HashGraphSetTypes);


TYPED_TEST(HashGraphSetTest, correctQuery) {
TYPED_TEST(HashGraphV1SetTest, correctQuery) {
    using KeyType = typename TestFixture::KeyType;
    using KeyType = typename TestFixture::KeyType;
    using ArrayType = typename TestFixture::ArrayType;
    using ArrayType = typename TestFixture::ArrayType;
    using SetType = typename TestFixture::SetType;
    using SetType = typename TestFixture::SetType;
@@ -42,7 +42,7 @@ TYPED_TEST(HashGraphSetTest, correctQuery) {
        EXPECT_TRUE(set.contains(values.getElement(i)));
        EXPECT_TRUE(set.contains(values.getElement(i)));
}
}


TYPED_TEST(HashGraphSetTest, wrongQuery) {
TYPED_TEST(HashGraphV1SetTest, wrongQuery) {
    using KeyType = typename TestFixture::KeyType;
    using KeyType = typename TestFixture::KeyType;
    using ArrayType = typename TestFixture::ArrayType;
    using ArrayType = typename TestFixture::ArrayType;
    using SetType = typename TestFixture::SetType;
    using SetType = typename TestFixture::SetType;
+81 −0
Original line number Original line Diff line number Diff line
#pragma once

#include "../HashGraph/HashGraphMap.h"

#include "gtest/gtest.h"

#include <string>
#include <type_traits>
#include <vector>


#ifdef HAVE_CUDA
#include <TNL/Devices/Cuda.h>
using Device = TNL::Devices::Cuda;
#else
#include <TNL/Devices/Host.h>
using Device = TNL::Devices::Host;
#endif


template<typename tPair>
class HashGraphV2MapTest : public ::testing::Test {
protected:
    using KeyType = typename std::tuple_element<0, tPair>::type;
    using ValueType = typename std::tuple_element<1, tPair>::type;
    using ArrayType = Array<Pair<KeyType, ValueType>, Device>;
    using MapType = HashGraphV2Map<KeyType, ValueType, Device>;
};

using HashGraphV2MapTypes = ::testing::Types<
    std::pair<int, int>,
    std::pair<int, long>,
    std::pair<int, float>,
    std::pair<int, double>,
    std::pair<long, int>,
    std::pair<long, long>,
    std::pair<long, float>,
    std::pair<long, double>,
    std::pair<float, int>,
    std::pair<float, long>,
    std::pair<float, float>,
    std::pair<float, double>,
    std::pair<double, int>,
    std::pair<double, long>,
    std::pair<double, float>,
    std::pair<double, double>
>;

TYPED_TEST_SUITE(HashGraphV2MapTest, HashGraphV2MapTypes);

TYPED_TEST(HashGraphV2MapTest, correctQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ValueType = typename TestFixture::ValueType;
    using ArrayType = typename TestFixture::ArrayType;
    using MapType = typename TestFixture::MapType;

    ArrayType pairs = get_pairs<KeyType, ValueType, Device>();
    MapType map(pairs);
    ValueType value;
    for (int i = 0; i < pairs.getSize(); i++) {
        auto pair = pairs.getElement(i);
        KeyType key = pair.key;
        EXPECT_TRUE(map.find(key, value));
        EXPECT_EQ(value, pair.value);
    }
}

TYPED_TEST(HashGraphV2MapTest, wrongQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ValueType = typename TestFixture::ValueType;
    using ArrayType = typename TestFixture::ArrayType;
    using MapType = typename TestFixture::MapType;

    ArrayType pairs = get_pairs<KeyType, ValueType, Device>();
    MapType map(pairs);

    for (KeyType key = 8; key < 16; key++) {
        ValueType value;
        EXPECT_FALSE(map.find(key, value));
    }
}
+55 −0
Original line number Original line Diff line number Diff line
#pragma once

#include "../HashGraph/HashGraphSet.h"

#include "gtest/gtest.h"

#include <string>
#include <type_traits>
#include <vector>


#ifdef HAVE_CUDA
#include <TNL/Devices/Cuda.h>
using Device = TNL::Devices::Cuda;
#else
#include <TNL/Devices/Host.h>
using Device = TNL::Devices::Host;
#endif


template<typename Key>
class HashGraphV2SetTest : public ::testing::Test {
protected:
    using KeyType = Key;
    using ArrayType = Array<KeyType, Device>;
    using SetType = HashGraphV2Set<KeyType, Device>;
};

using HashGraphSetTypes = ::testing::Types<int, long, float, double>;

TYPED_TEST_SUITE(HashGraphV2SetTest, HashGraphSetTypes);

TYPED_TEST(HashGraphV2SetTest, correctQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ArrayType = typename TestFixture::ArrayType;
    using SetType = typename TestFixture::SetType;

    ArrayType values = get_values<KeyType, Device>();
    SetType set(values);

    for (int i = 0; i < values.getSize(); i++)
        EXPECT_TRUE(set.contains(values.getElement(i)));
}

TYPED_TEST(HashGraphV2SetTest, wrongQuery) {
    using KeyType = typename TestFixture::KeyType;
    using ArrayType = typename TestFixture::ArrayType;
    using SetType = typename TestFixture::SetType;

    ArrayType values = get_values<KeyType, Device>();
    SetType set(values);

    for (KeyType key = 8; key < 16; key++)
        EXPECT_FALSE(set.contains(key));
}
+2 −0
Original line number Original line Diff line number Diff line
@@ -19,6 +19,7 @@ Array<Pair<KeyType, ValueType>, Device> get_pairs() {


#include "CuckooHashMapTest.h"
#include "CuckooHashMapTest.h"
#include "HashGraphV1MapTest.h"
#include "HashGraphV1MapTest.h"
#include "HashGraphV2MapTest.h"


template<typename KeyType, typename Device>
template<typename KeyType, typename Device>
Array<KeyType, Device> get_values() {
Array<KeyType, Device> get_values() {
@@ -33,6 +34,7 @@ Array<KeyType, Device> get_values() {


#include "CuckooHashSetTest.h"
#include "CuckooHashSetTest.h"
#include "HashGraphV1SetTest.h"
#include "HashGraphV1SetTest.h"
#include "HashGraphV2SetTest.h"


int main(int argc, char* argv[]) {
int main(int argc, char* argv[]) {
   ::testing::InitGoogleTest( &argc, argv );
   ::testing::InitGoogleTest( &argc, argv );