]> git.proxmox.com Git - ceph.git/blame - ceph/src/boost/boost/compute/algorithm/detail/merge_sort_on_gpu.hpp
import new upstream nautilus stable release 14.2.8
[ceph.git] / ceph / src / boost / boost / compute / algorithm / detail / merge_sort_on_gpu.hpp
CommitLineData
7c673cae
FG
1//---------------------------------------------------------------------------//
2// Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
3//
4// Distributed under the Boost Software License, Version 1.0
5// See accompanying file LICENSE_1_0.txt or copy at
6// http://www.boost.org/LICENSE_1_0.txt
7//
8// See http://boostorg.github.com/compute for more information.
9//---------------------------------------------------------------------------//
10
11#ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
12#define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
13
14#include <algorithm>
15
16#include <boost/compute/kernel.hpp>
17#include <boost/compute/program.hpp>
18#include <boost/compute/command_queue.hpp>
19#include <boost/compute/container/vector.hpp>
20#include <boost/compute/memory/local_buffer.hpp>
21#include <boost/compute/detail/meta_kernel.hpp>
22#include <boost/compute/detail/iterator_range_size.hpp>
23
24namespace boost {
25namespace compute {
26namespace detail {
27
28template<class KeyType, class ValueType>
29inline size_t pick_bitonic_block_sort_block_size(size_t proposed_wg,
30 size_t lmem_size,
31 bool sort_by_key)
32{
33 size_t n = proposed_wg;
34
35 size_t lmem_required = n * sizeof(KeyType);
36 if(sort_by_key) {
37 lmem_required += n * sizeof(ValueType);
38 }
39
40 // try to force at least 4 work-groups of >64 elements
41 // for better occupancy
42 while(lmem_size < (lmem_required * 4) && (n > 64)) {
43 n /= 2;
44 lmem_required = n * sizeof(KeyType);
45 }
46 while(lmem_size < lmem_required && (n != 1)) {
47 n /= 2;
48 if(n < 1) n = 1;
49 lmem_required = n * sizeof(KeyType);
50 }
51
52 if(n < 2) { return 1; }
53 else if(n < 4) { return 2; }
54 else if(n < 8) { return 4; }
55 else if(n < 16) { return 8; }
56 else if(n < 32) { return 16; }
57 else if(n < 64) { return 32; }
58 else if(n < 128) { return 64; }
59 else if(n < 256) { return 128; }
60 else { return 256; }
61}
62
63
64/// Performs bitonic block sort according to \p compare.
65///
66/// Since bitonic sort can be only performed when input size is equal to 2^n,
67/// in this case input size is block size (\p work_group_size), we would have
68/// to require \p count be a exact multiple of block size. That would not be
69/// great.
70/// Instead, bitonic sort kernel is merged with odd-even merge sort so if the
71/// last block is not equal to 2^n (where n is some natural number) the odd-even
72/// sort is performed for that block. That way bitonic_block_sort() works for
73/// input of any size. Block size (\p work_group_size) still have to be equal
74/// to 2^n.
75///
76/// This is NOT stable.
77///
78/// \param keys_first first key element in the range to sort
79/// \param values_first first value element in the range to sort
80/// \param compare comparison function for keys
81/// \param count number of elements in the range; count > 0
82/// \param work_group_size size of the work group, also the block size; must be
83/// equal to n^2 where n is natural number
84/// \param queue command queue to perform the operation
85template<class KeyIterator, class ValueIterator, class Compare>
86inline size_t bitonic_block_sort(KeyIterator keys_first,
87 ValueIterator values_first,
88 Compare compare,
89 const size_t count,
90 const bool sort_by_key,
91 command_queue &queue)
92{
93 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
b32b8144 94 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
7c673cae
FG
95
96 meta_kernel k("bitonic_block_sort");
97 size_t count_arg = k.add_arg<const uint_>("count");
98
99 size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "lkeys");
100 size_t local_vals_arg = 0;
101 if(sort_by_key) {
102 local_vals_arg = k.add_arg<uchar_ *>(memory_object::local_memory, "lidx");
103 }
104
105 k <<
106 // Work item global and local ids
107 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
108 k.decl<const uint_>("lid") << " = get_local_id(0);\n";
109
110 // declare my_key and my_value
111 k <<
112 k.decl<key_type>("my_key") << ";\n";
113 // Instead of copying values (my_value) in local memory with keys
114 // we save local index (uchar) and copy my_value at the end at
115 // final index. This saves local memory.
116 if(sort_by_key)
117 {
118 k <<
119 k.decl<uchar_>("my_index") << " = (uchar)(lid);\n";
120 }
121
122 // load key
123 k <<
124 "if(gid < count) {\n" <<
125 k.var<key_type>("my_key") << " = " <<
126 keys_first[k.var<const uint_>("gid")] << ";\n" <<
127 "}\n";
128
129 // load key and index to local memory
130 k <<
131 "lkeys[lid] = my_key;\n";
132 if(sort_by_key)
133 {
134 k <<
135 "lidx[lid] = my_index;\n";
136 }
137 k <<
138 k.decl<const uint_>("offset") << " = get_group_id(0) * get_local_size(0);\n" <<
139 k.decl<const uint_>("n") << " = min((uint)(get_local_size(0)),(count - offset));\n";
140
141 // When work group size is a power of 2 bitonic sorter can be used;
142 // otherwise, slower odd-even sort is used.
143
144 k <<
145 // check if n is power of 2
146 "if(((n != 0) && ((n & (~n + 1)) == n))) {\n";
147
148 // bitonic sort, not stable
149 k <<
150 // wait for keys and vals to be stored in local memory
151 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
152
153 "#pragma unroll\n" <<
154 "for(" <<
155 k.decl<uint_>("length") << " = 1; " <<
156 "length < n; " <<
157 "length <<= 1" <<
158 ") {\n" <<
159 // direction of sort: false -> asc, true -> desc
160 k.decl<bool>("direction") << "= ((lid & (length<<1)) != 0);\n" <<
161 "for(" <<
162 k.decl<uint_>("k") << " = length; " <<
163 "k > 0; " <<
164 "k >>= 1" <<
165 ") {\n" <<
166
167 // sibling to compare with my key
168 k.decl<uint_>("sibling_idx") << " = lid ^ k;\n" <<
169 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
170 k.decl<bool>("compare") << " = " <<
171 compare(k.var<key_type>("sibling_key"),
172 k.var<key_type>("my_key")) << ";\n" <<
92f5a8d4
TL
173 k.decl<bool>("equal") << " = !(compare || " <<
174 compare(k.var<key_type>("my_key"),
175 k.var<key_type>("sibling_key")) << ");\n" <<
7c673cae
FG
176 k.decl<bool>("swap") <<
177 " = compare ^ (sibling_idx < lid) ^ direction;\n" <<
92f5a8d4 178 "swap = equal ? false : swap;\n" <<
7c673cae
FG
179 "my_key = swap ? sibling_key : my_key;\n";
180 if(sort_by_key)
181 {
182 k <<
183 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
184 }
185 k <<
186 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
187 "lkeys[lid] = my_key;\n";
188 if(sort_by_key)
189 {
190 k <<
191 "lidx[lid] = my_index;\n";
192 }
193 k <<
194 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
195 "}\n" <<
196 "}\n";
197
198 // end of bitonic sort
199
200 // odd-even sort, not stable
201 k <<
202 "}\n" <<
203 "else { \n";
204
205 k <<
206 k.decl<bool>("lid_is_even") << " = (lid%2) == 0;\n" <<
207 k.decl<uint_>("oddsibling_idx") << " = " <<
208 "(lid_is_even) ? max(lid,(uint)(1)) - 1 : min(lid+1,n-1);\n" <<
209 k.decl<uint_>("evensibling_idx") << " = " <<
210 "(lid_is_even) ? min(lid+1,n-1) : max(lid,(uint)(1)) - 1;\n" <<
211
212 // wait for keys and vals to be stored in local memory
213 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
214
215 "#pragma unroll\n" <<
216 "for(" <<
217 k.decl<uint_>("i") << " = 0; " <<
218 "i < n; " <<
219 "i++" <<
220 ") {\n" <<
221 k.decl<uint_>("sibling_idx") <<
222 " = i%2 == 0 ? evensibling_idx : oddsibling_idx;\n" <<
223 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
224 k.decl<bool>("compare") << " = " <<
225 compare(k.var<key_type>("sibling_key"),
226 k.var<key_type>("my_key")) << ";\n" <<
92f5a8d4
TL
227 k.decl<bool>("equal") << " = !(compare || " <<
228 compare(k.var<key_type>("my_key"),
229 k.var<key_type>("sibling_key")) << ");\n" <<
7c673cae
FG
230 k.decl<bool>("swap") <<
231 " = compare ^ (sibling_idx < lid);\n" <<
92f5a8d4 232 "swap = equal ? false : swap;\n" <<
7c673cae
FG
233 "my_key = swap ? sibling_key : my_key;\n";
234 if(sort_by_key)
235 {
236 k <<
237 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
238 }
239 k <<
240 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
241 "lkeys[lid] = my_key;\n";
242 if(sort_by_key)
243 {
244 k <<
245 "lidx[lid] = my_index;\n";
246 }
247 k <<
248 "barrier(CLK_LOCAL_MEM_FENCE);\n"
249 "}\n" << // for
250
251 "}\n"; // else
252 // end of odd-even sort
253
254 // save key and value
255 k <<
256 "if(gid < count) {\n" <<
257 keys_first[k.var<const uint_>("gid")] << " = " <<
258 k.var<key_type>("my_key") << ";\n";
259 if(sort_by_key)
260 {
b32b8144
FG
261 k <<
262 k.decl<value_type>("my_value") << " = " <<
263 values_first[k.var<const uint_>("offset + my_index")] << ";\n" <<
264 "barrier(CLK_GLOBAL_MEM_FENCE);\n" <<
265 values_first[k.var<const uint_>("gid")] << " = my_value;\n";
7c673cae
FG
266 }
267 k <<
268 // end if
269 "}\n";
270
271 const context &context = queue.get_context();
272 const device &device = queue.get_device();
273 ::boost::compute::kernel kernel = k.compile(context);
274
275 const size_t work_group_size =
276 pick_bitonic_block_sort_block_size<key_type, uchar_>(
277 kernel.get_work_group_info<size_t>(
278 device, CL_KERNEL_WORK_GROUP_SIZE
279 ),
280 device.get_info<size_t>(CL_DEVICE_LOCAL_MEM_SIZE),
281 sort_by_key
282 );
283
284 const size_t global_size =
285 work_group_size * static_cast<size_t>(
286 std::ceil(float(count) / work_group_size)
287 );
288
289 kernel.set_arg(count_arg, static_cast<uint_>(count));
290 kernel.set_arg(local_keys_arg, local_buffer<key_type>(work_group_size));
291 if(sort_by_key) {
292 kernel.set_arg(local_vals_arg, local_buffer<uchar_>(work_group_size));
293 }
294
295 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
296 // return size of the block
297 return work_group_size;
298}
299
300template<class KeyIterator, class ValueIterator, class Compare>
301inline size_t block_sort(KeyIterator keys_first,
302 ValueIterator values_first,
303 Compare compare,
304 const size_t count,
305 const bool sort_by_key,
306 const bool stable,
307 command_queue &queue)
308{
309 if(stable) {
310 // TODO: Implement stable block sort (stable odd-even merge sort)
311 return size_t(1);
312 }
313 return bitonic_block_sort(
314 keys_first, values_first,
315 compare, count,
316 sort_by_key, queue
317 );
318}
319
320/// space: O(n + m); n - number of keys, m - number of values
321template<class KeyIterator, class ValueIterator, class Compare>
322inline void merge_blocks_on_gpu(KeyIterator keys_first,
323 ValueIterator values_first,
324 KeyIterator out_keys_first,
325 ValueIterator out_values_first,
326 Compare compare,
327 const size_t count,
328 const size_t block_size,
329 const bool sort_by_key,
330 command_queue &queue)
331{
332 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
333 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
334
335 meta_kernel k("merge_blocks");
336 size_t count_arg = k.add_arg<const uint_>("count");
337 size_t block_size_arg = k.add_arg<const uint_>("block_size");
338
339 k <<
340 // get global id
341 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
342 "if(gid >= count) {\n" <<
343 "return;\n" <<
344 "}\n" <<
345
346 k.decl<const key_type>("my_key") << " = " <<
347 keys_first[k.var<const uint_>("gid")] << ";\n";
348
349 if(sort_by_key) {
350 k <<
351 k.decl<const value_type>("my_value") << " = " <<
352 values_first[k.var<const uint_>("gid")] << ";\n";
353 }
354
355 k <<
356 // get my block idx
357 k.decl<const uint_>("my_block_idx") << " = gid / block_size;\n" <<
358 k.decl<const bool>("my_block_idx_is_odd") << " = " <<
359 "my_block_idx & 0x1;\n" <<
360
361 k.decl<const uint_>("other_block_idx") << " = " <<
362 // if(my_block_idx is odd) {} else {}
363 "my_block_idx_is_odd ? my_block_idx - 1 : my_block_idx + 1;\n" <<
364
365 // get ranges of my block and the other block
366 // [my_block_start; my_block_end)
367 // [other_block_start; other_block_end)
368 k.decl<const uint_>("my_block_start") << " = " <<
369 "min(my_block_idx * block_size, count);\n" << // including
370 k.decl<const uint_>("my_block_end") << " = " <<
371 "min((my_block_idx + 1) * block_size, count);\n" << // excluding
372
373 k.decl<const uint_>("other_block_start") << " = " <<
374 "min(other_block_idx * block_size, count);\n" << // including
375 k.decl<const uint_>("other_block_end") << " = " <<
376 "min((other_block_idx + 1) * block_size, count);\n" << // excluding
377
378 // other block is empty, nothing to merge here
379 "if(other_block_start == count){\n" <<
380 out_keys_first[k.var<uint_>("gid")] << " = my_key;\n";
381 if(sort_by_key) {
382 k <<
383 out_values_first[k.var<uint_>("gid")] << " = my_value;\n";
384 }
385
386 k <<
387 "return;\n" <<
388 "}\n" <<
389
390 // lower bound
391 // left_idx - lower bound
392 k.decl<uint_>("left_idx") << " = other_block_start;\n" <<
393 k.decl<uint_>("right_idx") << " = other_block_end;\n" <<
394 "while(left_idx < right_idx) {\n" <<
395 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
396 k.decl<key_type>("mid_key") << " = " <<
397 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
398 k.decl<bool>("smaller") << " = " <<
399 compare(k.var<key_type>("mid_key"),
400 k.var<key_type>("my_key")) << ";\n" <<
401 "left_idx = smaller ? mid_idx + 1 : left_idx;\n" <<
402 "right_idx = smaller ? right_idx : mid_idx;\n" <<
403 "}\n" <<
404 // left_idx is found position in other block
405
406 // if my_block is odd we need to get the upper bound
407 "right_idx = other_block_end;\n" <<
408 "if(my_block_idx_is_odd && left_idx != right_idx) {\n" <<
409 k.decl<key_type>("upper_key") << " = " <<
410 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
411 "while(" <<
412 "!(" << compare(k.var<key_type>("upper_key"),
413 k.var<key_type>("my_key")) <<
414 ") && " <<
415 "!(" << compare(k.var<key_type>("my_key"),
416 k.var<key_type>("upper_key")) <<
417 ") && " <<
418 "left_idx < right_idx" <<
419 ")" <<
420 "{\n" <<
421 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
422 k.decl<key_type>("mid_key") << " = " <<
423 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
424 k.decl<bool>("equal") << " = " <<
425 "!(" << compare(k.var<key_type>("mid_key"),
426 k.var<key_type>("my_key")) <<
427 ") && " <<
428 "!(" << compare(k.var<key_type>("my_key"),
429 k.var<key_type>("mid_key")) <<
430 ");\n" <<
431 "left_idx = equal ? mid_idx + 1 : left_idx + 1;\n" <<
432 "right_idx = equal ? right_idx : mid_idx;\n" <<
b32b8144 433 "upper_key = " <<
7c673cae
FG
434 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
435 "}\n" <<
436 "}\n" <<
437
438 k.decl<uint_>("offset") << " = 0;\n" <<
439 "offset += gid - my_block_start;\n" <<
440 "offset += left_idx - other_block_start;\n" <<
441 "offset += min(my_block_start, other_block_start);\n" <<
442 out_keys_first[k.var<uint_>("offset")] << " = my_key;\n";
443 if(sort_by_key) {
444 k <<
445 out_values_first[k.var<uint_>("offset")] << " = my_value;\n";
446 }
447
448 const context &context = queue.get_context();
449 ::boost::compute::kernel kernel = k.compile(context);
450
451 const size_t work_group_size = (std::min)(
452 size_t(256),
453 kernel.get_work_group_info<size_t>(
454 queue.get_device(), CL_KERNEL_WORK_GROUP_SIZE
455 )
456 );
457 const size_t global_size =
458 work_group_size * static_cast<size_t>(
459 std::ceil(float(count) / work_group_size)
460 );
461
462 kernel.set_arg(count_arg, static_cast<uint_>(count));
463 kernel.set_arg(block_size_arg, static_cast<uint_>(block_size));
464 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
465}
466
467template<class KeyIterator, class ValueIterator, class Compare>
468inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
469 KeyIterator keys_last,
470 ValueIterator values_first,
471 Compare compare,
472 bool stable,
473 command_queue &queue)
474{
475 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
476 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
477
478 size_t count = iterator_range_size(keys_first, keys_last);
479 if(count < 2){
480 return;
481 }
482
483 size_t block_size =
484 block_sort(
485 keys_first, values_first,
486 compare, count,
487 true /* sort_by_key */, stable /* stable */,
488 queue
489 );
490
491 // for small input size only block sort is performed
492 if(count <= block_size) {
493 return;
494 }
495
496 const context &context = queue.get_context();
497
498 bool result_in_temporary_buffer = false;
499 ::boost::compute::vector<key_type> temp_keys(count, context);
500 ::boost::compute::vector<value_type> temp_values(count, context);
501
502 for(; block_size < count; block_size *= 2) {
503 result_in_temporary_buffer = !result_in_temporary_buffer;
504 if(result_in_temporary_buffer) {
505 merge_blocks_on_gpu(keys_first, values_first,
506 temp_keys.begin(), temp_values.begin(),
507 compare, count, block_size,
508 true /* sort_by_key */, queue);
509 } else {
510 merge_blocks_on_gpu(temp_keys.begin(), temp_values.begin(),
511 keys_first, values_first,
512 compare, count, block_size,
513 true /* sort_by_key */, queue);
514 }
515 }
516
517 if(result_in_temporary_buffer) {
518 copy_async(temp_keys.begin(), temp_keys.end(), keys_first, queue);
519 copy_async(temp_values.begin(), temp_values.end(), values_first, queue);
520 }
521}
522
523template<class Iterator, class Compare>
524inline void merge_sort_on_gpu(Iterator first,
525 Iterator last,
526 Compare compare,
527 bool stable,
528 command_queue &queue)
529{
530 typedef typename std::iterator_traits<Iterator>::value_type key_type;
531
532 size_t count = iterator_range_size(first, last);
533 if(count < 2){
534 return;
535 }
536
537 Iterator dummy;
538 size_t block_size =
539 block_sort(
540 first, dummy,
541 compare, count,
542 false /* sort_by_key */, stable /* stable */,
543 queue
544 );
545
546 // for small input size only block sort is performed
547 if(count <= block_size) {
548 return;
549 }
550
551 const context &context = queue.get_context();
552
553 bool result_in_temporary_buffer = false;
554 ::boost::compute::vector<key_type> temp_keys(count, context);
555
556 for(; block_size < count; block_size *= 2) {
557 result_in_temporary_buffer = !result_in_temporary_buffer;
558 if(result_in_temporary_buffer) {
559 merge_blocks_on_gpu(first, dummy, temp_keys.begin(), dummy,
560 compare, count, block_size,
561 false /* sort_by_key */, queue);
562 } else {
563 merge_blocks_on_gpu(temp_keys.begin(), dummy, first, dummy,
564 compare, count, block_size,
565 false /* sort_by_key */, queue);
566 }
567 }
568
569 if(result_in_temporary_buffer) {
570 copy_async(temp_keys.begin(), temp_keys.end(), first, queue);
571 }
572}
573
574template<class KeyIterator, class ValueIterator, class Compare>
575inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
576 KeyIterator keys_last,
577 ValueIterator values_first,
578 Compare compare,
579 command_queue &queue)
580{
581 merge_sort_by_key_on_gpu(
582 keys_first, keys_last, values_first,
583 compare, false /* not stable */, queue
584 );
585}
586
587template<class Iterator, class Compare>
588inline void merge_sort_on_gpu(Iterator first,
589 Iterator last,
590 Compare compare,
591 command_queue &queue)
592{
593 merge_sort_on_gpu(
594 first, last, compare, false /* not stable */, queue
595 );
596}
597
598} // end detail namespace
599} // end compute namespace
600} // end boost namespace
601
602#endif /* BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ */