]> git.proxmox.com Git - ceph.git/blob - ceph/src/boost/libs/compute/include/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp
add subtree-ish sources for 12.0.3
[ceph.git] / ceph / src / boost / libs / compute / include / boost / compute / algorithm / detail / merge_sort_on_gpu.hpp
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
13
14 #include <algorithm>
15
16 #include <boost/compute/kernel.hpp>
17 #include <boost/compute/program.hpp>
18 #include <boost/compute/command_queue.hpp>
19 #include <boost/compute/container/vector.hpp>
20 #include <boost/compute/memory/local_buffer.hpp>
21 #include <boost/compute/detail/meta_kernel.hpp>
22 #include <boost/compute/detail/iterator_range_size.hpp>
23
24 namespace boost {
25 namespace compute {
26 namespace detail {
27
28 template<class KeyType, class ValueType>
29 inline size_t pick_bitonic_block_sort_block_size(size_t proposed_wg,
30 size_t lmem_size,
31 bool sort_by_key)
32 {
33 size_t n = proposed_wg;
34
35 size_t lmem_required = n * sizeof(KeyType);
36 if(sort_by_key) {
37 lmem_required += n * sizeof(ValueType);
38 }
39
40 // try to force at least 4 work-groups of >64 elements
41 // for better occupancy
42 while(lmem_size < (lmem_required * 4) && (n > 64)) {
43 n /= 2;
44 lmem_required = n * sizeof(KeyType);
45 }
46 while(lmem_size < lmem_required && (n != 1)) {
47 n /= 2;
48 if(n < 1) n = 1;
49 lmem_required = n * sizeof(KeyType);
50 }
51
52 if(n < 2) { return 1; }
53 else if(n < 4) { return 2; }
54 else if(n < 8) { return 4; }
55 else if(n < 16) { return 8; }
56 else if(n < 32) { return 16; }
57 else if(n < 64) { return 32; }
58 else if(n < 128) { return 64; }
59 else if(n < 256) { return 128; }
60 else { return 256; }
61 }
62
63
64 /// Performs bitonic block sort according to \p compare.
65 ///
66 /// Since bitonic sort can be only performed when input size is equal to 2^n,
67 /// in this case input size is block size (\p work_group_size), we would have
68 /// to require \p count be a exact multiple of block size. That would not be
69 /// great.
70 /// Instead, bitonic sort kernel is merged with odd-even merge sort so if the
71 /// last block is not equal to 2^n (where n is some natural number) the odd-even
72 /// sort is performed for that block. That way bitonic_block_sort() works for
73 /// input of any size. Block size (\p work_group_size) still have to be equal
74 /// to 2^n.
75 ///
76 /// This is NOT stable.
77 ///
78 /// \param keys_first first key element in the range to sort
79 /// \param values_first first value element in the range to sort
80 /// \param compare comparison function for keys
81 /// \param count number of elements in the range; count > 0
82 /// \param work_group_size size of the work group, also the block size; must be
83 /// equal to n^2 where n is natural number
84 /// \param queue command queue to perform the operation
85 template<class KeyIterator, class ValueIterator, class Compare>
86 inline size_t bitonic_block_sort(KeyIterator keys_first,
87 ValueIterator values_first,
88 Compare compare,
89 const size_t count,
90 const bool sort_by_key,
91 command_queue &queue)
92 {
93 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
94
95 meta_kernel k("bitonic_block_sort");
96 size_t count_arg = k.add_arg<const uint_>("count");
97
98 size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "lkeys");
99 size_t local_vals_arg = 0;
100 if(sort_by_key) {
101 local_vals_arg = k.add_arg<uchar_ *>(memory_object::local_memory, "lidx");
102 }
103
104 k <<
105 // Work item global and local ids
106 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
107 k.decl<const uint_>("lid") << " = get_local_id(0);\n";
108
109 // declare my_key and my_value
110 k <<
111 k.decl<key_type>("my_key") << ";\n";
112 // Instead of copying values (my_value) in local memory with keys
113 // we save local index (uchar) and copy my_value at the end at
114 // final index. This saves local memory.
115 if(sort_by_key)
116 {
117 k <<
118 k.decl<uchar_>("my_index") << " = (uchar)(lid);\n";
119 }
120
121 // load key
122 k <<
123 "if(gid < count) {\n" <<
124 k.var<key_type>("my_key") << " = " <<
125 keys_first[k.var<const uint_>("gid")] << ";\n" <<
126 "}\n";
127
128 // load key and index to local memory
129 k <<
130 "lkeys[lid] = my_key;\n";
131 if(sort_by_key)
132 {
133 k <<
134 "lidx[lid] = my_index;\n";
135 }
136 k <<
137 k.decl<const uint_>("offset") << " = get_group_id(0) * get_local_size(0);\n" <<
138 k.decl<const uint_>("n") << " = min((uint)(get_local_size(0)),(count - offset));\n";
139
140 // When work group size is a power of 2 bitonic sorter can be used;
141 // otherwise, slower odd-even sort is used.
142
143 k <<
144 // check if n is power of 2
145 "if(((n != 0) && ((n & (~n + 1)) == n))) {\n";
146
147 // bitonic sort, not stable
148 k <<
149 // wait for keys and vals to be stored in local memory
150 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
151
152 "#pragma unroll\n" <<
153 "for(" <<
154 k.decl<uint_>("length") << " = 1; " <<
155 "length < n; " <<
156 "length <<= 1" <<
157 ") {\n" <<
158 // direction of sort: false -> asc, true -> desc
159 k.decl<bool>("direction") << "= ((lid & (length<<1)) != 0);\n" <<
160 "for(" <<
161 k.decl<uint_>("k") << " = length; " <<
162 "k > 0; " <<
163 "k >>= 1" <<
164 ") {\n" <<
165
166 // sibling to compare with my key
167 k.decl<uint_>("sibling_idx") << " = lid ^ k;\n" <<
168 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
169 k.decl<bool>("compare") << " = " <<
170 compare(k.var<key_type>("sibling_key"),
171 k.var<key_type>("my_key")) << ";\n" <<
172 k.decl<bool>("swap") <<
173 " = compare ^ (sibling_idx < lid) ^ direction;\n" <<
174 "my_key = swap ? sibling_key : my_key;\n";
175 if(sort_by_key)
176 {
177 k <<
178 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
179 }
180 k <<
181 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
182 "lkeys[lid] = my_key;\n";
183 if(sort_by_key)
184 {
185 k <<
186 "lidx[lid] = my_index;\n";
187 }
188 k <<
189 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
190 "}\n" <<
191 "}\n";
192
193 // end of bitonic sort
194
195 // odd-even sort, not stable
196 k <<
197 "}\n" <<
198 "else { \n";
199
200 k <<
201 k.decl<bool>("lid_is_even") << " = (lid%2) == 0;\n" <<
202 k.decl<uint_>("oddsibling_idx") << " = " <<
203 "(lid_is_even) ? max(lid,(uint)(1)) - 1 : min(lid+1,n-1);\n" <<
204 k.decl<uint_>("evensibling_idx") << " = " <<
205 "(lid_is_even) ? min(lid+1,n-1) : max(lid,(uint)(1)) - 1;\n" <<
206
207 // wait for keys and vals to be stored in local memory
208 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
209
210 "#pragma unroll\n" <<
211 "for(" <<
212 k.decl<uint_>("i") << " = 0; " <<
213 "i < n; " <<
214 "i++" <<
215 ") {\n" <<
216 k.decl<uint_>("sibling_idx") <<
217 " = i%2 == 0 ? evensibling_idx : oddsibling_idx;\n" <<
218 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
219 k.decl<bool>("compare") << " = " <<
220 compare(k.var<key_type>("sibling_key"),
221 k.var<key_type>("my_key")) << ";\n" <<
222 k.decl<bool>("swap") <<
223 " = compare ^ (sibling_idx < lid);\n" <<
224 "my_key = swap ? sibling_key : my_key;\n";
225 if(sort_by_key)
226 {
227 k <<
228 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
229 }
230 k <<
231 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
232 "lkeys[lid] = my_key;\n";
233 if(sort_by_key)
234 {
235 k <<
236 "lidx[lid] = my_index;\n";
237 }
238 k <<
239 "barrier(CLK_LOCAL_MEM_FENCE);\n"
240 "}\n" << // for
241
242 "}\n"; // else
243 // end of odd-even sort
244
245 // save key and value
246 k <<
247 "if(gid < count) {\n" <<
248 keys_first[k.var<const uint_>("gid")] << " = " <<
249 k.var<key_type>("my_key") << ";\n";
250 if(sort_by_key)
251 {
252 k << values_first[k.var<const uint_>("gid")] << " = " <<
253 values_first[k.var<const uint_>("offset + my_index")] << ";\n";
254 }
255 k <<
256 // end if
257 "}\n";
258
259 const context &context = queue.get_context();
260 const device &device = queue.get_device();
261 ::boost::compute::kernel kernel = k.compile(context);
262
263 const size_t work_group_size =
264 pick_bitonic_block_sort_block_size<key_type, uchar_>(
265 kernel.get_work_group_info<size_t>(
266 device, CL_KERNEL_WORK_GROUP_SIZE
267 ),
268 device.get_info<size_t>(CL_DEVICE_LOCAL_MEM_SIZE),
269 sort_by_key
270 );
271
272 const size_t global_size =
273 work_group_size * static_cast<size_t>(
274 std::ceil(float(count) / work_group_size)
275 );
276
277 kernel.set_arg(count_arg, static_cast<uint_>(count));
278 kernel.set_arg(local_keys_arg, local_buffer<key_type>(work_group_size));
279 if(sort_by_key) {
280 kernel.set_arg(local_vals_arg, local_buffer<uchar_>(work_group_size));
281 }
282
283 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
284 // return size of the block
285 return work_group_size;
286 }
287
288 template<class KeyIterator, class ValueIterator, class Compare>
289 inline size_t block_sort(KeyIterator keys_first,
290 ValueIterator values_first,
291 Compare compare,
292 const size_t count,
293 const bool sort_by_key,
294 const bool stable,
295 command_queue &queue)
296 {
297 if(stable) {
298 // TODO: Implement stable block sort (stable odd-even merge sort)
299 return size_t(1);
300 }
301 return bitonic_block_sort(
302 keys_first, values_first,
303 compare, count,
304 sort_by_key, queue
305 );
306 }
307
308 /// space: O(n + m); n - number of keys, m - number of values
309 template<class KeyIterator, class ValueIterator, class Compare>
310 inline void merge_blocks_on_gpu(KeyIterator keys_first,
311 ValueIterator values_first,
312 KeyIterator out_keys_first,
313 ValueIterator out_values_first,
314 Compare compare,
315 const size_t count,
316 const size_t block_size,
317 const bool sort_by_key,
318 command_queue &queue)
319 {
320 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
321 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
322
323 meta_kernel k("merge_blocks");
324 size_t count_arg = k.add_arg<const uint_>("count");
325 size_t block_size_arg = k.add_arg<const uint_>("block_size");
326
327 k <<
328 // get global id
329 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
330 "if(gid >= count) {\n" <<
331 "return;\n" <<
332 "}\n" <<
333
334 k.decl<const key_type>("my_key") << " = " <<
335 keys_first[k.var<const uint_>("gid")] << ";\n";
336
337 if(sort_by_key) {
338 k <<
339 k.decl<const value_type>("my_value") << " = " <<
340 values_first[k.var<const uint_>("gid")] << ";\n";
341 }
342
343 k <<
344 // get my block idx
345 k.decl<const uint_>("my_block_idx") << " = gid / block_size;\n" <<
346 k.decl<const bool>("my_block_idx_is_odd") << " = " <<
347 "my_block_idx & 0x1;\n" <<
348
349 k.decl<const uint_>("other_block_idx") << " = " <<
350 // if(my_block_idx is odd) {} else {}
351 "my_block_idx_is_odd ? my_block_idx - 1 : my_block_idx + 1;\n" <<
352
353 // get ranges of my block and the other block
354 // [my_block_start; my_block_end)
355 // [other_block_start; other_block_end)
356 k.decl<const uint_>("my_block_start") << " = " <<
357 "min(my_block_idx * block_size, count);\n" << // including
358 k.decl<const uint_>("my_block_end") << " = " <<
359 "min((my_block_idx + 1) * block_size, count);\n" << // excluding
360
361 k.decl<const uint_>("other_block_start") << " = " <<
362 "min(other_block_idx * block_size, count);\n" << // including
363 k.decl<const uint_>("other_block_end") << " = " <<
364 "min((other_block_idx + 1) * block_size, count);\n" << // excluding
365
366 // other block is empty, nothing to merge here
367 "if(other_block_start == count){\n" <<
368 out_keys_first[k.var<uint_>("gid")] << " = my_key;\n";
369 if(sort_by_key) {
370 k <<
371 out_values_first[k.var<uint_>("gid")] << " = my_value;\n";
372 }
373
374 k <<
375 "return;\n" <<
376 "}\n" <<
377
378 // lower bound
379 // left_idx - lower bound
380 k.decl<uint_>("left_idx") << " = other_block_start;\n" <<
381 k.decl<uint_>("right_idx") << " = other_block_end;\n" <<
382 "while(left_idx < right_idx) {\n" <<
383 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
384 k.decl<key_type>("mid_key") << " = " <<
385 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
386 k.decl<bool>("smaller") << " = " <<
387 compare(k.var<key_type>("mid_key"),
388 k.var<key_type>("my_key")) << ";\n" <<
389 "left_idx = smaller ? mid_idx + 1 : left_idx;\n" <<
390 "right_idx = smaller ? right_idx : mid_idx;\n" <<
391 "}\n" <<
392 // left_idx is found position in other block
393
394 // if my_block is odd we need to get the upper bound
395 "right_idx = other_block_end;\n" <<
396 "if(my_block_idx_is_odd && left_idx != right_idx) {\n" <<
397 k.decl<key_type>("upper_key") << " = " <<
398 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
399 "while(" <<
400 "!(" << compare(k.var<key_type>("upper_key"),
401 k.var<key_type>("my_key")) <<
402 ") && " <<
403 "!(" << compare(k.var<key_type>("my_key"),
404 k.var<key_type>("upper_key")) <<
405 ") && " <<
406 "left_idx < right_idx" <<
407 ")" <<
408 "{\n" <<
409 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
410 k.decl<key_type>("mid_key") << " = " <<
411 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
412 k.decl<bool>("equal") << " = " <<
413 "!(" << compare(k.var<key_type>("mid_key"),
414 k.var<key_type>("my_key")) <<
415 ") && " <<
416 "!(" << compare(k.var<key_type>("my_key"),
417 k.var<key_type>("mid_key")) <<
418 ");\n" <<
419 "left_idx = equal ? mid_idx + 1 : left_idx + 1;\n" <<
420 "right_idx = equal ? right_idx : mid_idx;\n" <<
421 "upper_key = equal ? upper_key : " <<
422 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
423 "}\n" <<
424 "}\n" <<
425
426 k.decl<uint_>("offset") << " = 0;\n" <<
427 "offset += gid - my_block_start;\n" <<
428 "offset += left_idx - other_block_start;\n" <<
429 "offset += min(my_block_start, other_block_start);\n" <<
430 out_keys_first[k.var<uint_>("offset")] << " = my_key;\n";
431 if(sort_by_key) {
432 k <<
433 out_values_first[k.var<uint_>("offset")] << " = my_value;\n";
434 }
435
436 const context &context = queue.get_context();
437 ::boost::compute::kernel kernel = k.compile(context);
438
439 const size_t work_group_size = (std::min)(
440 size_t(256),
441 kernel.get_work_group_info<size_t>(
442 queue.get_device(), CL_KERNEL_WORK_GROUP_SIZE
443 )
444 );
445 const size_t global_size =
446 work_group_size * static_cast<size_t>(
447 std::ceil(float(count) / work_group_size)
448 );
449
450 kernel.set_arg(count_arg, static_cast<uint_>(count));
451 kernel.set_arg(block_size_arg, static_cast<uint_>(block_size));
452 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
453 }
454
455 template<class KeyIterator, class ValueIterator, class Compare>
456 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
457 KeyIterator keys_last,
458 ValueIterator values_first,
459 Compare compare,
460 bool stable,
461 command_queue &queue)
462 {
463 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
464 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
465
466 size_t count = iterator_range_size(keys_first, keys_last);
467 if(count < 2){
468 return;
469 }
470
471 size_t block_size =
472 block_sort(
473 keys_first, values_first,
474 compare, count,
475 true /* sort_by_key */, stable /* stable */,
476 queue
477 );
478
479 // for small input size only block sort is performed
480 if(count <= block_size) {
481 return;
482 }
483
484 const context &context = queue.get_context();
485
486 bool result_in_temporary_buffer = false;
487 ::boost::compute::vector<key_type> temp_keys(count, context);
488 ::boost::compute::vector<value_type> temp_values(count, context);
489
490 for(; block_size < count; block_size *= 2) {
491 result_in_temporary_buffer = !result_in_temporary_buffer;
492 if(result_in_temporary_buffer) {
493 merge_blocks_on_gpu(keys_first, values_first,
494 temp_keys.begin(), temp_values.begin(),
495 compare, count, block_size,
496 true /* sort_by_key */, queue);
497 } else {
498 merge_blocks_on_gpu(temp_keys.begin(), temp_values.begin(),
499 keys_first, values_first,
500 compare, count, block_size,
501 true /* sort_by_key */, queue);
502 }
503 }
504
505 if(result_in_temporary_buffer) {
506 copy_async(temp_keys.begin(), temp_keys.end(), keys_first, queue);
507 copy_async(temp_values.begin(), temp_values.end(), values_first, queue);
508 }
509 }
510
511 template<class Iterator, class Compare>
512 inline void merge_sort_on_gpu(Iterator first,
513 Iterator last,
514 Compare compare,
515 bool stable,
516 command_queue &queue)
517 {
518 typedef typename std::iterator_traits<Iterator>::value_type key_type;
519
520 size_t count = iterator_range_size(first, last);
521 if(count < 2){
522 return;
523 }
524
525 Iterator dummy;
526 size_t block_size =
527 block_sort(
528 first, dummy,
529 compare, count,
530 false /* sort_by_key */, stable /* stable */,
531 queue
532 );
533
534 // for small input size only block sort is performed
535 if(count <= block_size) {
536 return;
537 }
538
539 const context &context = queue.get_context();
540
541 bool result_in_temporary_buffer = false;
542 ::boost::compute::vector<key_type> temp_keys(count, context);
543
544 for(; block_size < count; block_size *= 2) {
545 result_in_temporary_buffer = !result_in_temporary_buffer;
546 if(result_in_temporary_buffer) {
547 merge_blocks_on_gpu(first, dummy, temp_keys.begin(), dummy,
548 compare, count, block_size,
549 false /* sort_by_key */, queue);
550 } else {
551 merge_blocks_on_gpu(temp_keys.begin(), dummy, first, dummy,
552 compare, count, block_size,
553 false /* sort_by_key */, queue);
554 }
555 }
556
557 if(result_in_temporary_buffer) {
558 copy_async(temp_keys.begin(), temp_keys.end(), first, queue);
559 }
560 }
561
562 template<class KeyIterator, class ValueIterator, class Compare>
563 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
564 KeyIterator keys_last,
565 ValueIterator values_first,
566 Compare compare,
567 command_queue &queue)
568 {
569 merge_sort_by_key_on_gpu(
570 keys_first, keys_last, values_first,
571 compare, false /* not stable */, queue
572 );
573 }
574
575 template<class Iterator, class Compare>
576 inline void merge_sort_on_gpu(Iterator first,
577 Iterator last,
578 Compare compare,
579 command_queue &queue)
580 {
581 merge_sort_on_gpu(
582 first, last, compare, false /* not stable */, queue
583 );
584 }
585
586 } // end detail namespace
587 } // end compute namespace
588 } // end boost namespace
589
590 #endif /* BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ */