]> git.proxmox.com Git - ceph.git/blame - ceph/src/boost/libs/compute/include/boost/compute/algorithm/detail/radix_sort.hpp
bump version to 12.2.2-pve1
[ceph.git] / ceph / src / boost / libs / compute / include / boost / compute / algorithm / detail / radix_sort.hpp
CommitLineData
7c673cae
FG
1//---------------------------------------------------------------------------//
2// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3//
4// Distributed under the Boost Software License, Version 1.0
5// See accompanying file LICENSE_1_0.txt or copy at
6// http://www.boost.org/LICENSE_1_0.txt
7//
8// See http://boostorg.github.com/compute for more information.
9//---------------------------------------------------------------------------//
10
11#ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
12#define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
13
14#include <iterator>
15
16#include <boost/assert.hpp>
17#include <boost/type_traits/is_signed.hpp>
18#include <boost/type_traits/is_floating_point.hpp>
19
20#include <boost/compute/kernel.hpp>
21#include <boost/compute/program.hpp>
22#include <boost/compute/command_queue.hpp>
23#include <boost/compute/algorithm/exclusive_scan.hpp>
24#include <boost/compute/container/vector.hpp>
25#include <boost/compute/detail/iterator_range_size.hpp>
26#include <boost/compute/detail/parameter_cache.hpp>
27#include <boost/compute/type_traits/type_name.hpp>
28#include <boost/compute/type_traits/is_fundamental.hpp>
29#include <boost/compute/type_traits/is_vector_type.hpp>
30#include <boost/compute/utility/program_cache.hpp>
31
32namespace boost {
33namespace compute {
34namespace detail {
35
36// meta-function returning true if type T is radix-sortable
37template<class T>
38struct is_radix_sortable :
39 boost::mpl::and_<
40 typename ::boost::compute::is_fundamental<T>::type,
41 typename boost::mpl::not_<typename is_vector_type<T>::type>::type
42 >
43{
44};
45
46template<size_t N>
47struct radix_sort_value_type
48{
49};
50
51template<>
52struct radix_sort_value_type<1>
53{
54 typedef uchar_ type;
55};
56
57template<>
58struct radix_sort_value_type<2>
59{
60 typedef ushort_ type;
61};
62
63template<>
64struct radix_sort_value_type<4>
65{
66 typedef uint_ type;
67};
68
69template<>
70struct radix_sort_value_type<8>
71{
72 typedef ulong_ type;
73};
74
75template<typename T>
76inline const char* enable_double()
77{
78 return " -DT2_double=0";
79}
80
81template<>
82inline const char* enable_double<double>()
83{
84 return " -DT2_double=1";
85}
86
87const char radix_sort_source[] =
88"#if T2_double\n"
89"#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
90"#endif\n"
91"#define K2_BITS (1 << K_BITS)\n"
92"#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
93"#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
94
95"#if defined(ASC)\n" // asc order
96
97"inline uint radix(const T x, const uint low_bit)\n"
98"{\n"
99"#if defined(IS_FLOATING_POINT)\n"
100" const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
101" return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
102"#elif defined(IS_SIGNED)\n"
103" return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
104"#else\n"
105" return (x >> low_bit) & RADIX_MASK;\n"
106"#endif\n"
107"}\n"
108
109"#else\n" // desc order
110
111// For signed types we just negate the x and for unsigned types we
112// subtract the x from max value of its type ((T)(-1) is a max value
113// of type T when T is an unsigned type).
114"inline uint radix(const T x, const uint low_bit)\n"
115"{\n"
116"#if defined(IS_FLOATING_POINT)\n"
117" const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
118" return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
119"#elif defined(IS_SIGNED)\n"
120" return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
121"#else\n"
122" return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
123"#endif\n"
124"}\n"
125
126"#endif\n" // #if defined(ASC)
127
128"__kernel void count(__global const T *input,\n"
129" const uint input_offset,\n"
130" const uint input_size,\n"
131" __global uint *global_counts,\n"
132" __global uint *global_offsets,\n"
133" __local uint *local_counts,\n"
134" const uint low_bit)\n"
135"{\n"
136 // work-item parameters
137" const uint gid = get_global_id(0);\n"
138" const uint lid = get_local_id(0);\n"
139
140 // zero local counts
141" if(lid < K2_BITS){\n"
142" local_counts[lid] = 0;\n"
143" }\n"
144" barrier(CLK_LOCAL_MEM_FENCE);\n"
145
146 // reduce local counts
147" if(gid < input_size){\n"
148" T value = input[input_offset+gid];\n"
149" uint bucket = radix(value, low_bit);\n"
150" atomic_inc(local_counts + bucket);\n"
151" }\n"
152" barrier(CLK_LOCAL_MEM_FENCE);\n"
153
154 // write block-relative offsets
155" if(lid < K2_BITS){\n"
156" global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
157
158 // write global offsets
159" if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
160" global_offsets[lid] = local_counts[lid];\n"
161" }\n"
162" }\n"
163"}\n"
164
165"__kernel void scan(__global const uint *block_offsets,\n"
166" __global uint *global_offsets,\n"
167" const uint block_count)\n"
168"{\n"
169" __global const uint *last_block_offsets =\n"
170" block_offsets + K2_BITS * (block_count - 1);\n"
171
172 // calculate and scan global_offsets
173" uint sum = 0;\n"
174" for(uint i = 0; i < K2_BITS; i++){\n"
175" uint x = global_offsets[i] + last_block_offsets[i];\n"
176" global_offsets[i] = sum;\n"
177" sum += x;\n"
178" }\n"
179"}\n"
180
181"__kernel void scatter(__global const T *input,\n"
182" const uint input_offset,\n"
183" const uint input_size,\n"
184" const uint low_bit,\n"
185" __global const uint *counts,\n"
186" __global const uint *global_offsets,\n"
187"#ifndef SORT_BY_KEY\n"
188" __global T *output,\n"
189" const uint output_offset)\n"
190"#else\n"
191" __global T *keys_output,\n"
192" const uint keys_output_offset,\n"
193" __global T2 *values_input,\n"
194" const uint values_input_offset,\n"
195" __global T2 *values_output,\n"
196" const uint values_output_offset)\n"
197"#endif\n"
198"{\n"
199 // work-item parameters
200" const uint gid = get_global_id(0);\n"
201" const uint lid = get_local_id(0);\n"
202
203 // copy input to local memory
204" T value;\n"
205" uint bucket;\n"
206" __local uint local_input[BLOCK_SIZE];\n"
207" if(gid < input_size){\n"
208" value = input[input_offset+gid];\n"
209" bucket = radix(value, low_bit);\n"
210" local_input[lid] = bucket;\n"
211" }\n"
212
213 // copy block counts to local memory
214" __local uint local_counts[(1 << K_BITS)];\n"
215" if(lid < K2_BITS){\n"
216" local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
217" }\n"
218
219 // wait until local memory is ready
220" barrier(CLK_LOCAL_MEM_FENCE);\n"
221
222" if(gid >= input_size){\n"
223" return;\n"
224" }\n"
225
226 // get global offset
227" uint offset = global_offsets[bucket] + local_counts[bucket];\n"
228
229 // calculate local offset
230" uint local_offset = 0;\n"
231" for(uint i = 0; i < lid; i++){\n"
232" if(local_input[i] == bucket)\n"
233" local_offset++;\n"
234" }\n"
235
236"#ifndef SORT_BY_KEY\n"
237 // write value to output
238" output[output_offset + offset + local_offset] = value;\n"
239"#else\n"
240 // write key and value if doing sort_by_key
241" keys_output[keys_output_offset+offset + local_offset] = value;\n"
242" values_output[values_output_offset+offset + local_offset] =\n"
243" values_input[values_input_offset+gid];\n"
244"#endif\n"
245"}\n";
246
247template<class T, class T2>
248inline void radix_sort_impl(const buffer_iterator<T> first,
249 const buffer_iterator<T> last,
250 const buffer_iterator<T2> values_first,
251 const bool ascending,
252 command_queue &queue)
253{
254
255 typedef T value_type;
256 typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
257
258 const device &device = queue.get_device();
259 const context &context = queue.get_context();
260
261
262 // if we have a valid values iterator then we are doing a
263 // sort by key and have to set up the values buffer
264 bool sort_by_key = (values_first.get_buffer().get() != 0);
265
266 // load (or create) radix sort program
267 std::string cache_key =
268 std::string("__boost_radix_sort_") + type_name<value_type>();
269
270 if(sort_by_key){
271 cache_key += std::string("_with_") + type_name<T2>();
272 }
273
274 boost::shared_ptr<program_cache> cache =
275 program_cache::get_global_cache(context);
276 boost::shared_ptr<parameter_cache> parameters =
277 detail::parameter_cache::get_global_cache(device);
278
279 // sort parameters
280 const uint_ k = parameters->get(cache_key, "k", 4);
281 const uint_ k2 = 1 << k;
282 const uint_ block_size = parameters->get(cache_key, "tpb", 128);
283
284 // sort program compiler options
285 std::stringstream options;
286 options << "-DK_BITS=" << k;
287 options << " -DT=" << type_name<sort_type>();
288 options << " -DBLOCK_SIZE=" << block_size;
289
290 if(boost::is_floating_point<value_type>::value){
291 options << " -DIS_FLOATING_POINT";
292 }
293
294 if(boost::is_signed<value_type>::value){
295 options << " -DIS_SIGNED";
296 }
297
298 if(sort_by_key){
299 options << " -DSORT_BY_KEY";
300 options << " -DT2=" << type_name<T2>();
301 options << enable_double<T2>();
302 }
303
304 if(ascending){
305 options << " -DASC";
306 }
307
308 // load radix sort program
309 program radix_sort_program = cache->get_or_build(
310 cache_key, options.str(), radix_sort_source, context
311 );
312
313 kernel count_kernel(radix_sort_program, "count");
314 kernel scan_kernel(radix_sort_program, "scan");
315 kernel scatter_kernel(radix_sort_program, "scatter");
316
317 size_t count = detail::iterator_range_size(first, last);
318
319 uint_ block_count = static_cast<uint_>(count / block_size);
320 if(block_count * block_size != count){
321 block_count++;
322 }
323
324 // setup temporary buffers
325 vector<value_type> output(count, context);
326 vector<T2> values_output(sort_by_key ? count : 0, context);
327 vector<uint_> offsets(k2, context);
328 vector<uint_> counts(block_count * k2, context);
329
330 const buffer *input_buffer = &first.get_buffer();
331 uint_ input_offset = static_cast<uint_>(first.get_index());
332 const buffer *output_buffer = &output.get_buffer();
333 uint_ output_offset = 0;
334 const buffer *values_input_buffer = &values_first.get_buffer();
335 uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
336 const buffer *values_output_buffer = &values_output.get_buffer();
337 uint_ values_output_offset = 0;
338
339 for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
340 // write counts
341 count_kernel.set_arg(0, *input_buffer);
342 count_kernel.set_arg(1, input_offset);
343 count_kernel.set_arg(2, static_cast<uint_>(count));
344 count_kernel.set_arg(3, counts);
345 count_kernel.set_arg(4, offsets);
346 count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
347 count_kernel.set_arg(6, i * k);
348 queue.enqueue_1d_range_kernel(count_kernel,
349 0,
350 block_count * block_size,
351 block_size);
352
353 // scan counts
354 if(k == 1){
355 typedef uint2_ counter_type;
356 ::boost::compute::exclusive_scan(
357 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
358 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
359 make_buffer_iterator<counter_type>(counts.get_buffer()),
360 queue
361 );
362 }
363 else if(k == 2){
364 typedef uint4_ counter_type;
365 ::boost::compute::exclusive_scan(
366 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
367 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
368 make_buffer_iterator<counter_type>(counts.get_buffer()),
369 queue
370 );
371 }
372 else if(k == 4){
373 typedef uint16_ counter_type;
374 ::boost::compute::exclusive_scan(
375 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
376 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
377 make_buffer_iterator<counter_type>(counts.get_buffer()),
378 queue
379 );
380 }
381 else {
382 BOOST_ASSERT(false && "unknown k");
383 break;
384 }
385
386 // scan global offsets
387 scan_kernel.set_arg(0, counts);
388 scan_kernel.set_arg(1, offsets);
389 scan_kernel.set_arg(2, block_count);
390 queue.enqueue_task(scan_kernel);
391
392 // scatter values
393 scatter_kernel.set_arg(0, *input_buffer);
394 scatter_kernel.set_arg(1, input_offset);
395 scatter_kernel.set_arg(2, static_cast<uint_>(count));
396 scatter_kernel.set_arg(3, i * k);
397 scatter_kernel.set_arg(4, counts);
398 scatter_kernel.set_arg(5, offsets);
399 scatter_kernel.set_arg(6, *output_buffer);
400 scatter_kernel.set_arg(7, output_offset);
401 if(sort_by_key){
402 scatter_kernel.set_arg(8, *values_input_buffer);
403 scatter_kernel.set_arg(9, values_input_offset);
404 scatter_kernel.set_arg(10, *values_output_buffer);
405 scatter_kernel.set_arg(11, values_output_offset);
406 }
407 queue.enqueue_1d_range_kernel(scatter_kernel,
408 0,
409 block_count * block_size,
410 block_size);
411
412 // swap buffers
413 std::swap(input_buffer, output_buffer);
414 std::swap(values_input_buffer, values_output_buffer);
415 std::swap(input_offset, output_offset);
416 std::swap(values_input_offset, values_output_offset);
417 }
418}
419
420template<class Iterator>
421inline void radix_sort(Iterator first,
422 Iterator last,
423 command_queue &queue)
424{
425 radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
426}
427
428template<class KeyIterator, class ValueIterator>
429inline void radix_sort_by_key(KeyIterator keys_first,
430 KeyIterator keys_last,
431 ValueIterator values_first,
432 command_queue &queue)
433{
434 radix_sort_impl(keys_first, keys_last, values_first, true, queue);
435}
436
437template<class Iterator>
438inline void radix_sort(Iterator first,
439 Iterator last,
440 const bool ascending,
441 command_queue &queue)
442{
443 radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
444}
445
446template<class KeyIterator, class ValueIterator>
447inline void radix_sort_by_key(KeyIterator keys_first,
448 KeyIterator keys_last,
449 ValueIterator values_first,
450 const bool ascending,
451 command_queue &queue)
452{
453 radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
454}
455
456
457} // end detail namespace
458} // end compute namespace
459} // end boost namespace
460
461#endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP