]> git.proxmox.com Git - ceph.git/blame - ceph/src/boost/libs/mpi/test/all_reduce_test.cpp
import new upstream nautilus stable release 14.2.8
[ceph.git] / ceph / src / boost / libs / mpi / test / all_reduce_test.cpp
CommitLineData
7c673cae
FG
1// Copyright (C) 2005, 2006 Douglas Gregor.
2
3// Use, modification and distribution is subject to the Boost Software
4// License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at
5// http://www.boost.org/LICENSE_1_0.txt)
6
7// A test of the all_reduce() collective.
8#include <boost/mpi/collectives/all_reduce.hpp>
9#include <boost/mpi/communicator.hpp>
10#include <boost/mpi/environment.hpp>
7c673cae
FG
11#include <vector>
12#include <algorithm>
13#include <boost/serialization/string.hpp>
14#include <boost/iterator/counting_iterator.hpp>
15#include <boost/lexical_cast.hpp>
16#include <numeric>
17
92f5a8d4
TL
18#define BOOST_TEST_MODULE mpi_all_reduce
19#include <boost/test/included/unit_test.hpp>
20
7c673cae
FG
21using boost::mpi::communicator;
22
23// A simple point class that we can build, add, compare, and
24// serialize.
25struct point
26{
27 point() : x(0), y(0), z(0) { }
28 point(int x, int y, int z) : x(x), y(y), z(z) { }
29
30 int x;
31 int y;
32 int z;
33
34 private:
35 template<typename Archiver>
36 void serialize(Archiver& ar, unsigned int /*version*/)
37 {
38 ar & x & y & z;
39 }
40
41 friend class boost::serialization::access;
42};
43
44std::ostream& operator<<(std::ostream& out, const point& p)
45{
46 return out << p.x << ' ' << p.y << ' ' << p.z;
47}
48
49bool operator==(const point& p1, const point& p2)
50{
51 return p1.x == p2.x && p1.y == p2.y && p1.z == p2.z;
52}
53
54bool operator!=(const point& p1, const point& p2)
55{
56 return !(p1 == p2);
57}
58
59point operator+(const point& p1, const point& p2)
60{
61 return point(p1.x + p2.x, p1.y + p2.y, p1.z + p2.z);
62}
63
64// test lexical order
65bool operator<(const point& p1, const point& p2)
66{
67 return (p1.x < p2.x
68 ? true
69 : (p1.x > p2.x
70 ? false
71 : p1.y < p2.y ));
72}
73
74namespace boost { namespace mpi {
75
76 template <>
77 struct is_mpi_datatype<point> : public mpl::true_ { };
78
79} } // end namespace boost::mpi
80
81template<typename Generator, typename Op>
82void
83all_reduce_one_test(const communicator& comm, Generator generator,
84 const char* type_kind, Op op, const char* op_kind,
85 typename Generator::result_type init, bool in_place)
86{
87 typedef typename Generator::result_type value_type;
88 value_type value = generator(comm.rank());
89
90 using boost::mpi::all_reduce;
91 using boost::mpi::inplace;
92
93 if (comm.rank() == 0) {
94 std::cout << "Reducing to " << op_kind << " of " << type_kind << "...";
95 std::cout.flush();
96 }
97
98 value_type result_value;
99 if (in_place) {
100 all_reduce(comm, inplace(value), op);
101 result_value = value;
102 } else {
103 result_value = all_reduce(comm, value, op);
104 }
105
106 // Compute expected result
107 std::vector<value_type> generated_values;
108 for (int p = 0; p < comm.size(); ++p)
109 generated_values.push_back(generator(p));
110 value_type expected_result = std::accumulate(generated_values.begin(),
111 generated_values.end(),
112 init, op);
113 BOOST_CHECK(result_value == expected_result);
114 if (result_value == expected_result && comm.rank() == 0)
115 std::cout << "OK." << std::endl;
116
117 (comm.barrier)();
118}
119
120template<typename Generator, typename Op>
121void
122all_reduce_array_test(const communicator& comm, Generator generator,
123 const char* type_kind, Op op, const char* op_kind,
124 typename Generator::result_type init, bool in_place)
125{
126 typedef typename Generator::result_type value_type;
127 value_type value = generator(comm.rank());
128 std::vector<value_type> send(10, value);
129
130 using boost::mpi::all_reduce;
131 using boost::mpi::inplace;
132
133 if (comm.rank() == 0) {
134 char const* place = in_place ? "in place" : "out of place";
135 std::cout << "Reducing (" << place << ") array to " << op_kind << " of " << type_kind << "...";
136 std::cout.flush();
137 }
138 std::vector<value_type> result;
139 if (in_place) {
140 all_reduce(comm, inplace(&(send[0])), send.size(), op);
141 result.swap(send);
142 } else {
143 std::vector<value_type> recv(10, value_type());
144 all_reduce(comm, &(send[0]), send.size(), &(recv[0]), op);
145 result.swap(recv);
146 }
147
148 // Compute expected result
149 std::vector<value_type> generated_values;
150 for (int p = 0; p < comm.size(); ++p)
151 generated_values.push_back(generator(p));
152 value_type expected_result = std::accumulate(generated_values.begin(),
153 generated_values.end(),
154 init, op);
155
156 bool got_expected_result = (std::equal_range(result.begin(), result.end(),
157 expected_result)
158 == std::make_pair(result.begin(), result.end()));
159 BOOST_CHECK(got_expected_result);
160 if (got_expected_result && comm.rank() == 0)
161 std::cout << "OK." << std::endl;
162
163 (comm.barrier)();
164}
165
166// Test the 4 families of all reduce: (value, array) X (in place, out of place)
167template<typename Generator, typename Op>
168void
169all_reduce_test(const communicator& comm, Generator generator,
170 const char* type_kind, Op op, const char* op_kind,
171 typename Generator::result_type init)
172{
173 const bool in_place = true;
174 const bool out_of_place = false;
175 all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, in_place);
176 all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, out_of_place);
177 all_reduce_array_test(comm, generator, type_kind, op, op_kind,
178 init, in_place);
179 all_reduce_array_test(comm, generator, type_kind, op, op_kind,
180 init, out_of_place);
181}
182
183// Generates integers to test with all_reduce()
184struct int_generator
185{
186 typedef int result_type;
187
188 int_generator(int base = 1) : base(base) { }
189
190 int operator()(int p) const { return base + p; }
191
192 private:
193 int base;
194};
195
196// Generate points to test with all_reduce()
197struct point_generator
198{
199 typedef point result_type;
200
201 point_generator(point origin) : origin(origin) { }
202
203 point operator()(int p) const
204 {
205 return point(origin.x + 1, origin.y + 1, origin.z + 1);
206 }
207
208 private:
209 point origin;
210};
211
212struct string_generator
213{
214 typedef std::string result_type;
215
216 std::string operator()(int p) const
217 {
218 std::string result = boost::lexical_cast<std::string>(p);
219 result += " rosebud";
220 if (p != 1) result += 's';
221 return result;
222 }
223};
224
225struct secret_int_bit_and
226{
227 int operator()(int x, int y) const { return x & y; }
228};
229
230struct wrapped_int
231{
232 wrapped_int() : value(0) { }
233 explicit wrapped_int(int value) : value(value) { }
234
235 template<typename Archive>
236 void serialize(Archive& ar, unsigned int /* version */)
237 {
238 ar & value;
239 }
240
241 int value;
242};
243
244wrapped_int operator+(const wrapped_int& x, const wrapped_int& y)
245{
246 return wrapped_int(x.value + y.value);
247}
248
249bool operator==(const wrapped_int& x, const wrapped_int& y)
250{
251 return x.value == y.value;
252}
253
254bool operator<(const wrapped_int& x, const wrapped_int& y)
255{
256 return x.value < y.value;
257}
258
259// Generates wrapped_its to test with all_reduce()
260struct wrapped_int_generator
261{
262 typedef wrapped_int result_type;
263
264 wrapped_int_generator(int base = 1) : base(base) { }
265
266 wrapped_int operator()(int p) const { return wrapped_int(base + p); }
267
268 private:
269 int base;
270};
271
272namespace boost { namespace mpi {
273
274// Make std::plus<wrapped_int> commutative.
275template<>
276struct is_commutative<std::plus<wrapped_int>, wrapped_int>
277 : mpl::true_ { };
278
279} } // end namespace boost::mpi
280
92f5a8d4
TL
281BOOST_AUTO_TEST_CASE(test_all_reduce)
282{
7c673cae 283 using namespace boost::mpi;
92f5a8d4 284 environment env;
7c673cae
FG
285 communicator comm;
286
287 // Built-in MPI datatypes with built-in MPI operations
92f5a8d4
TL
288 all_reduce_test(comm, int_generator(), "integers", std::plus<int>(), "sum", 0);
289 all_reduce_test(comm, int_generator(), "integers", std::multiplies<int>(), "product", 1);
290 all_reduce_test(comm, int_generator(), "integers", maximum<int>(), "maximum", 0);
291 all_reduce_test(comm, int_generator(), "integers", minimum<int>(), "minimum", 2);
7c673cae
FG
292
293 // User-defined MPI datatypes with operations that have the
294 // same name as built-in operations.
92f5a8d4
TL
295 all_reduce_test(comm, point_generator(point(0,0,0)), "points", std::plus<point>(),
296 "sum", point());
7c673cae
FG
297
298 // Built-in MPI datatypes with user-defined operations
92f5a8d4 299 all_reduce_test(comm, int_generator(17), "integers", secret_int_bit_and(),
7c673cae
FG
300 "bitwise and", -1);
301
302 // Arbitrary types with user-defined, commutative operations.
303 all_reduce_test(comm, wrapped_int_generator(17), "wrapped integers",
304 std::plus<wrapped_int>(), "sum", wrapped_int(0));
305
306 // Arbitrary types with (non-commutative) user-defined operations
307 all_reduce_test(comm, string_generator(), "strings",
308 std::plus<std::string>(), "concatenation", std::string());
7c673cae 309}