Commit 6420f6b1 authored by kolusask's avatar kolusask
Browse files

Re-enabled test on std::set

parent 62cd2d36
Loading
Loading
Loading
Loading
+39 −16
Original line number Diff line number Diff line
@@ -26,7 +26,7 @@ using Cell = TNL::Containers::StaticArray<4, int>;
/**
 * Wrapper class for testing std::set
 */
template<typename T>
template<typename T, typename Device>
class StdSetWrapper {
    struct Comparator {
        bool operator()(const T& v1, const T& v2) const {
@@ -38,6 +38,8 @@ class StdSetWrapper {
    StdSetWrapper(const std::vector<T>& values) :
                    m_set(values.begin(), values.end()) {}
    
    StdSetWrapper(StdSetWrapper<T, Device>&& other) : m_set(std::move(other.m_set)) {}

    bool contains(const T& value) const {
        return m_set.count(value);
    }
@@ -50,7 +52,7 @@ class StdSetWrapper {
};

template<>
bool StdSetWrapper<Cell>::Comparator::operator()(const Cell& v1, const Cell& v2) const {
bool StdSetWrapper<Cell, Device>::Comparator::operator()(const Cell& v1, const Cell& v2) const {
    for (int i = 0; i < 4; i++)
        if (v1[i] != v2[i])
            return v1[i] < v2[i];
@@ -62,7 +64,7 @@ template<template<class, class> class S, typename T>
S<T, Device> test_building(const Array<T, Device>& values) {
    BEGIN_TEST(BUILDING)
    std::cerr << "\t\tInserting " << values.getSize() << " entries" << std::endl;
    auto table = S<T, Device>(values);
    S<T, Device> table(values);
    if constexpr (std::is_same<S<T, Device>, CuckooHashSet<T, Device>>::value) {
        std::cerr << "\t\t - " << table.generations() 
                  << " generation(s) before success" << std::endl;
@@ -74,6 +76,16 @@ S<T, Device> test_building(const Array<T, Device>& values) {
}


template<template<class, class> class S, typename T>
S<T, Device> test_building(const std::vector<T>& values) {
    BEGIN_TEST(BUILDING)
    std::cerr << "\t\tInserting " << values.size() << " entries" << std::endl;
    auto table = S<T, Device>(values);
    END_TEST
    return table; 
}


template<template<class, class> class S, typename T>
void test_correct_query(const S<T, Device>& table, 
                        const std::vector<T>& values, 
@@ -107,16 +119,24 @@ void test_wrong_query(const S<T, Device>& table,
    END_TEST
}


template<template<class, class> class S, typename T>
void test_random(const typename Array<T, Device>::ConstViewType values, const std::vector<T>& cpuValues, int nTests, double addedPart) {
    int nAdded = int(values.getSize() * addedPart);
    S<T, Device> table([&] {
        if constexpr (std::is_same<S<T, Device>, StdSetWrapper<T, Device>>::value)
            return S<T, Device>(std::vector(cpuValues.begin(), cpuValues.begin() + nAdded));
        else {
            auto arr = Array<T, Device>(nAdded);
            auto aView = arr.getView();
            auto copy_part = [aView, values] __cuda_callable__ (int i) mutable {
                aView[i] = values[i];
            };
            TNL::Algorithms::ParallelFor<Device>::exec(0, nAdded, copy_part);
    auto table = S<T, Device>(arr);
            return S<T, Device>(arr);
        }
    }());
    //auto table = build_table();
    int correct = 0, wrong = 0;
    BEGIN_TEST(RANDOM)
    std::cerr << "\t\tFor " << cpuValues.size() << " entries, "
@@ -158,7 +178,7 @@ Array<T, Device> get_data(std::string fileName, std::vector<T>& cpu) {
    }
    auto dataPair = dataMap[fileName];
    cpu = dataPair.second;
    return std::move(dataPair.first);
    return Array<T, Device>(dataPair.first);
}

template<>
@@ -213,8 +233,12 @@ template<template<class, class> class S, typename T>
void test(std::string fileName, bool testRandom = true) {
    std::vector<T> cpuValues;
    auto values = get_data<T>(fileName, cpuValues);
    auto table = test_building<S, T>(values);
    //table.debug_print();
    S<T, Device> table([&] {
        if constexpr (std::is_same<S<T, Device>, StdSetWrapper<T, Device>>::value)
            return test_building<S, T>(cpuValues);
        else
            return test_building<S, T>(values);
    }());
    test_correct_query<S, T>(table, cpuValues, cpuValues.size());
    test_wrong_query<S, T>(table, cpuValues, cpuValues.size());
    if (testRandom)
@@ -246,7 +270,6 @@ Triangle ruin<Triangle>(const Triangle& value) {
template<template<class, class> class S>
void test_class(const std::string className, const std::string dataFile) {
    std::cerr << "Testing " << className << "<TNL::Containers::StaticArray<4, int>> with " << dataFile << ':' << std::endl;
    // TODO testRandom=true
    test<S, Cell>(dataFile, false);

    std::cerr <<  "Testing " << className << "<Triangle> with " << dataFile << ':' << std::endl;
@@ -280,7 +303,7 @@ void run() {
    for (int i = 1; i <= 5; i++) {
        std::string dataFile = "data/cube1m_" + std::to_string(i) + "_cells.txt";

        //test_class<StdSetWrapper>("std::set", dataFile);
        test_class<StdSetWrapper>("std::set", dataFile);
        test_class<CuckooHashSet>("CuckooHashSet", dataFile);
        test_class<HashGraphV1Set>("HashGraphV1Set", dataFile);
    }