]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | // Copyright (C) 2005, 2006 Douglas Gregor. |
2 | ||
3 | // Use, modification and distribution is subject to the Boost Software | |
4 | // License, Version 1.0. (See accompanying file LICENSE_1_0.txt or copy at | |
5 | // http://www.boost.org/LICENSE_1_0.txt) | |
6 | ||
7 | // A test of the all_reduce() collective. | |
8 | #include <boost/mpi/collectives/all_reduce.hpp> | |
9 | #include <boost/mpi/communicator.hpp> | |
10 | #include <boost/mpi/environment.hpp> | |
7c673cae FG |
11 | #include <vector> |
12 | #include <algorithm> | |
13 | #include <boost/serialization/string.hpp> | |
14 | #include <boost/iterator/counting_iterator.hpp> | |
15 | #include <boost/lexical_cast.hpp> | |
16 | #include <numeric> | |
17 | ||
92f5a8d4 TL |
18 | #define BOOST_TEST_MODULE mpi_all_reduce |
19 | #include <boost/test/included/unit_test.hpp> | |
20 | ||
7c673cae FG |
21 | using boost::mpi::communicator; |
22 | ||
23 | // A simple point class that we can build, add, compare, and | |
24 | // serialize. | |
25 | struct point | |
26 | { | |
27 | point() : x(0), y(0), z(0) { } | |
28 | point(int x, int y, int z) : x(x), y(y), z(z) { } | |
29 | ||
30 | int x; | |
31 | int y; | |
32 | int z; | |
33 | ||
34 | private: | |
35 | template<typename Archiver> | |
36 | void serialize(Archiver& ar, unsigned int /*version*/) | |
37 | { | |
38 | ar & x & y & z; | |
39 | } | |
40 | ||
41 | friend class boost::serialization::access; | |
42 | }; | |
43 | ||
44 | std::ostream& operator<<(std::ostream& out, const point& p) | |
45 | { | |
46 | return out << p.x << ' ' << p.y << ' ' << p.z; | |
47 | } | |
48 | ||
49 | bool operator==(const point& p1, const point& p2) | |
50 | { | |
51 | return p1.x == p2.x && p1.y == p2.y && p1.z == p2.z; | |
52 | } | |
53 | ||
54 | bool operator!=(const point& p1, const point& p2) | |
55 | { | |
56 | return !(p1 == p2); | |
57 | } | |
58 | ||
59 | point operator+(const point& p1, const point& p2) | |
60 | { | |
61 | return point(p1.x + p2.x, p1.y + p2.y, p1.z + p2.z); | |
62 | } | |
63 | ||
64 | // test lexical order | |
65 | bool operator<(const point& p1, const point& p2) | |
66 | { | |
67 | return (p1.x < p2.x | |
68 | ? true | |
69 | : (p1.x > p2.x | |
70 | ? false | |
71 | : p1.y < p2.y )); | |
72 | } | |
73 | ||
74 | namespace boost { namespace mpi { | |
75 | ||
76 | template <> | |
77 | struct is_mpi_datatype<point> : public mpl::true_ { }; | |
78 | ||
79 | } } // end namespace boost::mpi | |
80 | ||
81 | template<typename Generator, typename Op> | |
82 | void | |
83 | all_reduce_one_test(const communicator& comm, Generator generator, | |
84 | const char* type_kind, Op op, const char* op_kind, | |
85 | typename Generator::result_type init, bool in_place) | |
86 | { | |
87 | typedef typename Generator::result_type value_type; | |
88 | value_type value = generator(comm.rank()); | |
89 | ||
90 | using boost::mpi::all_reduce; | |
91 | using boost::mpi::inplace; | |
92 | ||
93 | if (comm.rank() == 0) { | |
94 | std::cout << "Reducing to " << op_kind << " of " << type_kind << "..."; | |
95 | std::cout.flush(); | |
96 | } | |
97 | ||
98 | value_type result_value; | |
99 | if (in_place) { | |
100 | all_reduce(comm, inplace(value), op); | |
101 | result_value = value; | |
102 | } else { | |
103 | result_value = all_reduce(comm, value, op); | |
104 | } | |
105 | ||
106 | // Compute expected result | |
107 | std::vector<value_type> generated_values; | |
108 | for (int p = 0; p < comm.size(); ++p) | |
109 | generated_values.push_back(generator(p)); | |
110 | value_type expected_result = std::accumulate(generated_values.begin(), | |
111 | generated_values.end(), | |
112 | init, op); | |
113 | BOOST_CHECK(result_value == expected_result); | |
114 | if (result_value == expected_result && comm.rank() == 0) | |
115 | std::cout << "OK." << std::endl; | |
116 | ||
117 | (comm.barrier)(); | |
118 | } | |
119 | ||
120 | template<typename Generator, typename Op> | |
121 | void | |
122 | all_reduce_array_test(const communicator& comm, Generator generator, | |
123 | const char* type_kind, Op op, const char* op_kind, | |
124 | typename Generator::result_type init, bool in_place) | |
125 | { | |
126 | typedef typename Generator::result_type value_type; | |
127 | value_type value = generator(comm.rank()); | |
128 | std::vector<value_type> send(10, value); | |
129 | ||
130 | using boost::mpi::all_reduce; | |
131 | using boost::mpi::inplace; | |
132 | ||
133 | if (comm.rank() == 0) { | |
134 | char const* place = in_place ? "in place" : "out of place"; | |
135 | std::cout << "Reducing (" << place << ") array to " << op_kind << " of " << type_kind << "..."; | |
136 | std::cout.flush(); | |
137 | } | |
138 | std::vector<value_type> result; | |
139 | if (in_place) { | |
140 | all_reduce(comm, inplace(&(send[0])), send.size(), op); | |
141 | result.swap(send); | |
142 | } else { | |
143 | std::vector<value_type> recv(10, value_type()); | |
144 | all_reduce(comm, &(send[0]), send.size(), &(recv[0]), op); | |
145 | result.swap(recv); | |
146 | } | |
147 | ||
148 | // Compute expected result | |
149 | std::vector<value_type> generated_values; | |
150 | for (int p = 0; p < comm.size(); ++p) | |
151 | generated_values.push_back(generator(p)); | |
152 | value_type expected_result = std::accumulate(generated_values.begin(), | |
153 | generated_values.end(), | |
154 | init, op); | |
155 | ||
156 | bool got_expected_result = (std::equal_range(result.begin(), result.end(), | |
157 | expected_result) | |
158 | == std::make_pair(result.begin(), result.end())); | |
159 | BOOST_CHECK(got_expected_result); | |
160 | if (got_expected_result && comm.rank() == 0) | |
161 | std::cout << "OK." << std::endl; | |
162 | ||
163 | (comm.barrier)(); | |
164 | } | |
165 | ||
166 | // Test the 4 families of all reduce: (value, array) X (in place, out of place) | |
167 | template<typename Generator, typename Op> | |
168 | void | |
169 | all_reduce_test(const communicator& comm, Generator generator, | |
170 | const char* type_kind, Op op, const char* op_kind, | |
171 | typename Generator::result_type init) | |
172 | { | |
173 | const bool in_place = true; | |
174 | const bool out_of_place = false; | |
175 | all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, in_place); | |
176 | all_reduce_one_test(comm, generator, type_kind, op, op_kind, init, out_of_place); | |
177 | all_reduce_array_test(comm, generator, type_kind, op, op_kind, | |
178 | init, in_place); | |
179 | all_reduce_array_test(comm, generator, type_kind, op, op_kind, | |
180 | init, out_of_place); | |
181 | } | |
182 | ||
183 | // Generates integers to test with all_reduce() | |
184 | struct int_generator | |
185 | { | |
186 | typedef int result_type; | |
187 | ||
188 | int_generator(int base = 1) : base(base) { } | |
189 | ||
190 | int operator()(int p) const { return base + p; } | |
191 | ||
192 | private: | |
193 | int base; | |
194 | }; | |
195 | ||
196 | // Generate points to test with all_reduce() | |
197 | struct point_generator | |
198 | { | |
199 | typedef point result_type; | |
200 | ||
201 | point_generator(point origin) : origin(origin) { } | |
202 | ||
203 | point operator()(int p) const | |
204 | { | |
205 | return point(origin.x + 1, origin.y + 1, origin.z + 1); | |
206 | } | |
207 | ||
208 | private: | |
209 | point origin; | |
210 | }; | |
211 | ||
212 | struct string_generator | |
213 | { | |
214 | typedef std::string result_type; | |
215 | ||
216 | std::string operator()(int p) const | |
217 | { | |
218 | std::string result = boost::lexical_cast<std::string>(p); | |
219 | result += " rosebud"; | |
220 | if (p != 1) result += 's'; | |
221 | return result; | |
222 | } | |
223 | }; | |
224 | ||
225 | struct secret_int_bit_and | |
226 | { | |
227 | int operator()(int x, int y) const { return x & y; } | |
228 | }; | |
229 | ||
230 | struct wrapped_int | |
231 | { | |
232 | wrapped_int() : value(0) { } | |
233 | explicit wrapped_int(int value) : value(value) { } | |
234 | ||
235 | template<typename Archive> | |
236 | void serialize(Archive& ar, unsigned int /* version */) | |
237 | { | |
238 | ar & value; | |
239 | } | |
240 | ||
241 | int value; | |
242 | }; | |
243 | ||
244 | wrapped_int operator+(const wrapped_int& x, const wrapped_int& y) | |
245 | { | |
246 | return wrapped_int(x.value + y.value); | |
247 | } | |
248 | ||
249 | bool operator==(const wrapped_int& x, const wrapped_int& y) | |
250 | { | |
251 | return x.value == y.value; | |
252 | } | |
253 | ||
254 | bool operator<(const wrapped_int& x, const wrapped_int& y) | |
255 | { | |
256 | return x.value < y.value; | |
257 | } | |
258 | ||
259 | // Generates wrapped_its to test with all_reduce() | |
260 | struct wrapped_int_generator | |
261 | { | |
262 | typedef wrapped_int result_type; | |
263 | ||
264 | wrapped_int_generator(int base = 1) : base(base) { } | |
265 | ||
266 | wrapped_int operator()(int p) const { return wrapped_int(base + p); } | |
267 | ||
268 | private: | |
269 | int base; | |
270 | }; | |
271 | ||
272 | namespace boost { namespace mpi { | |
273 | ||
274 | // Make std::plus<wrapped_int> commutative. | |
275 | template<> | |
276 | struct is_commutative<std::plus<wrapped_int>, wrapped_int> | |
277 | : mpl::true_ { }; | |
278 | ||
279 | } } // end namespace boost::mpi | |
280 | ||
92f5a8d4 TL |
281 | BOOST_AUTO_TEST_CASE(test_all_reduce) |
282 | { | |
7c673cae | 283 | using namespace boost::mpi; |
92f5a8d4 | 284 | environment env; |
7c673cae FG |
285 | communicator comm; |
286 | ||
287 | // Built-in MPI datatypes with built-in MPI operations | |
92f5a8d4 TL |
288 | all_reduce_test(comm, int_generator(), "integers", std::plus<int>(), "sum", 0); |
289 | all_reduce_test(comm, int_generator(), "integers", std::multiplies<int>(), "product", 1); | |
290 | all_reduce_test(comm, int_generator(), "integers", maximum<int>(), "maximum", 0); | |
291 | all_reduce_test(comm, int_generator(), "integers", minimum<int>(), "minimum", 2); | |
7c673cae FG |
292 | |
293 | // User-defined MPI datatypes with operations that have the | |
294 | // same name as built-in operations. | |
92f5a8d4 TL |
295 | all_reduce_test(comm, point_generator(point(0,0,0)), "points", std::plus<point>(), |
296 | "sum", point()); | |
7c673cae FG |
297 | |
298 | // Built-in MPI datatypes with user-defined operations | |
92f5a8d4 | 299 | all_reduce_test(comm, int_generator(17), "integers", secret_int_bit_and(), |
7c673cae FG |
300 | "bitwise and", -1); |
301 | ||
302 | // Arbitrary types with user-defined, commutative operations. | |
303 | all_reduce_test(comm, wrapped_int_generator(17), "wrapped integers", | |
304 | std::plus<wrapped_int>(), "sum", wrapped_int(0)); | |
305 | ||
306 | // Arbitrary types with (non-commutative) user-defined operations | |
307 | all_reduce_test(comm, string_generator(), "strings", | |
308 | std::plus<std::string>(), "concatenation", std::string()); | |
7c673cae | 309 | } |