]>
Commit | Line | Data |
---|---|---|
1 | //---------------------------------------------------------------------------// | |
2 | // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ | |
13 | ||
14 | #include <algorithm> | |
15 | ||
16 | #include <boost/compute/kernel.hpp> | |
17 | #include <boost/compute/program.hpp> | |
18 | #include <boost/compute/command_queue.hpp> | |
19 | #include <boost/compute/container/vector.hpp> | |
20 | #include <boost/compute/memory/local_buffer.hpp> | |
21 | #include <boost/compute/detail/meta_kernel.hpp> | |
22 | #include <boost/compute/detail/iterator_range_size.hpp> | |
23 | ||
24 | namespace boost { | |
25 | namespace compute { | |
26 | namespace detail { | |
27 | ||
28 | template<class KeyType, class ValueType> | |
29 | inline size_t pick_bitonic_block_sort_block_size(size_t proposed_wg, | |
30 | size_t lmem_size, | |
31 | bool sort_by_key) | |
32 | { | |
33 | size_t n = proposed_wg; | |
34 | ||
35 | size_t lmem_required = n * sizeof(KeyType); | |
36 | if(sort_by_key) { | |
37 | lmem_required += n * sizeof(ValueType); | |
38 | } | |
39 | ||
40 | // try to force at least 4 work-groups of >64 elements | |
41 | // for better occupancy | |
42 | while(lmem_size < (lmem_required * 4) && (n > 64)) { | |
43 | n /= 2; | |
44 | lmem_required = n * sizeof(KeyType); | |
45 | } | |
46 | while(lmem_size < lmem_required && (n != 1)) { | |
47 | n /= 2; | |
48 | if(n < 1) n = 1; | |
49 | lmem_required = n * sizeof(KeyType); | |
50 | } | |
51 | ||
52 | if(n < 2) { return 1; } | |
53 | else if(n < 4) { return 2; } | |
54 | else if(n < 8) { return 4; } | |
55 | else if(n < 16) { return 8; } | |
56 | else if(n < 32) { return 16; } | |
57 | else if(n < 64) { return 32; } | |
58 | else if(n < 128) { return 64; } | |
59 | else if(n < 256) { return 128; } | |
60 | else { return 256; } | |
61 | } | |
62 | ||
63 | ||
64 | /// Performs bitonic block sort according to \p compare. | |
65 | /// | |
66 | /// Since bitonic sort can be only performed when input size is equal to 2^n, | |
67 | /// in this case input size is block size (\p work_group_size), we would have | |
68 | /// to require \p count be a exact multiple of block size. That would not be | |
69 | /// great. | |
70 | /// Instead, bitonic sort kernel is merged with odd-even merge sort so if the | |
71 | /// last block is not equal to 2^n (where n is some natural number) the odd-even | |
72 | /// sort is performed for that block. That way bitonic_block_sort() works for | |
73 | /// input of any size. Block size (\p work_group_size) still have to be equal | |
74 | /// to 2^n. | |
75 | /// | |
76 | /// This is NOT stable. | |
77 | /// | |
78 | /// \param keys_first first key element in the range to sort | |
79 | /// \param values_first first value element in the range to sort | |
80 | /// \param compare comparison function for keys | |
81 | /// \param count number of elements in the range; count > 0 | |
82 | /// \param work_group_size size of the work group, also the block size; must be | |
83 | /// equal to n^2 where n is natural number | |
84 | /// \param queue command queue to perform the operation | |
85 | template<class KeyIterator, class ValueIterator, class Compare> | |
86 | inline size_t bitonic_block_sort(KeyIterator keys_first, | |
87 | ValueIterator values_first, | |
88 | Compare compare, | |
89 | const size_t count, | |
90 | const bool sort_by_key, | |
91 | command_queue &queue) | |
92 | { | |
93 | typedef typename std::iterator_traits<KeyIterator>::value_type key_type; | |
94 | typedef typename std::iterator_traits<ValueIterator>::value_type value_type; | |
95 | ||
96 | meta_kernel k("bitonic_block_sort"); | |
97 | size_t count_arg = k.add_arg<const uint_>("count"); | |
98 | ||
99 | size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "lkeys"); | |
100 | size_t local_vals_arg = 0; | |
101 | if(sort_by_key) { | |
102 | local_vals_arg = k.add_arg<uchar_ *>(memory_object::local_memory, "lidx"); | |
103 | } | |
104 | ||
105 | k << | |
106 | // Work item global and local ids | |
107 | k.decl<const uint_>("gid") << " = get_global_id(0);\n" << | |
108 | k.decl<const uint_>("lid") << " = get_local_id(0);\n"; | |
109 | ||
110 | // declare my_key and my_value | |
111 | k << | |
112 | k.decl<key_type>("my_key") << ";\n"; | |
113 | // Instead of copying values (my_value) in local memory with keys | |
114 | // we save local index (uchar) and copy my_value at the end at | |
115 | // final index. This saves local memory. | |
116 | if(sort_by_key) | |
117 | { | |
118 | k << | |
119 | k.decl<uchar_>("my_index") << " = (uchar)(lid);\n"; | |
120 | } | |
121 | ||
122 | // load key | |
123 | k << | |
124 | "if(gid < count) {\n" << | |
125 | k.var<key_type>("my_key") << " = " << | |
126 | keys_first[k.var<const uint_>("gid")] << ";\n" << | |
127 | "}\n"; | |
128 | ||
129 | // load key and index to local memory | |
130 | k << | |
131 | "lkeys[lid] = my_key;\n"; | |
132 | if(sort_by_key) | |
133 | { | |
134 | k << | |
135 | "lidx[lid] = my_index;\n"; | |
136 | } | |
137 | k << | |
138 | k.decl<const uint_>("offset") << " = get_group_id(0) * get_local_size(0);\n" << | |
139 | k.decl<const uint_>("n") << " = min((uint)(get_local_size(0)),(count - offset));\n"; | |
140 | ||
141 | // When work group size is a power of 2 bitonic sorter can be used; | |
142 | // otherwise, slower odd-even sort is used. | |
143 | ||
144 | k << | |
145 | // check if n is power of 2 | |
146 | "if(((n != 0) && ((n & (~n + 1)) == n))) {\n"; | |
147 | ||
148 | // bitonic sort, not stable | |
149 | k << | |
150 | // wait for keys and vals to be stored in local memory | |
151 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
152 | ||
153 | "#pragma unroll\n" << | |
154 | "for(" << | |
155 | k.decl<uint_>("length") << " = 1; " << | |
156 | "length < n; " << | |
157 | "length <<= 1" << | |
158 | ") {\n" << | |
159 | // direction of sort: false -> asc, true -> desc | |
160 | k.decl<bool>("direction") << "= ((lid & (length<<1)) != 0);\n" << | |
161 | "for(" << | |
162 | k.decl<uint_>("k") << " = length; " << | |
163 | "k > 0; " << | |
164 | "k >>= 1" << | |
165 | ") {\n" << | |
166 | ||
167 | // sibling to compare with my key | |
168 | k.decl<uint_>("sibling_idx") << " = lid ^ k;\n" << | |
169 | k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" << | |
170 | k.decl<bool>("compare") << " = " << | |
171 | compare(k.var<key_type>("sibling_key"), | |
172 | k.var<key_type>("my_key")) << ";\n" << | |
173 | k.decl<bool>("equal") << " = !(compare || " << | |
174 | compare(k.var<key_type>("my_key"), | |
175 | k.var<key_type>("sibling_key")) << ");\n" << | |
176 | k.decl<bool>("swap") << | |
177 | " = compare ^ (sibling_idx < lid) ^ direction;\n" << | |
178 | "swap = equal ? false : swap;\n" << | |
179 | "my_key = swap ? sibling_key : my_key;\n"; | |
180 | if(sort_by_key) | |
181 | { | |
182 | k << | |
183 | "my_index = swap ? lidx[sibling_idx] : my_index;\n"; | |
184 | } | |
185 | k << | |
186 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
187 | "lkeys[lid] = my_key;\n"; | |
188 | if(sort_by_key) | |
189 | { | |
190 | k << | |
191 | "lidx[lid] = my_index;\n"; | |
192 | } | |
193 | k << | |
194 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
195 | "}\n" << | |
196 | "}\n"; | |
197 | ||
198 | // end of bitonic sort | |
199 | ||
200 | // odd-even sort, not stable | |
201 | k << | |
202 | "}\n" << | |
203 | "else { \n"; | |
204 | ||
205 | k << | |
206 | k.decl<bool>("lid_is_even") << " = (lid%2) == 0;\n" << | |
207 | k.decl<uint_>("oddsibling_idx") << " = " << | |
208 | "(lid_is_even) ? max(lid,(uint)(1)) - 1 : min(lid+1,n-1);\n" << | |
209 | k.decl<uint_>("evensibling_idx") << " = " << | |
210 | "(lid_is_even) ? min(lid+1,n-1) : max(lid,(uint)(1)) - 1;\n" << | |
211 | ||
212 | // wait for keys and vals to be stored in local memory | |
213 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
214 | ||
215 | "#pragma unroll\n" << | |
216 | "for(" << | |
217 | k.decl<uint_>("i") << " = 0; " << | |
218 | "i < n; " << | |
219 | "i++" << | |
220 | ") {\n" << | |
221 | k.decl<uint_>("sibling_idx") << | |
222 | " = i%2 == 0 ? evensibling_idx : oddsibling_idx;\n" << | |
223 | k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" << | |
224 | k.decl<bool>("compare") << " = " << | |
225 | compare(k.var<key_type>("sibling_key"), | |
226 | k.var<key_type>("my_key")) << ";\n" << | |
227 | k.decl<bool>("equal") << " = !(compare || " << | |
228 | compare(k.var<key_type>("my_key"), | |
229 | k.var<key_type>("sibling_key")) << ");\n" << | |
230 | k.decl<bool>("swap") << | |
231 | " = compare ^ (sibling_idx < lid);\n" << | |
232 | "swap = equal ? false : swap;\n" << | |
233 | "my_key = swap ? sibling_key : my_key;\n"; | |
234 | if(sort_by_key) | |
235 | { | |
236 | k << | |
237 | "my_index = swap ? lidx[sibling_idx] : my_index;\n"; | |
238 | } | |
239 | k << | |
240 | "barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
241 | "lkeys[lid] = my_key;\n"; | |
242 | if(sort_by_key) | |
243 | { | |
244 | k << | |
245 | "lidx[lid] = my_index;\n"; | |
246 | } | |
247 | k << | |
248 | "barrier(CLK_LOCAL_MEM_FENCE);\n" | |
249 | "}\n" << // for | |
250 | ||
251 | "}\n"; // else | |
252 | // end of odd-even sort | |
253 | ||
254 | // save key and value | |
255 | k << | |
256 | "if(gid < count) {\n" << | |
257 | keys_first[k.var<const uint_>("gid")] << " = " << | |
258 | k.var<key_type>("my_key") << ";\n"; | |
259 | if(sort_by_key) | |
260 | { | |
261 | k << | |
262 | k.decl<value_type>("my_value") << " = " << | |
263 | values_first[k.var<const uint_>("offset + my_index")] << ";\n" << | |
264 | "barrier(CLK_GLOBAL_MEM_FENCE);\n" << | |
265 | values_first[k.var<const uint_>("gid")] << " = my_value;\n"; | |
266 | } | |
267 | k << | |
268 | // end if | |
269 | "}\n"; | |
270 | ||
271 | const context &context = queue.get_context(); | |
272 | const device &device = queue.get_device(); | |
273 | ::boost::compute::kernel kernel = k.compile(context); | |
274 | ||
275 | const size_t work_group_size = | |
276 | pick_bitonic_block_sort_block_size<key_type, uchar_>( | |
277 | kernel.get_work_group_info<size_t>( | |
278 | device, CL_KERNEL_WORK_GROUP_SIZE | |
279 | ), | |
280 | device.get_info<size_t>(CL_DEVICE_LOCAL_MEM_SIZE), | |
281 | sort_by_key | |
282 | ); | |
283 | ||
284 | const size_t global_size = | |
285 | work_group_size * static_cast<size_t>( | |
286 | std::ceil(float(count) / work_group_size) | |
287 | ); | |
288 | ||
289 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
290 | kernel.set_arg(local_keys_arg, local_buffer<key_type>(work_group_size)); | |
291 | if(sort_by_key) { | |
292 | kernel.set_arg(local_vals_arg, local_buffer<uchar_>(work_group_size)); | |
293 | } | |
294 | ||
295 | queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size); | |
296 | // return size of the block | |
297 | return work_group_size; | |
298 | } | |
299 | ||
300 | template<class KeyIterator, class ValueIterator, class Compare> | |
301 | inline size_t block_sort(KeyIterator keys_first, | |
302 | ValueIterator values_first, | |
303 | Compare compare, | |
304 | const size_t count, | |
305 | const bool sort_by_key, | |
306 | const bool stable, | |
307 | command_queue &queue) | |
308 | { | |
309 | if(stable) { | |
310 | // TODO: Implement stable block sort (stable odd-even merge sort) | |
311 | return size_t(1); | |
312 | } | |
313 | return bitonic_block_sort( | |
314 | keys_first, values_first, | |
315 | compare, count, | |
316 | sort_by_key, queue | |
317 | ); | |
318 | } | |
319 | ||
320 | /// space: O(n + m); n - number of keys, m - number of values | |
321 | template<class KeyIterator, class ValueIterator, class Compare> | |
322 | inline void merge_blocks_on_gpu(KeyIterator keys_first, | |
323 | ValueIterator values_first, | |
324 | KeyIterator out_keys_first, | |
325 | ValueIterator out_values_first, | |
326 | Compare compare, | |
327 | const size_t count, | |
328 | const size_t block_size, | |
329 | const bool sort_by_key, | |
330 | command_queue &queue) | |
331 | { | |
332 | typedef typename std::iterator_traits<KeyIterator>::value_type key_type; | |
333 | typedef typename std::iterator_traits<ValueIterator>::value_type value_type; | |
334 | ||
335 | meta_kernel k("merge_blocks"); | |
336 | size_t count_arg = k.add_arg<const uint_>("count"); | |
337 | size_t block_size_arg = k.add_arg<const uint_>("block_size"); | |
338 | ||
339 | k << | |
340 | // get global id | |
341 | k.decl<const uint_>("gid") << " = get_global_id(0);\n" << | |
342 | "if(gid >= count) {\n" << | |
343 | "return;\n" << | |
344 | "}\n" << | |
345 | ||
346 | k.decl<const key_type>("my_key") << " = " << | |
347 | keys_first[k.var<const uint_>("gid")] << ";\n"; | |
348 | ||
349 | if(sort_by_key) { | |
350 | k << | |
351 | k.decl<const value_type>("my_value") << " = " << | |
352 | values_first[k.var<const uint_>("gid")] << ";\n"; | |
353 | } | |
354 | ||
355 | k << | |
356 | // get my block idx | |
357 | k.decl<const uint_>("my_block_idx") << " = gid / block_size;\n" << | |
358 | k.decl<const bool>("my_block_idx_is_odd") << " = " << | |
359 | "my_block_idx & 0x1;\n" << | |
360 | ||
361 | k.decl<const uint_>("other_block_idx") << " = " << | |
362 | // if(my_block_idx is odd) {} else {} | |
363 | "my_block_idx_is_odd ? my_block_idx - 1 : my_block_idx + 1;\n" << | |
364 | ||
365 | // get ranges of my block and the other block | |
366 | // [my_block_start; my_block_end) | |
367 | // [other_block_start; other_block_end) | |
368 | k.decl<const uint_>("my_block_start") << " = " << | |
369 | "min(my_block_idx * block_size, count);\n" << // including | |
370 | k.decl<const uint_>("my_block_end") << " = " << | |
371 | "min((my_block_idx + 1) * block_size, count);\n" << // excluding | |
372 | ||
373 | k.decl<const uint_>("other_block_start") << " = " << | |
374 | "min(other_block_idx * block_size, count);\n" << // including | |
375 | k.decl<const uint_>("other_block_end") << " = " << | |
376 | "min((other_block_idx + 1) * block_size, count);\n" << // excluding | |
377 | ||
378 | // other block is empty, nothing to merge here | |
379 | "if(other_block_start == count){\n" << | |
380 | out_keys_first[k.var<uint_>("gid")] << " = my_key;\n"; | |
381 | if(sort_by_key) { | |
382 | k << | |
383 | out_values_first[k.var<uint_>("gid")] << " = my_value;\n"; | |
384 | } | |
385 | ||
386 | k << | |
387 | "return;\n" << | |
388 | "}\n" << | |
389 | ||
390 | // lower bound | |
391 | // left_idx - lower bound | |
392 | k.decl<uint_>("left_idx") << " = other_block_start;\n" << | |
393 | k.decl<uint_>("right_idx") << " = other_block_end;\n" << | |
394 | "while(left_idx < right_idx) {\n" << | |
395 | k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" << | |
396 | k.decl<key_type>("mid_key") << " = " << | |
397 | keys_first[k.var<const uint_>("mid_idx")] << ";\n" << | |
398 | k.decl<bool>("smaller") << " = " << | |
399 | compare(k.var<key_type>("mid_key"), | |
400 | k.var<key_type>("my_key")) << ";\n" << | |
401 | "left_idx = smaller ? mid_idx + 1 : left_idx;\n" << | |
402 | "right_idx = smaller ? right_idx : mid_idx;\n" << | |
403 | "}\n" << | |
404 | // left_idx is found position in other block | |
405 | ||
406 | // if my_block is odd we need to get the upper bound | |
407 | "right_idx = other_block_end;\n" << | |
408 | "if(my_block_idx_is_odd && left_idx != right_idx) {\n" << | |
409 | k.decl<key_type>("upper_key") << " = " << | |
410 | keys_first[k.var<const uint_>("left_idx")] << ";\n" << | |
411 | "while(" << | |
412 | "!(" << compare(k.var<key_type>("upper_key"), | |
413 | k.var<key_type>("my_key")) << | |
414 | ") && " << | |
415 | "!(" << compare(k.var<key_type>("my_key"), | |
416 | k.var<key_type>("upper_key")) << | |
417 | ") && " << | |
418 | "left_idx < right_idx" << | |
419 | ")" << | |
420 | "{\n" << | |
421 | k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" << | |
422 | k.decl<key_type>("mid_key") << " = " << | |
423 | keys_first[k.var<const uint_>("mid_idx")] << ";\n" << | |
424 | k.decl<bool>("equal") << " = " << | |
425 | "!(" << compare(k.var<key_type>("mid_key"), | |
426 | k.var<key_type>("my_key")) << | |
427 | ") && " << | |
428 | "!(" << compare(k.var<key_type>("my_key"), | |
429 | k.var<key_type>("mid_key")) << | |
430 | ");\n" << | |
431 | "left_idx = equal ? mid_idx + 1 : left_idx + 1;\n" << | |
432 | "right_idx = equal ? right_idx : mid_idx;\n" << | |
433 | "upper_key = " << | |
434 | keys_first[k.var<const uint_>("left_idx")] << ";\n" << | |
435 | "}\n" << | |
436 | "}\n" << | |
437 | ||
438 | k.decl<uint_>("offset") << " = 0;\n" << | |
439 | "offset += gid - my_block_start;\n" << | |
440 | "offset += left_idx - other_block_start;\n" << | |
441 | "offset += min(my_block_start, other_block_start);\n" << | |
442 | out_keys_first[k.var<uint_>("offset")] << " = my_key;\n"; | |
443 | if(sort_by_key) { | |
444 | k << | |
445 | out_values_first[k.var<uint_>("offset")] << " = my_value;\n"; | |
446 | } | |
447 | ||
448 | const context &context = queue.get_context(); | |
449 | ::boost::compute::kernel kernel = k.compile(context); | |
450 | ||
451 | const size_t work_group_size = (std::min)( | |
452 | size_t(256), | |
453 | kernel.get_work_group_info<size_t>( | |
454 | queue.get_device(), CL_KERNEL_WORK_GROUP_SIZE | |
455 | ) | |
456 | ); | |
457 | const size_t global_size = | |
458 | work_group_size * static_cast<size_t>( | |
459 | std::ceil(float(count) / work_group_size) | |
460 | ); | |
461 | ||
462 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
463 | kernel.set_arg(block_size_arg, static_cast<uint_>(block_size)); | |
464 | queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size); | |
465 | } | |
466 | ||
467 | template<class KeyIterator, class ValueIterator, class Compare> | |
468 | inline void merge_sort_by_key_on_gpu(KeyIterator keys_first, | |
469 | KeyIterator keys_last, | |
470 | ValueIterator values_first, | |
471 | Compare compare, | |
472 | bool stable, | |
473 | command_queue &queue) | |
474 | { | |
475 | typedef typename std::iterator_traits<KeyIterator>::value_type key_type; | |
476 | typedef typename std::iterator_traits<ValueIterator>::value_type value_type; | |
477 | ||
478 | size_t count = iterator_range_size(keys_first, keys_last); | |
479 | if(count < 2){ | |
480 | return; | |
481 | } | |
482 | ||
483 | size_t block_size = | |
484 | block_sort( | |
485 | keys_first, values_first, | |
486 | compare, count, | |
487 | true /* sort_by_key */, stable /* stable */, | |
488 | queue | |
489 | ); | |
490 | ||
491 | // for small input size only block sort is performed | |
492 | if(count <= block_size) { | |
493 | return; | |
494 | } | |
495 | ||
496 | const context &context = queue.get_context(); | |
497 | ||
498 | bool result_in_temporary_buffer = false; | |
499 | ::boost::compute::vector<key_type> temp_keys(count, context); | |
500 | ::boost::compute::vector<value_type> temp_values(count, context); | |
501 | ||
502 | for(; block_size < count; block_size *= 2) { | |
503 | result_in_temporary_buffer = !result_in_temporary_buffer; | |
504 | if(result_in_temporary_buffer) { | |
505 | merge_blocks_on_gpu(keys_first, values_first, | |
506 | temp_keys.begin(), temp_values.begin(), | |
507 | compare, count, block_size, | |
508 | true /* sort_by_key */, queue); | |
509 | } else { | |
510 | merge_blocks_on_gpu(temp_keys.begin(), temp_values.begin(), | |
511 | keys_first, values_first, | |
512 | compare, count, block_size, | |
513 | true /* sort_by_key */, queue); | |
514 | } | |
515 | } | |
516 | ||
517 | if(result_in_temporary_buffer) { | |
518 | copy_async(temp_keys.begin(), temp_keys.end(), keys_first, queue); | |
519 | copy_async(temp_values.begin(), temp_values.end(), values_first, queue); | |
520 | } | |
521 | } | |
522 | ||
523 | template<class Iterator, class Compare> | |
524 | inline void merge_sort_on_gpu(Iterator first, | |
525 | Iterator last, | |
526 | Compare compare, | |
527 | bool stable, | |
528 | command_queue &queue) | |
529 | { | |
530 | typedef typename std::iterator_traits<Iterator>::value_type key_type; | |
531 | ||
532 | size_t count = iterator_range_size(first, last); | |
533 | if(count < 2){ | |
534 | return; | |
535 | } | |
536 | ||
537 | Iterator dummy; | |
538 | size_t block_size = | |
539 | block_sort( | |
540 | first, dummy, | |
541 | compare, count, | |
542 | false /* sort_by_key */, stable /* stable */, | |
543 | queue | |
544 | ); | |
545 | ||
546 | // for small input size only block sort is performed | |
547 | if(count <= block_size) { | |
548 | return; | |
549 | } | |
550 | ||
551 | const context &context = queue.get_context(); | |
552 | ||
553 | bool result_in_temporary_buffer = false; | |
554 | ::boost::compute::vector<key_type> temp_keys(count, context); | |
555 | ||
556 | for(; block_size < count; block_size *= 2) { | |
557 | result_in_temporary_buffer = !result_in_temporary_buffer; | |
558 | if(result_in_temporary_buffer) { | |
559 | merge_blocks_on_gpu(first, dummy, temp_keys.begin(), dummy, | |
560 | compare, count, block_size, | |
561 | false /* sort_by_key */, queue); | |
562 | } else { | |
563 | merge_blocks_on_gpu(temp_keys.begin(), dummy, first, dummy, | |
564 | compare, count, block_size, | |
565 | false /* sort_by_key */, queue); | |
566 | } | |
567 | } | |
568 | ||
569 | if(result_in_temporary_buffer) { | |
570 | copy_async(temp_keys.begin(), temp_keys.end(), first, queue); | |
571 | } | |
572 | } | |
573 | ||
574 | template<class KeyIterator, class ValueIterator, class Compare> | |
575 | inline void merge_sort_by_key_on_gpu(KeyIterator keys_first, | |
576 | KeyIterator keys_last, | |
577 | ValueIterator values_first, | |
578 | Compare compare, | |
579 | command_queue &queue) | |
580 | { | |
581 | merge_sort_by_key_on_gpu( | |
582 | keys_first, keys_last, values_first, | |
583 | compare, false /* not stable */, queue | |
584 | ); | |
585 | } | |
586 | ||
587 | template<class Iterator, class Compare> | |
588 | inline void merge_sort_on_gpu(Iterator first, | |
589 | Iterator last, | |
590 | Compare compare, | |
591 | command_queue &queue) | |
592 | { | |
593 | merge_sort_on_gpu( | |
594 | first, last, compare, false /* not stable */, queue | |
595 | ); | |
596 | } | |
597 | ||
598 | } // end detail namespace | |
599 | } // end compute namespace | |
600 | } // end boost namespace | |
601 | ||
602 | #endif /* BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ */ |