]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP | |
13 | ||
14 | #include <algorithm> | |
15 | #include <iterator> | |
16 | ||
17 | #include <boost/compute/command_queue.hpp> | |
18 | #include <boost/compute/functional.hpp> | |
19 | #include <boost/compute/algorithm/inclusive_scan.hpp> | |
20 | #include <boost/compute/container/vector.hpp> | |
21 | #include <boost/compute/container/detail/scalar.hpp> | |
22 | #include <boost/compute/detail/meta_kernel.hpp> | |
23 | #include <boost/compute/detail/iterator_range_size.hpp> | |
24 | #include <boost/compute/detail/read_write_single_value.hpp> | |
25 | #include <boost/compute/type_traits.hpp> | |
26 | #include <boost/compute/utility/program_cache.hpp> | |
27 | ||
28 | namespace boost { | |
29 | namespace compute { | |
30 | namespace detail { | |
31 | ||
32 | /// \internal_ | |
33 | /// | |
34 | /// Fills \p new_keys_first with unsigned integer keys generated from vector | |
35 | /// of original keys \p keys_first. New keys can be distinguish by simple equality | |
36 | /// predicate. | |
37 | /// | |
38 | /// \param keys_first iterator pointing to the first key | |
39 | /// \param number_of_keys number of keys | |
40 | /// \param predicate binary predicate for key comparison | |
41 | /// \param new_keys_first iterator pointing to the new keys vector | |
42 | /// \param preferred_work_group_size preferred work group size | |
43 | /// \param queue command queue to perform the operation | |
44 | /// | |
45 | /// Binary function \p predicate must take two keys as arguments and | |
46 | /// return true only if they are considered the same. | |
47 | /// | |
48 | /// The first new key equals zero and the last equals number of unique keys | |
49 | /// minus one. | |
50 | /// | |
51 | /// No local memory usage. | |
52 | template<class InputKeyIterator, class BinaryPredicate> | |
53 | inline void generate_uint_keys(InputKeyIterator keys_first, | |
54 | size_t number_of_keys, | |
55 | BinaryPredicate predicate, | |
56 | vector<uint_>::iterator new_keys_first, | |
57 | size_t preferred_work_group_size, | |
58 | command_queue &queue) | |
59 | { | |
60 | typedef typename | |
61 | std::iterator_traits<InputKeyIterator>::value_type key_type; | |
62 | ||
63 | detail::meta_kernel k("reduce_by_key_new_key_flags"); | |
64 | k.add_set_arg<const uint_>("count", uint_(number_of_keys)); | |
65 | ||
66 | k << | |
67 | k.decl<const uint_>("gid") << " = get_global_id(0);\n" << | |
68 | k.decl<uint_>("value") << " = 0;\n" << | |
69 | "if(gid >= count){\n return;\n}\n" << | |
70 | "if(gid > 0){ \n" << | |
71 | k.decl<key_type>("key") << " = " << | |
72 | keys_first[k.var<const uint_>("gid")] << ";\n" << | |
73 | k.decl<key_type>("previous_key") << " = " << | |
74 | keys_first[k.var<const uint_>("gid - 1")] << ";\n" << | |
75 | " value = " << predicate(k.var<key_type>("previous_key"), | |
76 | k.var<key_type>("key")) << | |
77 | " ? 0 : 1;\n" << | |
78 | "}\n else {\n" << | |
79 | " value = 0;\n" << | |
80 | "}\n" << | |
81 | new_keys_first[k.var<const uint_>("gid")] << " = value;\n"; | |
82 | ||
83 | const context &context = queue.get_context(); | |
84 | kernel kernel = k.compile(context); | |
85 | ||
86 | size_t work_group_size = preferred_work_group_size; | |
87 | size_t work_groups_no = static_cast<size_t>( | |
88 | std::ceil(float(number_of_keys) / work_group_size) | |
89 | ); | |
90 | ||
91 | queue.enqueue_1d_range_kernel(kernel, | |
92 | 0, | |
93 | work_groups_no * work_group_size, | |
94 | work_group_size); | |
95 | ||
96 | inclusive_scan(new_keys_first, new_keys_first + number_of_keys, | |
97 | new_keys_first, queue); | |
98 | } | |
99 | ||
100 | /// \internal_ | |
101 | /// Calculate carry-out for each work group. | |
102 | /// Carry-out is a pair of the last key processed by a work group and sum of all | |
103 | /// values under this key in this work group. | |
104 | template<class InputValueIterator, class OutputValueIterator, class BinaryFunction> | |
105 | inline void carry_outs(vector<uint_>::iterator keys_first, | |
106 | InputValueIterator values_first, | |
107 | size_t count, | |
108 | vector<uint_>::iterator carry_out_keys_first, | |
109 | OutputValueIterator carry_out_values_first, | |
110 | BinaryFunction function, | |
111 | size_t work_group_size, | |
112 | command_queue &queue) | |
113 | { | |
114 | typedef typename | |
115 | std::iterator_traits<OutputValueIterator>::value_type value_out_type; | |
116 | ||
117 | detail::meta_kernel k("reduce_by_key_with_scan_carry_outs"); | |
118 | k.add_set_arg<const uint_>("count", uint_(count)); | |
119 | size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys"); | |
120 | size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals"); | |
121 | ||
122 | k << | |
123 | k.decl<const uint_>("gid") << " = get_global_id(0);\n" << | |
124 | k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" << | |
125 | k.decl<const uint_>("lid") << " = get_local_id(0);\n" << | |
126 | k.decl<const uint_>("group_id") << " = get_group_id(0);\n" << | |
127 | ||
128 | k.decl<uint_>("key") << ";\n" << | |
129 | k.decl<value_out_type>("value") << ";\n" << | |
130 | "if(gid < count){\n" << | |
131 | k.var<uint_>("key") << " = " << | |
132 | keys_first[k.var<const uint_>("gid")] << ";\n" << | |
133 | k.var<value_out_type>("value") << " = " << | |
134 | values_first[k.var<const uint_>("gid")] << ";\n" << | |
135 | "lkeys[lid] = key;\n" << | |
136 | "lvals[lid] = value;\n" << | |
137 | "}\n" << | |
138 | ||
139 | // Calculate carry out for each work group by performing Hillis/Steele scan | |
140 | // where only last element (key-value pair) is saved | |
141 | k.decl<value_out_type>("result") << " = value;\n" << | |
142 | k.decl<uint_>("other_key") << ";\n" << | |
143 | k.decl<value_out_type>("other_value") << ";\n" << | |
144 | ||
145 | "for(" << k.decl<uint_>("offset") << " = 1; " << | |
146 | "offset < wg_size; offset *= 2){\n" | |
147 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
148 | " if(lid >= offset){\n" | |
149 | " other_key = lkeys[lid - offset];\n" << | |
150 | " if(other_key == key){\n" << | |
151 | " other_value = lvals[lid - offset];\n" << | |
152 | " result = " << function(k.var<value_out_type>("result"), | |
153 | k.var<value_out_type>("other_value")) << ";\n" << | |
154 | " }\n" << | |
155 | " }\n" << | |
156 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
157 | " lvals[lid] = result;\n" << | |
158 | "}\n" << | |
159 | ||
160 | // save carry out | |
161 | "if(lid == (wg_size - 1)){\n" << | |
162 | carry_out_keys_first[k.var<const uint_>("group_id")] << " = key;\n" << | |
163 | carry_out_values_first[k.var<const uint_>("group_id")] << " = result;\n" << | |
164 | "}\n"; | |
165 | ||
166 | size_t work_groups_no = static_cast<size_t>( | |
167 | std::ceil(float(count) / work_group_size) | |
168 | ); | |
169 | ||
170 | const context &context = queue.get_context(); | |
171 | kernel kernel = k.compile(context); | |
172 | kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size)); | |
173 | kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size)); | |
174 | ||
175 | queue.enqueue_1d_range_kernel(kernel, | |
176 | 0, | |
177 | work_groups_no * work_group_size, | |
178 | work_group_size); | |
179 | } | |
180 | ||
181 | /// \internal_ | |
182 | /// Calculate carry-in by performing inclusive scan by key on carry-outs vector. | |
183 | template<class OutputValueIterator, class BinaryFunction> | |
184 | inline void carry_ins(vector<uint_>::iterator carry_out_keys_first, | |
185 | OutputValueIterator carry_out_values_first, | |
186 | OutputValueIterator carry_in_values_first, | |
187 | size_t carry_out_size, | |
188 | BinaryFunction function, | |
189 | size_t work_group_size, | |
190 | command_queue &queue) | |
191 | { | |
192 | typedef typename | |
193 | std::iterator_traits<OutputValueIterator>::value_type value_out_type; | |
194 | ||
195 | uint_ values_pre_work_item = static_cast<uint_>( | |
196 | std::ceil(float(carry_out_size) / work_group_size) | |
197 | ); | |
198 | ||
199 | detail::meta_kernel k("reduce_by_key_with_scan_carry_ins"); | |
200 | k.add_set_arg<const uint_>("carry_out_size", uint_(carry_out_size)); | |
201 | k.add_set_arg<const uint_>("values_per_work_item", values_pre_work_item); | |
202 | size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys"); | |
203 | size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals"); | |
204 | ||
205 | k << | |
206 | k.decl<uint_>("id") << " = get_global_id(0) * values_per_work_item;\n" << | |
207 | k.decl<uint_>("idx") << " = id;\n" << | |
208 | k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" << | |
209 | k.decl<const uint_>("lid") << " = get_local_id(0);\n" << | |
210 | k.decl<const uint_>("group_id") << " = get_group_id(0);\n" << | |
211 | ||
212 | k.decl<uint_>("key") << ";\n" << | |
213 | k.decl<value_out_type>("value") << ";\n" << | |
214 | k.decl<uint_>("previous_key") << ";\n" << | |
215 | k.decl<value_out_type>("result") << ";\n" << | |
216 | ||
217 | "if(id < carry_out_size){\n" << | |
218 | k.var<uint_>("previous_key") << " = " << | |
219 | carry_out_keys_first[k.var<const uint_>("id")] << ";\n" << | |
220 | k.var<value_out_type>("result") << " = " << | |
221 | carry_out_values_first[k.var<const uint_>("id")] << ";\n" << | |
222 | carry_in_values_first[k.var<const uint_>("id")] << " = result;\n" << | |
223 | "}\n" << | |
224 | ||
225 | k.decl<const uint_>("end") << " = (id + values_per_work_item) <= carry_out_size" << | |
226 | " ? (values_per_work_item + id) : carry_out_size;\n" << | |
227 | ||
228 | "for(idx = idx + 1; idx < end; idx += 1){\n" << | |
229 | " key = " << carry_out_keys_first[k.var<const uint_>("idx")] << ";\n" << | |
230 | " value = " << carry_out_values_first[k.var<const uint_>("idx")] << ";\n" << | |
231 | " if(previous_key == key){\n" << | |
232 | " result = " << function(k.var<value_out_type>("result"), | |
233 | k.var<value_out_type>("value")) << ";\n" << | |
234 | " }\n else { \n" << | |
235 | " result = value;\n" | |
236 | " }\n" << | |
237 | " " << carry_in_values_first[k.var<const uint_>("idx")] << " = result;\n" << | |
238 | " previous_key = key;\n" | |
239 | "}\n" << | |
240 | ||
241 | // save the last key and result to local memory | |
242 | "lkeys[lid] = previous_key;\n" << | |
243 | "lvals[lid] = result;\n" << | |
244 | ||
245 | // Hillis/Steele scan | |
246 | "for(" << k.decl<uint_>("offset") << " = 1; " << | |
247 | "offset < wg_size; offset *= 2){\n" | |
248 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
249 | " if(lid >= offset){\n" | |
250 | " key = lkeys[lid - offset];\n" << | |
251 | " if(previous_key == key){\n" << | |
252 | " value = lvals[lid - offset];\n" << | |
253 | " result = " << function(k.var<value_out_type>("result"), | |
254 | k.var<value_out_type>("value")) << ";\n" << | |
255 | " }\n" << | |
256 | " }\n" << | |
257 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
258 | " lvals[lid] = result;\n" << | |
259 | "}\n" << | |
260 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
261 | ||
262 | "if(lid > 0){\n" << | |
263 | // load key-value reduced by previous work item | |
264 | " previous_key = lkeys[lid - 1];\n" << | |
265 | " result = lvals[lid - 1];\n" << | |
266 | "}\n" << | |
267 | ||
268 | // add key-value reduced by previous work item | |
269 | "for(idx = id; idx < id + values_per_work_item; idx += 1){\n" << | |
270 | // make sure all carry-ins are saved in global memory | |
271 | " barrier( CLK_GLOBAL_MEM_FENCE );\n" << | |
272 | " if(lid > 0 && idx < carry_out_size) {\n" | |
273 | " key = " << carry_out_keys_first[k.var<const uint_>("idx")] << ";\n" << | |
274 | " value = " << carry_in_values_first[k.var<const uint_>("idx")] << ";\n" << | |
275 | " if(previous_key == key){\n" << | |
276 | " value = " << function(k.var<value_out_type>("result"), | |
277 | k.var<value_out_type>("value")) << ";\n" << | |
278 | " }\n" << | |
279 | " " << carry_in_values_first[k.var<const uint_>("idx")] << " = value;\n" << | |
280 | " }\n" << | |
281 | "}\n"; | |
282 | ||
283 | ||
284 | const context &context = queue.get_context(); | |
285 | kernel kernel = k.compile(context); | |
286 | kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size)); | |
287 | kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size)); | |
288 | ||
289 | queue.enqueue_1d_range_kernel(kernel, | |
290 | 0, | |
291 | work_group_size, | |
292 | work_group_size); | |
293 | } | |
294 | ||
295 | /// \internal_ | |
296 | /// | |
297 | /// Perform final reduction by key. Each work item: | |
298 | /// 1. Perform local work-group reduction (Hillis/Steele scan) | |
299 | /// 2. Add carry-in (if keys are right) | |
300 | /// 3. Save reduced value if next key is different than processed one | |
301 | template<class InputKeyIterator, class InputValueIterator, | |
302 | class OutputKeyIterator, class OutputValueIterator, | |
303 | class BinaryFunction> | |
304 | inline void final_reduction(InputKeyIterator keys_first, | |
305 | InputValueIterator values_first, | |
306 | OutputKeyIterator keys_result, | |
307 | OutputValueIterator values_result, | |
308 | size_t count, | |
309 | BinaryFunction function, | |
310 | vector<uint_>::iterator new_keys_first, | |
311 | vector<uint_>::iterator carry_in_keys_first, | |
312 | OutputValueIterator carry_in_values_first, | |
313 | size_t carry_in_size, | |
314 | size_t work_group_size, | |
315 | command_queue &queue) | |
316 | { | |
317 | typedef typename | |
318 | std::iterator_traits<OutputValueIterator>::value_type value_out_type; | |
319 | ||
320 | detail::meta_kernel k("reduce_by_key_with_scan_final_reduction"); | |
321 | k.add_set_arg<const uint_>("count", uint_(count)); | |
322 | size_t local_keys_arg = k.add_arg<uint_ *>(memory_object::local_memory, "lkeys"); | |
323 | size_t local_vals_arg = k.add_arg<value_out_type *>(memory_object::local_memory, "lvals"); | |
324 | ||
325 | k << | |
326 | k.decl<const uint_>("gid") << " = get_global_id(0);\n" << | |
327 | k.decl<const uint_>("wg_size") << " = get_local_size(0);\n" << | |
328 | k.decl<const uint_>("lid") << " = get_local_id(0);\n" << | |
329 | k.decl<const uint_>("group_id") << " = get_group_id(0);\n" << | |
330 | ||
331 | k.decl<uint_>("key") << ";\n" << | |
332 | k.decl<value_out_type>("value") << ";\n" | |
333 | ||
334 | "if(gid < count){\n" << | |
335 | k.var<uint_>("key") << " = " << | |
336 | new_keys_first[k.var<const uint_>("gid")] << ";\n" << | |
337 | k.var<value_out_type>("value") << " = " << | |
338 | values_first[k.var<const uint_>("gid")] << ";\n" << | |
339 | "lkeys[lid] = key;\n" << | |
340 | "lvals[lid] = value;\n" << | |
341 | "}\n" << | |
342 | ||
343 | // Hillis/Steele scan | |
344 | k.decl<value_out_type>("result") << " = value;\n" << | |
345 | k.decl<uint_>("other_key") << ";\n" << | |
346 | k.decl<value_out_type>("other_value") << ";\n" << | |
347 | ||
348 | "for(" << k.decl<uint_>("offset") << " = 1; " << | |
349 | "offset < wg_size ; offset *= 2){\n" | |
350 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
351 | " if(lid >= offset) {\n" << | |
352 | " other_key = lkeys[lid - offset];\n" << | |
353 | " if(other_key == key){\n" << | |
354 | " other_value = lvals[lid - offset];\n" << | |
355 | " result = " << function(k.var<value_out_type>("result"), | |
356 | k.var<value_out_type>("other_value")) << ";\n" << | |
357 | " }\n" << | |
358 | " }\n" << | |
359 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
360 | " lvals[lid] = result;\n" << | |
361 | "}\n" << | |
362 | ||
363 | "if(gid >= count) {\n return;\n};\n" << | |
364 | ||
365 | k.decl<const bool>("save") << " = (gid < (count - 1)) ?" | |
366 | << new_keys_first[k.var<const uint_>("gid + 1")] << " != key" << | |
367 | ": true;\n" << | |
368 | ||
369 | // Add carry in | |
370 | k.decl<uint_>("carry_in_key") << ";\n" << | |
371 | "if(group_id > 0 && save) {\n" << | |
372 | " carry_in_key = " << carry_in_keys_first[k.var<const uint_>("group_id - 1")] << ";\n" << | |
373 | " if(key == carry_in_key){\n" << | |
374 | " other_value = " << carry_in_values_first[k.var<const uint_>("group_id - 1")] << ";\n" << | |
375 | " result = " << function(k.var<value_out_type>("result"), | |
376 | k.var<value_out_type>("other_value")) << ";\n" << | |
377 | " }\n" << | |
378 | "}\n" << | |
379 | ||
380 | // Save result only if the next key is different or it's the last element. | |
381 | "if(save){\n" << | |
382 | keys_result[k.var<uint_>("key")] << " = " << keys_first[k.var<const uint_>("gid")] << ";\n" << | |
383 | values_result[k.var<uint_>("key")] << " = result;\n" << | |
384 | "}\n" | |
385 | ; | |
386 | ||
387 | size_t work_groups_no = static_cast<size_t>( | |
388 | std::ceil(float(count) / work_group_size) | |
389 | ); | |
390 | ||
391 | const context &context = queue.get_context(); | |
392 | kernel kernel = k.compile(context); | |
393 | kernel.set_arg(local_keys_arg, local_buffer<uint_>(work_group_size)); | |
394 | kernel.set_arg(local_vals_arg, local_buffer<value_out_type>(work_group_size)); | |
395 | ||
396 | queue.enqueue_1d_range_kernel(kernel, | |
397 | 0, | |
398 | work_groups_no * work_group_size, | |
399 | work_group_size); | |
400 | } | |
401 | ||
402 | /// \internal_ | |
403 | /// Returns preferred work group size for reduce by key with scan algorithm. | |
404 | template<class KeyType, class ValueType> | |
405 | inline size_t get_work_group_size(const device& device) | |
406 | { | |
407 | std::string cache_key = std::string("__boost_reduce_by_key_with_scan") | |
408 | + "k_" + type_name<KeyType>() + "_v_" + type_name<ValueType>(); | |
409 | ||
410 | // load parameters | |
411 | boost::shared_ptr<parameter_cache> parameters = | |
412 | detail::parameter_cache::get_global_cache(device); | |
413 | ||
414 | return (std::max)( | |
415 | static_cast<size_t>(parameters->get(cache_key, "wgsize", 256)), | |
416 | static_cast<size_t>(device.get_info<CL_DEVICE_MAX_WORK_GROUP_SIZE>()) | |
417 | ); | |
418 | } | |
419 | ||
420 | /// \internal_ | |
421 | /// | |
422 | /// 1. For each work group carry-out value is calculated (it's done by key-oriented | |
423 | /// Hillis/Steele scan). Carry-out is a pair of the last key processed by work | |
424 | /// group and sum of all values under this key in work group. | |
425 | /// 2. From every carry-out carry-in is calculated by performing inclusive scan | |
426 | /// by key. | |
427 | /// 3. Final reduction by key is performed (key-oriented Hillis/Steele scan), | |
428 | /// carry-in values are added where needed. | |
429 | template<class InputKeyIterator, class InputValueIterator, | |
430 | class OutputKeyIterator, class OutputValueIterator, | |
431 | class BinaryFunction, class BinaryPredicate> | |
432 | inline size_t reduce_by_key_with_scan(InputKeyIterator keys_first, | |
433 | InputKeyIterator keys_last, | |
434 | InputValueIterator values_first, | |
435 | OutputKeyIterator keys_result, | |
436 | OutputValueIterator values_result, | |
437 | BinaryFunction function, | |
438 | BinaryPredicate predicate, | |
439 | command_queue &queue) | |
440 | { | |
441 | typedef typename | |
442 | std::iterator_traits<InputValueIterator>::value_type value_type; | |
443 | typedef typename | |
444 | std::iterator_traits<InputKeyIterator>::value_type key_type; | |
445 | typedef typename | |
446 | std::iterator_traits<OutputValueIterator>::value_type value_out_type; | |
447 | ||
448 | const context &context = queue.get_context(); | |
449 | size_t count = detail::iterator_range_size(keys_first, keys_last); | |
450 | ||
451 | if(count == 0){ | |
452 | return size_t(0); | |
453 | } | |
454 | ||
455 | const device &device = queue.get_device(); | |
456 | size_t work_group_size = get_work_group_size<value_type, key_type>(device); | |
457 | ||
458 | // Replace original key with unsigned integer keys generated based on given | |
459 | // predicate. New key is also an index for keys_result and values_result vectors, | |
460 | // which points to place where reduced value should be saved. | |
461 | vector<uint_> new_keys(count, context); | |
462 | vector<uint_>::iterator new_keys_first = new_keys.begin(); | |
463 | generate_uint_keys(keys_first, count, predicate, new_keys_first, | |
464 | work_group_size, queue); | |
465 | ||
466 | // Calculate carry-out and carry-in vectors size | |
467 | const size_t carry_out_size = static_cast<size_t>( | |
468 | std::ceil(float(count) / work_group_size) | |
469 | ); | |
470 | vector<uint_> carry_out_keys(carry_out_size, context); | |
471 | vector<value_out_type> carry_out_values(carry_out_size, context); | |
472 | carry_outs(new_keys_first, values_first, count, carry_out_keys.begin(), | |
473 | carry_out_values.begin(), function, work_group_size, queue); | |
474 | ||
475 | vector<value_out_type> carry_in_values(carry_out_size, context); | |
476 | carry_ins(carry_out_keys.begin(), carry_out_values.begin(), | |
477 | carry_in_values.begin(), carry_out_size, function, work_group_size, | |
478 | queue); | |
479 | ||
480 | final_reduction(keys_first, values_first, keys_result, values_result, | |
481 | count, function, new_keys_first, carry_out_keys.begin(), | |
482 | carry_in_values.begin(), carry_out_size, work_group_size, | |
483 | queue); | |
484 | ||
485 | const size_t result = read_single_value<uint_>(new_keys.get_buffer(), | |
486 | count - 1, queue); | |
487 | return result + 1; | |
488 | } | |
489 | ||
490 | /// \internal_ | |
491 | /// Return true if requirements for running reduce by key with scan on given | |
492 | /// device are met (at least one work group of preferred size can be run). | |
493 | template<class InputKeyIterator, class InputValueIterator, | |
494 | class OutputKeyIterator, class OutputValueIterator> | |
495 | bool reduce_by_key_with_scan_requirements_met(InputKeyIterator keys_first, | |
496 | InputValueIterator values_first, | |
497 | OutputKeyIterator keys_result, | |
498 | OutputValueIterator values_result, | |
499 | const size_t count, | |
500 | command_queue &queue) | |
501 | { | |
502 | typedef typename | |
503 | std::iterator_traits<InputValueIterator>::value_type value_type; | |
504 | typedef typename | |
505 | std::iterator_traits<InputKeyIterator>::value_type key_type; | |
506 | typedef typename | |
507 | std::iterator_traits<OutputValueIterator>::value_type value_out_type; | |
508 | ||
509 | (void) keys_first; | |
510 | (void) values_first; | |
511 | (void) keys_result; | |
512 | (void) values_result; | |
513 | ||
514 | const device &device = queue.get_device(); | |
515 | // device must have dedicated local memory storage | |
516 | if(device.get_info<CL_DEVICE_LOCAL_MEM_TYPE>() != CL_LOCAL) | |
517 | { | |
518 | return false; | |
519 | } | |
520 | ||
521 | // local memory size in bytes (per compute unit) | |
522 | const size_t local_mem_size = device.get_info<CL_DEVICE_LOCAL_MEM_SIZE>(); | |
523 | ||
524 | // preferred work group size | |
525 | size_t work_group_size = get_work_group_size<key_type, value_type>(device); | |
526 | ||
527 | // local memory size needed to perform parallel reduction | |
528 | size_t required_local_mem_size = 0; | |
529 | // keys size | |
530 | required_local_mem_size += sizeof(uint_) * work_group_size; | |
531 | // reduced values size | |
532 | required_local_mem_size += sizeof(value_out_type) * work_group_size; | |
533 | ||
534 | return (required_local_mem_size <= local_mem_size); | |
535 | } | |
536 | ||
537 | } // end detail namespace | |
538 | } // end compute namespace | |
539 | } // end boost namespace | |
540 | ||
541 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_REDUCE_BY_KEY_WITH_SCAN_HPP |