]>
Commit | Line | Data |
---|---|---|
1e59de90 TL |
1 | // (C) Copyright Matt Borland and Nick Thompson 2022. |
2 | // Distributed under the Boost Software License, Version 1.0. | |
3 | // (See accompanying file LICENSE_1_0.txt or copy at | |
4 | // http://www.boost.org/LICENSE_1_0.txt) | |
5 | ||
6 | #include "math_unit_test.hpp" | |
7 | #include <cmath> | |
8 | #include <vector> | |
9 | #include <boost/math/special_functions/logsumexp.hpp> | |
10 | #include <boost/math/constants/constants.hpp> | |
11 | #include <boost/math/tools/random_vector.hpp> | |
12 | ||
13 | template <typename Real> | |
14 | void test() | |
15 | { | |
16 | using boost::math::logsumexp; | |
17 | using std::log; | |
18 | using std::exp; | |
19 | ||
20 | // Spot check 2 values | |
21 | // Also validate that 2 values does not attempt to instantiate the iterator version | |
22 | // https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html | |
23 | // Calculated at higher precision using wolfram alpha | |
24 | Real x1 = 1e-50l; | |
25 | Real x2 = 2.5e-50l; | |
26 | Real spot1 = static_cast<Real>(exp(x1)); | |
27 | Real spot2 = static_cast<Real>(exp(x2)); | |
28 | Real spot12 = logsumexp(x1, x2); | |
29 | CHECK_ULP_CLOSE(log(spot1 + spot2), spot12, 1); | |
30 | ||
31 | // Spot check 3 values and compare result of each different interface | |
32 | Real x3 = 5e-50l; | |
33 | Real spot3 = static_cast<Real>(exp(x3)); | |
34 | std::vector<Real> x_vals {x1, x2, x3}; | |
35 | ||
36 | Real spot123 = logsumexp(x1, x2, x3); | |
37 | Real spot123_container = logsumexp(x_vals); | |
38 | Real spot123_iter = logsumexp(x_vals.begin(), x_vals.end()); | |
39 | ||
40 | CHECK_EQUAL(spot123, spot123_container); | |
41 | CHECK_EQUAL(spot123_container, spot123_iter); | |
42 | CHECK_ULP_CLOSE(log(spot1 + spot2 + spot3), spot123, 1); | |
43 | ||
44 | // Spot check 4 values with repeated largest value | |
45 | Real x4 = x3; | |
46 | Real spot4 = spot3; | |
47 | Real spot1234 = logsumexp(x1, x2, x3, x4); | |
48 | x_vals.emplace_back(x4); | |
49 | Real spot1234_container = logsumexp(x_vals); | |
50 | ||
51 | CHECK_EQUAL(spot1234, spot1234_container); | |
52 | CHECK_ULP_CLOSE(log(spot1 + spot2 + spot3 + spot4), spot1234, 1); | |
53 | ||
54 | // Check with a value of vastly different order of magnitude | |
55 | Real x5 = 1.0l; | |
56 | Real spot5 = static_cast<Real>(exp(x5)); | |
57 | x_vals.emplace_back(x5); | |
58 | Real spot12345 = logsumexp(x_vals); | |
59 | CHECK_ULP_CLOSE(log(spot1 + spot2 + spot3 + spot4 + spot5), spot12345, 1); | |
60 | } | |
61 | ||
62 | // The naive method of computation should overflow: | |
63 | template<typename Real> | |
64 | void test_overflow() | |
65 | { | |
66 | using boost::math::logsumexp; | |
67 | using std::exp; | |
68 | using std::log; | |
69 | ||
70 | Real x = ((std::numeric_limits<Real>::max)()/2); | |
71 | ||
72 | Real naive_result = log(exp(x) + exp(x)); | |
73 | CHECK_EQUAL(std::isfinite(naive_result), false); | |
74 | ||
75 | Real result = logsumexp(x, x); | |
76 | CHECK_EQUAL(std::isfinite(result), true); | |
77 | CHECK_ULP_CLOSE(result, x + boost::math::constants::ln_two<Real>(), 1); | |
78 | } | |
79 | ||
80 | template <typename Real> | |
81 | void test_random() | |
82 | { | |
83 | using std::exp; | |
84 | using std::log; | |
85 | using boost::math::logsumexp; | |
86 | using boost::math::generate_random_vector; | |
87 | ||
88 | std::vector<Real> test_values = generate_random_vector(128, 0, Real(1e-50l), Real(1e-40l)); | |
89 | Real naive_exp_sum = 0; | |
90 | ||
91 | for(const auto& val : test_values) | |
92 | { | |
93 | naive_exp_sum += exp(val); | |
94 | } | |
95 | ||
96 | CHECK_ULP_CLOSE(log(naive_exp_sum), logsumexp(test_values), 1); | |
97 | } | |
98 | ||
99 | int main (void) | |
100 | { | |
101 | test<float>(); | |
102 | test<double>(); | |
103 | test<long double>(); | |
104 | ||
105 | test_overflow<float>(); | |
106 | test_overflow<double>(); | |
107 | test_overflow<long double>(); | |
108 | ||
109 | test_random<float>(); | |
110 | test_random<double>(); | |
111 | test_random<long double>(); | |
112 | return boost::math::test::report_errors(); | |
113 | } |