]> git.proxmox.com Git - ceph.git/blob - ceph/src/boost/boost/compute/algorithm/detail/radix_sort.hpp
update sources to v12.2.3
[ceph.git] / ceph / src / boost / boost / compute / algorithm / detail / radix_sort.hpp
1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
13
14 #include <iterator>
15
16 #include <boost/assert.hpp>
17 #include <boost/type_traits/is_signed.hpp>
18 #include <boost/type_traits/is_floating_point.hpp>
19
20 #include <boost/mpl/and.hpp>
21 #include <boost/mpl/not.hpp>
22
23 #include <boost/compute/kernel.hpp>
24 #include <boost/compute/program.hpp>
25 #include <boost/compute/command_queue.hpp>
26 #include <boost/compute/algorithm/exclusive_scan.hpp>
27 #include <boost/compute/container/vector.hpp>
28 #include <boost/compute/detail/iterator_range_size.hpp>
29 #include <boost/compute/detail/parameter_cache.hpp>
30 #include <boost/compute/type_traits/type_name.hpp>
31 #include <boost/compute/type_traits/is_fundamental.hpp>
32 #include <boost/compute/type_traits/is_vector_type.hpp>
33 #include <boost/compute/utility/program_cache.hpp>
34
35 namespace boost {
36 namespace compute {
37 namespace detail {
38
39 // meta-function returning true if type T is radix-sortable
40 template<class T>
41 struct is_radix_sortable :
42 boost::mpl::and_<
43 typename ::boost::compute::is_fundamental<T>::type,
44 typename boost::mpl::not_<typename is_vector_type<T>::type>::type
45 >
46 {
47 };
48
49 template<size_t N>
50 struct radix_sort_value_type
51 {
52 };
53
54 template<>
55 struct radix_sort_value_type<1>
56 {
57 typedef uchar_ type;
58 };
59
60 template<>
61 struct radix_sort_value_type<2>
62 {
63 typedef ushort_ type;
64 };
65
66 template<>
67 struct radix_sort_value_type<4>
68 {
69 typedef uint_ type;
70 };
71
72 template<>
73 struct radix_sort_value_type<8>
74 {
75 typedef ulong_ type;
76 };
77
78 template<typename T>
79 inline const char* enable_double()
80 {
81 return " -DT2_double=0";
82 }
83
84 template<>
85 inline const char* enable_double<double>()
86 {
87 return " -DT2_double=1";
88 }
89
90 const char radix_sort_source[] =
91 "#if T2_double\n"
92 "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
93 "#endif\n"
94 "#define K2_BITS (1 << K_BITS)\n"
95 "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
96 "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
97
98 "#if defined(ASC)\n" // asc order
99
100 "inline uint radix(const T x, const uint low_bit)\n"
101 "{\n"
102 "#if defined(IS_FLOATING_POINT)\n"
103 " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
104 " return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
105 "#elif defined(IS_SIGNED)\n"
106 " return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
107 "#else\n"
108 " return (x >> low_bit) & RADIX_MASK;\n"
109 "#endif\n"
110 "}\n"
111
112 "#else\n" // desc order
113
114 // For signed types we just negate the x and for unsigned types we
115 // subtract the x from max value of its type ((T)(-1) is a max value
116 // of type T when T is an unsigned type).
117 "inline uint radix(const T x, const uint low_bit)\n"
118 "{\n"
119 "#if defined(IS_FLOATING_POINT)\n"
120 " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
121 " return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
122 "#elif defined(IS_SIGNED)\n"
123 " return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
124 "#else\n"
125 " return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
126 "#endif\n"
127 "}\n"
128
129 "#endif\n" // #if defined(ASC)
130
131 "__kernel void count(__global const T *input,\n"
132 " const uint input_offset,\n"
133 " const uint input_size,\n"
134 " __global uint *global_counts,\n"
135 " __global uint *global_offsets,\n"
136 " __local uint *local_counts,\n"
137 " const uint low_bit)\n"
138 "{\n"
139 // work-item parameters
140 " const uint gid = get_global_id(0);\n"
141 " const uint lid = get_local_id(0);\n"
142
143 // zero local counts
144 " if(lid < K2_BITS){\n"
145 " local_counts[lid] = 0;\n"
146 " }\n"
147 " barrier(CLK_LOCAL_MEM_FENCE);\n"
148
149 // reduce local counts
150 " if(gid < input_size){\n"
151 " T value = input[input_offset+gid];\n"
152 " uint bucket = radix(value, low_bit);\n"
153 " atomic_inc(local_counts + bucket);\n"
154 " }\n"
155 " barrier(CLK_LOCAL_MEM_FENCE);\n"
156
157 // write block-relative offsets
158 " if(lid < K2_BITS){\n"
159 " global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
160
161 // write global offsets
162 " if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
163 " global_offsets[lid] = local_counts[lid];\n"
164 " }\n"
165 " }\n"
166 "}\n"
167
168 "__kernel void scan(__global const uint *block_offsets,\n"
169 " __global uint *global_offsets,\n"
170 " const uint block_count)\n"
171 "{\n"
172 " __global const uint *last_block_offsets =\n"
173 " block_offsets + K2_BITS * (block_count - 1);\n"
174
175 // calculate and scan global_offsets
176 " uint sum = 0;\n"
177 " for(uint i = 0; i < K2_BITS; i++){\n"
178 " uint x = global_offsets[i] + last_block_offsets[i];\n"
179 " global_offsets[i] = sum;\n"
180 " sum += x;\n"
181 " }\n"
182 "}\n"
183
184 "__kernel void scatter(__global const T *input,\n"
185 " const uint input_offset,\n"
186 " const uint input_size,\n"
187 " const uint low_bit,\n"
188 " __global const uint *counts,\n"
189 " __global const uint *global_offsets,\n"
190 "#ifndef SORT_BY_KEY\n"
191 " __global T *output,\n"
192 " const uint output_offset)\n"
193 "#else\n"
194 " __global T *keys_output,\n"
195 " const uint keys_output_offset,\n"
196 " __global T2 *values_input,\n"
197 " const uint values_input_offset,\n"
198 " __global T2 *values_output,\n"
199 " const uint values_output_offset)\n"
200 "#endif\n"
201 "{\n"
202 // work-item parameters
203 " const uint gid = get_global_id(0);\n"
204 " const uint lid = get_local_id(0);\n"
205
206 // copy input to local memory
207 " T value;\n"
208 " uint bucket;\n"
209 " __local uint local_input[BLOCK_SIZE];\n"
210 " if(gid < input_size){\n"
211 " value = input[input_offset+gid];\n"
212 " bucket = radix(value, low_bit);\n"
213 " local_input[lid] = bucket;\n"
214 " }\n"
215
216 // copy block counts to local memory
217 " __local uint local_counts[(1 << K_BITS)];\n"
218 " if(lid < K2_BITS){\n"
219 " local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
220 " }\n"
221
222 // wait until local memory is ready
223 " barrier(CLK_LOCAL_MEM_FENCE);\n"
224
225 " if(gid >= input_size){\n"
226 " return;\n"
227 " }\n"
228
229 // get global offset
230 " uint offset = global_offsets[bucket] + local_counts[bucket];\n"
231
232 // calculate local offset
233 " uint local_offset = 0;\n"
234 " for(uint i = 0; i < lid; i++){\n"
235 " if(local_input[i] == bucket)\n"
236 " local_offset++;\n"
237 " }\n"
238
239 "#ifndef SORT_BY_KEY\n"
240 // write value to output
241 " output[output_offset + offset + local_offset] = value;\n"
242 "#else\n"
243 // write key and value if doing sort_by_key
244 " keys_output[keys_output_offset+offset + local_offset] = value;\n"
245 " values_output[values_output_offset+offset + local_offset] =\n"
246 " values_input[values_input_offset+gid];\n"
247 "#endif\n"
248 "}\n";
249
250 template<class T, class T2>
251 inline void radix_sort_impl(const buffer_iterator<T> first,
252 const buffer_iterator<T> last,
253 const buffer_iterator<T2> values_first,
254 const bool ascending,
255 command_queue &queue)
256 {
257
258 typedef T value_type;
259 typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
260
261 const device &device = queue.get_device();
262 const context &context = queue.get_context();
263
264
265 // if we have a valid values iterator then we are doing a
266 // sort by key and have to set up the values buffer
267 bool sort_by_key = (values_first.get_buffer().get() != 0);
268
269 // load (or create) radix sort program
270 std::string cache_key =
271 std::string("__boost_radix_sort_") + type_name<value_type>();
272
273 if(sort_by_key){
274 cache_key += std::string("_with_") + type_name<T2>();
275 }
276
277 boost::shared_ptr<program_cache> cache =
278 program_cache::get_global_cache(context);
279 boost::shared_ptr<parameter_cache> parameters =
280 detail::parameter_cache::get_global_cache(device);
281
282 // sort parameters
283 const uint_ k = parameters->get(cache_key, "k", 4);
284 const uint_ k2 = 1 << k;
285 const uint_ block_size = parameters->get(cache_key, "tpb", 128);
286
287 // sort program compiler options
288 std::stringstream options;
289 options << "-DK_BITS=" << k;
290 options << " -DT=" << type_name<sort_type>();
291 options << " -DBLOCK_SIZE=" << block_size;
292
293 if(boost::is_floating_point<value_type>::value){
294 options << " -DIS_FLOATING_POINT";
295 }
296
297 if(boost::is_signed<value_type>::value){
298 options << " -DIS_SIGNED";
299 }
300
301 if(sort_by_key){
302 options << " -DSORT_BY_KEY";
303 options << " -DT2=" << type_name<T2>();
304 options << enable_double<T2>();
305 }
306
307 if(ascending){
308 options << " -DASC";
309 }
310
311 // get type definition if it is a custom struct
312 std::string custom_type_def = boost::compute::type_definition<T2>() + "\n";
313
314 // load radix sort program
315 program radix_sort_program = cache->get_or_build(
316 cache_key, options.str(), custom_type_def + radix_sort_source, context
317 );
318
319 kernel count_kernel(radix_sort_program, "count");
320 kernel scan_kernel(radix_sort_program, "scan");
321 kernel scatter_kernel(radix_sort_program, "scatter");
322
323 size_t count = detail::iterator_range_size(first, last);
324
325 uint_ block_count = static_cast<uint_>(count / block_size);
326 if(block_count * block_size != count){
327 block_count++;
328 }
329
330 // setup temporary buffers
331 vector<value_type> output(count, context);
332 vector<T2> values_output(sort_by_key ? count : 0, context);
333 vector<uint_> offsets(k2, context);
334 vector<uint_> counts(block_count * k2, context);
335
336 const buffer *input_buffer = &first.get_buffer();
337 uint_ input_offset = static_cast<uint_>(first.get_index());
338 const buffer *output_buffer = &output.get_buffer();
339 uint_ output_offset = 0;
340 const buffer *values_input_buffer = &values_first.get_buffer();
341 uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
342 const buffer *values_output_buffer = &values_output.get_buffer();
343 uint_ values_output_offset = 0;
344
345 for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
346 // write counts
347 count_kernel.set_arg(0, *input_buffer);
348 count_kernel.set_arg(1, input_offset);
349 count_kernel.set_arg(2, static_cast<uint_>(count));
350 count_kernel.set_arg(3, counts);
351 count_kernel.set_arg(4, offsets);
352 count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
353 count_kernel.set_arg(6, i * k);
354 queue.enqueue_1d_range_kernel(count_kernel,
355 0,
356 block_count * block_size,
357 block_size);
358
359 // scan counts
360 if(k == 1){
361 typedef uint2_ counter_type;
362 ::boost::compute::exclusive_scan(
363 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
364 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
365 make_buffer_iterator<counter_type>(counts.get_buffer()),
366 queue
367 );
368 }
369 else if(k == 2){
370 typedef uint4_ counter_type;
371 ::boost::compute::exclusive_scan(
372 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
373 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
374 make_buffer_iterator<counter_type>(counts.get_buffer()),
375 queue
376 );
377 }
378 else if(k == 4){
379 typedef uint16_ counter_type;
380 ::boost::compute::exclusive_scan(
381 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
382 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
383 make_buffer_iterator<counter_type>(counts.get_buffer()),
384 queue
385 );
386 }
387 else {
388 BOOST_ASSERT(false && "unknown k");
389 break;
390 }
391
392 // scan global offsets
393 scan_kernel.set_arg(0, counts);
394 scan_kernel.set_arg(1, offsets);
395 scan_kernel.set_arg(2, block_count);
396 queue.enqueue_task(scan_kernel);
397
398 // scatter values
399 scatter_kernel.set_arg(0, *input_buffer);
400 scatter_kernel.set_arg(1, input_offset);
401 scatter_kernel.set_arg(2, static_cast<uint_>(count));
402 scatter_kernel.set_arg(3, i * k);
403 scatter_kernel.set_arg(4, counts);
404 scatter_kernel.set_arg(5, offsets);
405 scatter_kernel.set_arg(6, *output_buffer);
406 scatter_kernel.set_arg(7, output_offset);
407 if(sort_by_key){
408 scatter_kernel.set_arg(8, *values_input_buffer);
409 scatter_kernel.set_arg(9, values_input_offset);
410 scatter_kernel.set_arg(10, *values_output_buffer);
411 scatter_kernel.set_arg(11, values_output_offset);
412 }
413 queue.enqueue_1d_range_kernel(scatter_kernel,
414 0,
415 block_count * block_size,
416 block_size);
417
418 // swap buffers
419 std::swap(input_buffer, output_buffer);
420 std::swap(values_input_buffer, values_output_buffer);
421 std::swap(input_offset, output_offset);
422 std::swap(values_input_offset, values_output_offset);
423 }
424 }
425
426 template<class Iterator>
427 inline void radix_sort(Iterator first,
428 Iterator last,
429 command_queue &queue)
430 {
431 radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
432 }
433
434 template<class KeyIterator, class ValueIterator>
435 inline void radix_sort_by_key(KeyIterator keys_first,
436 KeyIterator keys_last,
437 ValueIterator values_first,
438 command_queue &queue)
439 {
440 radix_sort_impl(keys_first, keys_last, values_first, true, queue);
441 }
442
443 template<class Iterator>
444 inline void radix_sort(Iterator first,
445 Iterator last,
446 const bool ascending,
447 command_queue &queue)
448 {
449 radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
450 }
451
452 template<class KeyIterator, class ValueIterator>
453 inline void radix_sort_by_key(KeyIterator keys_first,
454 KeyIterator keys_last,
455 ValueIterator values_first,
456 const bool ascending,
457 command_queue &queue)
458 {
459 radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
460 }
461
462
463 } // end detail namespace
464 } // end compute namespace
465 } // end boost namespace
466
467 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP