]>
Commit | Line | Data |
---|---|---|
9f95a23c TL |
1 | // -*- mode:C++; tab-width:8; c-basic-offset:2; indent-tabs-mode:t -*- |
2 | // vim: ts=8 sw=2 smarttab | |
3 | ||
4 | #include "common/weighted_shuffle.h" | |
5 | #include <array> | |
6 | #include <map> | |
7 | #include "gtest/gtest.h" | |
8 | ||
9 | TEST(WeightedShuffle, Basic) { | |
10 | std::array<char, 5> choices{'a', 'b', 'c', 'd', 'e'}; | |
11 | std::array<int, 5> weights{100, 50, 25, 10, 1}; | |
12 | std::map<char, std::array<unsigned, 5>> frequency { | |
13 | {'a', {0, 0, 0, 0, 0}}, | |
14 | {'b', {0, 0, 0, 0, 0}}, | |
15 | {'c', {0, 0, 0, 0, 0}}, | |
16 | {'d', {0, 0, 0, 0, 0}}, | |
17 | {'e', {0, 0, 0, 0, 0}} | |
18 | }; // count each element appearing in each position | |
19 | const int samples = 10000; | |
20 | std::random_device rd; | |
21 | for (auto i = 0; i < samples; i++) { | |
22 | weighted_shuffle(begin(choices), end(choices), | |
23 | begin(weights), end(weights), | |
24 | std::mt19937{rd()}); | |
25 | for (size_t j = 0; j < choices.size(); ++j) | |
26 | ++frequency[choices[j]][j]; | |
27 | } | |
28 | // verify that the probability that the nth choice is selected as the first | |
29 | // one is the nth weight divided by the sum of all weights | |
30 | const auto total_weight = std::accumulate(weights.begin(), weights.end(), 0); | |
31 | constexpr float epsilon = 0.02; | |
32 | for (unsigned i = 0; i < choices.size(); i++) { | |
33 | const auto& f = frequency[choices[i]]; | |
34 | const auto& w = weights[i]; | |
35 | ASSERT_NEAR(float(w) / total_weight, | |
36 | float(f.front()) / samples, | |
37 | epsilon); | |
38 | } | |
39 | } |