1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com>
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_
16 #include <boost/compute/kernel.hpp>
17 #include <boost/compute/program.hpp>
18 #include <boost/compute/command_queue.hpp>
19 #include <boost/compute/container/vector.hpp>
20 #include <boost/compute/memory/local_buffer.hpp>
21 #include <boost/compute/detail/meta_kernel.hpp>
22 #include <boost/compute/detail/iterator_range_size.hpp>
28 template<class KeyType, class ValueType>
29 inline size_t pick_bitonic_block_sort_block_size(size_t proposed_wg,
33 size_t n = proposed_wg;
35 size_t lmem_required = n * sizeof(KeyType);
37 lmem_required += n * sizeof(ValueType);
40 // try to force at least 4 work-groups of >64 elements
41 // for better occupancy
42 while(lmem_size < (lmem_required * 4) && (n > 64)) {
44 lmem_required = n * sizeof(KeyType);
46 while(lmem_size < lmem_required && (n != 1)) {
49 lmem_required = n * sizeof(KeyType);
52 if(n < 2) { return 1; }
53 else if(n < 4) { return 2; }
54 else if(n < 8) { return 4; }
55 else if(n < 16) { return 8; }
56 else if(n < 32) { return 16; }
57 else if(n < 64) { return 32; }
58 else if(n < 128) { return 64; }
59 else if(n < 256) { return 128; }
64 /// Performs bitonic block sort according to \p compare.
66 /// Since bitonic sort can be only performed when input size is equal to 2^n,
67 /// in this case input size is block size (\p work_group_size), we would have
68 /// to require \p count be a exact multiple of block size. That would not be
70 /// Instead, bitonic sort kernel is merged with odd-even merge sort so if the
71 /// last block is not equal to 2^n (where n is some natural number) the odd-even
72 /// sort is performed for that block. That way bitonic_block_sort() works for
73 /// input of any size. Block size (\p work_group_size) still have to be equal
76 /// This is NOT stable.
78 /// \param keys_first first key element in the range to sort
79 /// \param values_first first value element in the range to sort
80 /// \param compare comparison function for keys
81 /// \param count number of elements in the range; count > 0
82 /// \param work_group_size size of the work group, also the block size; must be
83 /// equal to n^2 where n is natural number
84 /// \param queue command queue to perform the operation
85 template<class KeyIterator, class ValueIterator, class Compare>
86 inline size_t bitonic_block_sort(KeyIterator keys_first,
87 ValueIterator values_first,
90 const bool sort_by_key,
93 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
94 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
96 meta_kernel k("bitonic_block_sort");
97 size_t count_arg = k.add_arg<const uint_>("count");
99 size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "lkeys");
100 size_t local_vals_arg = 0;
102 local_vals_arg = k.add_arg<uchar_ *>(memory_object::local_memory, "lidx");
106 // Work item global and local ids
107 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
108 k.decl<const uint_>("lid") << " = get_local_id(0);\n";
110 // declare my_key and my_value
112 k.decl<key_type>("my_key") << ";\n";
113 // Instead of copying values (my_value) in local memory with keys
114 // we save local index (uchar) and copy my_value at the end at
115 // final index. This saves local memory.
119 k.decl<uchar_>("my_index") << " = (uchar)(lid);\n";
124 "if(gid < count) {\n" <<
125 k.var<key_type>("my_key") << " = " <<
126 keys_first[k.var<const uint_>("gid")] << ";\n" <<
129 // load key and index to local memory
131 "lkeys[lid] = my_key;\n";
135 "lidx[lid] = my_index;\n";
138 k.decl<const uint_>("offset") << " = get_group_id(0) * get_local_size(0);\n" <<
139 k.decl<const uint_>("n") << " = min((uint)(get_local_size(0)),(count - offset));\n";
141 // When work group size is a power of 2 bitonic sorter can be used;
142 // otherwise, slower odd-even sort is used.
145 // check if n is power of 2
146 "if(((n != 0) && ((n & (~n + 1)) == n))) {\n";
148 // bitonic sort, not stable
150 // wait for keys and vals to be stored in local memory
151 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
153 "#pragma unroll\n" <<
155 k.decl<uint_>("length") << " = 1; " <<
159 // direction of sort: false -> asc, true -> desc
160 k.decl<bool>("direction") << "= ((lid & (length<<1)) != 0);\n" <<
162 k.decl<uint_>("k") << " = length; " <<
167 // sibling to compare with my key
168 k.decl<uint_>("sibling_idx") << " = lid ^ k;\n" <<
169 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
170 k.decl<bool>("compare") << " = " <<
171 compare(k.var<key_type>("sibling_key"),
172 k.var<key_type>("my_key")) << ";\n" <<
173 k.decl<bool>("swap") <<
174 " = compare ^ (sibling_idx < lid) ^ direction;\n" <<
175 "my_key = swap ? sibling_key : my_key;\n";
179 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
182 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
183 "lkeys[lid] = my_key;\n";
187 "lidx[lid] = my_index;\n";
190 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
194 // end of bitonic sort
196 // odd-even sort, not stable
202 k.decl<bool>("lid_is_even") << " = (lid%2) == 0;\n" <<
203 k.decl<uint_>("oddsibling_idx") << " = " <<
204 "(lid_is_even) ? max(lid,(uint)(1)) - 1 : min(lid+1,n-1);\n" <<
205 k.decl<uint_>("evensibling_idx") << " = " <<
206 "(lid_is_even) ? min(lid+1,n-1) : max(lid,(uint)(1)) - 1;\n" <<
208 // wait for keys and vals to be stored in local memory
209 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
211 "#pragma unroll\n" <<
213 k.decl<uint_>("i") << " = 0; " <<
217 k.decl<uint_>("sibling_idx") <<
218 " = i%2 == 0 ? evensibling_idx : oddsibling_idx;\n" <<
219 k.decl<key_type>("sibling_key") << " = lkeys[sibling_idx];\n" <<
220 k.decl<bool>("compare") << " = " <<
221 compare(k.var<key_type>("sibling_key"),
222 k.var<key_type>("my_key")) << ";\n" <<
223 k.decl<bool>("swap") <<
224 " = compare ^ (sibling_idx < lid);\n" <<
225 "my_key = swap ? sibling_key : my_key;\n";
229 "my_index = swap ? lidx[sibling_idx] : my_index;\n";
232 "barrier(CLK_LOCAL_MEM_FENCE);\n" <<
233 "lkeys[lid] = my_key;\n";
237 "lidx[lid] = my_index;\n";
240 "barrier(CLK_LOCAL_MEM_FENCE);\n"
244 // end of odd-even sort
246 // save key and value
248 "if(gid < count) {\n" <<
249 keys_first[k.var<const uint_>("gid")] << " = " <<
250 k.var<key_type>("my_key") << ";\n";
254 k.decl<value_type>("my_value") << " = " <<
255 values_first[k.var<const uint_>("offset + my_index")] << ";\n" <<
256 "barrier(CLK_GLOBAL_MEM_FENCE);\n" <<
257 values_first[k.var<const uint_>("gid")] << " = my_value;\n";
263 const context &context = queue.get_context();
264 const device &device = queue.get_device();
265 ::boost::compute::kernel kernel = k.compile(context);
267 const size_t work_group_size =
268 pick_bitonic_block_sort_block_size<key_type, uchar_>(
269 kernel.get_work_group_info<size_t>(
270 device, CL_KERNEL_WORK_GROUP_SIZE
272 device.get_info<size_t>(CL_DEVICE_LOCAL_MEM_SIZE),
276 const size_t global_size =
277 work_group_size * static_cast<size_t>(
278 std::ceil(float(count) / work_group_size)
281 kernel.set_arg(count_arg, static_cast<uint_>(count));
282 kernel.set_arg(local_keys_arg, local_buffer<key_type>(work_group_size));
284 kernel.set_arg(local_vals_arg, local_buffer<uchar_>(work_group_size));
287 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
288 // return size of the block
289 return work_group_size;
292 template<class KeyIterator, class ValueIterator, class Compare>
293 inline size_t block_sort(KeyIterator keys_first,
294 ValueIterator values_first,
297 const bool sort_by_key,
299 command_queue &queue)
302 // TODO: Implement stable block sort (stable odd-even merge sort)
305 return bitonic_block_sort(
306 keys_first, values_first,
312 /// space: O(n + m); n - number of keys, m - number of values
313 template<class KeyIterator, class ValueIterator, class Compare>
314 inline void merge_blocks_on_gpu(KeyIterator keys_first,
315 ValueIterator values_first,
316 KeyIterator out_keys_first,
317 ValueIterator out_values_first,
320 const size_t block_size,
321 const bool sort_by_key,
322 command_queue &queue)
324 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
325 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
327 meta_kernel k("merge_blocks");
328 size_t count_arg = k.add_arg<const uint_>("count");
329 size_t block_size_arg = k.add_arg<const uint_>("block_size");
333 k.decl<const uint_>("gid") << " = get_global_id(0);\n" <<
334 "if(gid >= count) {\n" <<
338 k.decl<const key_type>("my_key") << " = " <<
339 keys_first[k.var<const uint_>("gid")] << ";\n";
343 k.decl<const value_type>("my_value") << " = " <<
344 values_first[k.var<const uint_>("gid")] << ";\n";
349 k.decl<const uint_>("my_block_idx") << " = gid / block_size;\n" <<
350 k.decl<const bool>("my_block_idx_is_odd") << " = " <<
351 "my_block_idx & 0x1;\n" <<
353 k.decl<const uint_>("other_block_idx") << " = " <<
354 // if(my_block_idx is odd) {} else {}
355 "my_block_idx_is_odd ? my_block_idx - 1 : my_block_idx + 1;\n" <<
357 // get ranges of my block and the other block
358 // [my_block_start; my_block_end)
359 // [other_block_start; other_block_end)
360 k.decl<const uint_>("my_block_start") << " = " <<
361 "min(my_block_idx * block_size, count);\n" << // including
362 k.decl<const uint_>("my_block_end") << " = " <<
363 "min((my_block_idx + 1) * block_size, count);\n" << // excluding
365 k.decl<const uint_>("other_block_start") << " = " <<
366 "min(other_block_idx * block_size, count);\n" << // including
367 k.decl<const uint_>("other_block_end") << " = " <<
368 "min((other_block_idx + 1) * block_size, count);\n" << // excluding
370 // other block is empty, nothing to merge here
371 "if(other_block_start == count){\n" <<
372 out_keys_first[k.var<uint_>("gid")] << " = my_key;\n";
375 out_values_first[k.var<uint_>("gid")] << " = my_value;\n";
383 // left_idx - lower bound
384 k.decl<uint_>("left_idx") << " = other_block_start;\n" <<
385 k.decl<uint_>("right_idx") << " = other_block_end;\n" <<
386 "while(left_idx < right_idx) {\n" <<
387 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
388 k.decl<key_type>("mid_key") << " = " <<
389 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
390 k.decl<bool>("smaller") << " = " <<
391 compare(k.var<key_type>("mid_key"),
392 k.var<key_type>("my_key")) << ";\n" <<
393 "left_idx = smaller ? mid_idx + 1 : left_idx;\n" <<
394 "right_idx = smaller ? right_idx : mid_idx;\n" <<
396 // left_idx is found position in other block
398 // if my_block is odd we need to get the upper bound
399 "right_idx = other_block_end;\n" <<
400 "if(my_block_idx_is_odd && left_idx != right_idx) {\n" <<
401 k.decl<key_type>("upper_key") << " = " <<
402 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
404 "!(" << compare(k.var<key_type>("upper_key"),
405 k.var<key_type>("my_key")) <<
407 "!(" << compare(k.var<key_type>("my_key"),
408 k.var<key_type>("upper_key")) <<
410 "left_idx < right_idx" <<
413 k.decl<uint_>("mid_idx") << " = (left_idx + right_idx) / 2;\n" <<
414 k.decl<key_type>("mid_key") << " = " <<
415 keys_first[k.var<const uint_>("mid_idx")] << ";\n" <<
416 k.decl<bool>("equal") << " = " <<
417 "!(" << compare(k.var<key_type>("mid_key"),
418 k.var<key_type>("my_key")) <<
420 "!(" << compare(k.var<key_type>("my_key"),
421 k.var<key_type>("mid_key")) <<
423 "left_idx = equal ? mid_idx + 1 : left_idx + 1;\n" <<
424 "right_idx = equal ? right_idx : mid_idx;\n" <<
426 keys_first[k.var<const uint_>("left_idx")] << ";\n" <<
430 k.decl<uint_>("offset") << " = 0;\n" <<
431 "offset += gid - my_block_start;\n" <<
432 "offset += left_idx - other_block_start;\n" <<
433 "offset += min(my_block_start, other_block_start);\n" <<
434 out_keys_first[k.var<uint_>("offset")] << " = my_key;\n";
437 out_values_first[k.var<uint_>("offset")] << " = my_value;\n";
440 const context &context = queue.get_context();
441 ::boost::compute::kernel kernel = k.compile(context);
443 const size_t work_group_size = (std::min)(
445 kernel.get_work_group_info<size_t>(
446 queue.get_device(), CL_KERNEL_WORK_GROUP_SIZE
449 const size_t global_size =
450 work_group_size * static_cast<size_t>(
451 std::ceil(float(count) / work_group_size)
454 kernel.set_arg(count_arg, static_cast<uint_>(count));
455 kernel.set_arg(block_size_arg, static_cast<uint_>(block_size));
456 queue.enqueue_1d_range_kernel(kernel, 0, global_size, work_group_size);
459 template<class KeyIterator, class ValueIterator, class Compare>
460 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
461 KeyIterator keys_last,
462 ValueIterator values_first,
465 command_queue &queue)
467 typedef typename std::iterator_traits<KeyIterator>::value_type key_type;
468 typedef typename std::iterator_traits<ValueIterator>::value_type value_type;
470 size_t count = iterator_range_size(keys_first, keys_last);
477 keys_first, values_first,
479 true /* sort_by_key */, stable /* stable */,
483 // for small input size only block sort is performed
484 if(count <= block_size) {
488 const context &context = queue.get_context();
490 bool result_in_temporary_buffer = false;
491 ::boost::compute::vector<key_type> temp_keys(count, context);
492 ::boost::compute::vector<value_type> temp_values(count, context);
494 for(; block_size < count; block_size *= 2) {
495 result_in_temporary_buffer = !result_in_temporary_buffer;
496 if(result_in_temporary_buffer) {
497 merge_blocks_on_gpu(keys_first, values_first,
498 temp_keys.begin(), temp_values.begin(),
499 compare, count, block_size,
500 true /* sort_by_key */, queue);
502 merge_blocks_on_gpu(temp_keys.begin(), temp_values.begin(),
503 keys_first, values_first,
504 compare, count, block_size,
505 true /* sort_by_key */, queue);
509 if(result_in_temporary_buffer) {
510 copy_async(temp_keys.begin(), temp_keys.end(), keys_first, queue);
511 copy_async(temp_values.begin(), temp_values.end(), values_first, queue);
515 template<class Iterator, class Compare>
516 inline void merge_sort_on_gpu(Iterator first,
520 command_queue &queue)
522 typedef typename std::iterator_traits<Iterator>::value_type key_type;
524 size_t count = iterator_range_size(first, last);
534 false /* sort_by_key */, stable /* stable */,
538 // for small input size only block sort is performed
539 if(count <= block_size) {
543 const context &context = queue.get_context();
545 bool result_in_temporary_buffer = false;
546 ::boost::compute::vector<key_type> temp_keys(count, context);
548 for(; block_size < count; block_size *= 2) {
549 result_in_temporary_buffer = !result_in_temporary_buffer;
550 if(result_in_temporary_buffer) {
551 merge_blocks_on_gpu(first, dummy, temp_keys.begin(), dummy,
552 compare, count, block_size,
553 false /* sort_by_key */, queue);
555 merge_blocks_on_gpu(temp_keys.begin(), dummy, first, dummy,
556 compare, count, block_size,
557 false /* sort_by_key */, queue);
561 if(result_in_temporary_buffer) {
562 copy_async(temp_keys.begin(), temp_keys.end(), first, queue);
566 template<class KeyIterator, class ValueIterator, class Compare>
567 inline void merge_sort_by_key_on_gpu(KeyIterator keys_first,
568 KeyIterator keys_last,
569 ValueIterator values_first,
571 command_queue &queue)
573 merge_sort_by_key_on_gpu(
574 keys_first, keys_last, values_first,
575 compare, false /* not stable */, queue
579 template<class Iterator, class Compare>
580 inline void merge_sort_on_gpu(Iterator first,
583 command_queue &queue)
586 first, last, compare, false /* not stable */, queue
590 } // end detail namespace
591 } // end compute namespace
592 } // end boost namespace
594 #endif /* BOOST_COMPUTE_ALGORITHM_DETAIL_MERGE_SORT_ON_GPU_HPP_ */