]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP | |
13 | ||
14 | #include <iterator> | |
15 | ||
16 | #include <boost/assert.hpp> | |
17 | #include <boost/type_traits/is_signed.hpp> | |
18 | #include <boost/type_traits/is_floating_point.hpp> | |
19 | ||
20 | #include <boost/compute/kernel.hpp> | |
21 | #include <boost/compute/program.hpp> | |
22 | #include <boost/compute/command_queue.hpp> | |
23 | #include <boost/compute/algorithm/exclusive_scan.hpp> | |
24 | #include <boost/compute/container/vector.hpp> | |
25 | #include <boost/compute/detail/iterator_range_size.hpp> | |
26 | #include <boost/compute/detail/parameter_cache.hpp> | |
27 | #include <boost/compute/type_traits/type_name.hpp> | |
28 | #include <boost/compute/type_traits/is_fundamental.hpp> | |
29 | #include <boost/compute/type_traits/is_vector_type.hpp> | |
30 | #include <boost/compute/utility/program_cache.hpp> | |
31 | ||
32 | namespace boost { | |
33 | namespace compute { | |
34 | namespace detail { | |
35 | ||
36 | // meta-function returning true if type T is radix-sortable | |
37 | template<class T> | |
38 | struct is_radix_sortable : | |
39 | boost::mpl::and_< | |
40 | typename ::boost::compute::is_fundamental<T>::type, | |
41 | typename boost::mpl::not_<typename is_vector_type<T>::type>::type | |
42 | > | |
43 | { | |
44 | }; | |
45 | ||
46 | template<size_t N> | |
47 | struct radix_sort_value_type | |
48 | { | |
49 | }; | |
50 | ||
51 | template<> | |
52 | struct radix_sort_value_type<1> | |
53 | { | |
54 | typedef uchar_ type; | |
55 | }; | |
56 | ||
57 | template<> | |
58 | struct radix_sort_value_type<2> | |
59 | { | |
60 | typedef ushort_ type; | |
61 | }; | |
62 | ||
63 | template<> | |
64 | struct radix_sort_value_type<4> | |
65 | { | |
66 | typedef uint_ type; | |
67 | }; | |
68 | ||
69 | template<> | |
70 | struct radix_sort_value_type<8> | |
71 | { | |
72 | typedef ulong_ type; | |
73 | }; | |
74 | ||
75 | template<typename T> | |
76 | inline const char* enable_double() | |
77 | { | |
78 | return " -DT2_double=0"; | |
79 | } | |
80 | ||
81 | template<> | |
82 | inline const char* enable_double<double>() | |
83 | { | |
84 | return " -DT2_double=1"; | |
85 | } | |
86 | ||
87 | const char radix_sort_source[] = | |
88 | "#if T2_double\n" | |
89 | "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n" | |
90 | "#endif\n" | |
91 | "#define K2_BITS (1 << K_BITS)\n" | |
92 | "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n" | |
93 | "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n" | |
94 | ||
95 | "#if defined(ASC)\n" // asc order | |
96 | ||
97 | "inline uint radix(const T x, const uint low_bit)\n" | |
98 | "{\n" | |
99 | "#if defined(IS_FLOATING_POINT)\n" | |
100 | " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n" | |
101 | " return ((x ^ mask) >> low_bit) & RADIX_MASK;\n" | |
102 | "#elif defined(IS_SIGNED)\n" | |
103 | " return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n" | |
104 | "#else\n" | |
105 | " return (x >> low_bit) & RADIX_MASK;\n" | |
106 | "#endif\n" | |
107 | "}\n" | |
108 | ||
109 | "#else\n" // desc order | |
110 | ||
111 | // For signed types we just negate the x and for unsigned types we | |
112 | // subtract the x from max value of its type ((T)(-1) is a max value | |
113 | // of type T when T is an unsigned type). | |
114 | "inline uint radix(const T x, const uint low_bit)\n" | |
115 | "{\n" | |
116 | "#if defined(IS_FLOATING_POINT)\n" | |
117 | " const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n" | |
118 | " return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n" | |
119 | "#elif defined(IS_SIGNED)\n" | |
120 | " return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n" | |
121 | "#else\n" | |
122 | " return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n" | |
123 | "#endif\n" | |
124 | "}\n" | |
125 | ||
126 | "#endif\n" // #if defined(ASC) | |
127 | ||
128 | "__kernel void count(__global const T *input,\n" | |
129 | " const uint input_offset,\n" | |
130 | " const uint input_size,\n" | |
131 | " __global uint *global_counts,\n" | |
132 | " __global uint *global_offsets,\n" | |
133 | " __local uint *local_counts,\n" | |
134 | " const uint low_bit)\n" | |
135 | "{\n" | |
136 | // work-item parameters | |
137 | " const uint gid = get_global_id(0);\n" | |
138 | " const uint lid = get_local_id(0);\n" | |
139 | ||
140 | // zero local counts | |
141 | " if(lid < K2_BITS){\n" | |
142 | " local_counts[lid] = 0;\n" | |
143 | " }\n" | |
144 | " barrier(CLK_LOCAL_MEM_FENCE);\n" | |
145 | ||
146 | // reduce local counts | |
147 | " if(gid < input_size){\n" | |
148 | " T value = input[input_offset+gid];\n" | |
149 | " uint bucket = radix(value, low_bit);\n" | |
150 | " atomic_inc(local_counts + bucket);\n" | |
151 | " }\n" | |
152 | " barrier(CLK_LOCAL_MEM_FENCE);\n" | |
153 | ||
154 | // write block-relative offsets | |
155 | " if(lid < K2_BITS){\n" | |
156 | " global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n" | |
157 | ||
158 | // write global offsets | |
159 | " if(get_group_id(0) == (get_num_groups(0) - 1)){\n" | |
160 | " global_offsets[lid] = local_counts[lid];\n" | |
161 | " }\n" | |
162 | " }\n" | |
163 | "}\n" | |
164 | ||
165 | "__kernel void scan(__global const uint *block_offsets,\n" | |
166 | " __global uint *global_offsets,\n" | |
167 | " const uint block_count)\n" | |
168 | "{\n" | |
169 | " __global const uint *last_block_offsets =\n" | |
170 | " block_offsets + K2_BITS * (block_count - 1);\n" | |
171 | ||
172 | // calculate and scan global_offsets | |
173 | " uint sum = 0;\n" | |
174 | " for(uint i = 0; i < K2_BITS; i++){\n" | |
175 | " uint x = global_offsets[i] + last_block_offsets[i];\n" | |
176 | " global_offsets[i] = sum;\n" | |
177 | " sum += x;\n" | |
178 | " }\n" | |
179 | "}\n" | |
180 | ||
181 | "__kernel void scatter(__global const T *input,\n" | |
182 | " const uint input_offset,\n" | |
183 | " const uint input_size,\n" | |
184 | " const uint low_bit,\n" | |
185 | " __global const uint *counts,\n" | |
186 | " __global const uint *global_offsets,\n" | |
187 | "#ifndef SORT_BY_KEY\n" | |
188 | " __global T *output,\n" | |
189 | " const uint output_offset)\n" | |
190 | "#else\n" | |
191 | " __global T *keys_output,\n" | |
192 | " const uint keys_output_offset,\n" | |
193 | " __global T2 *values_input,\n" | |
194 | " const uint values_input_offset,\n" | |
195 | " __global T2 *values_output,\n" | |
196 | " const uint values_output_offset)\n" | |
197 | "#endif\n" | |
198 | "{\n" | |
199 | // work-item parameters | |
200 | " const uint gid = get_global_id(0);\n" | |
201 | " const uint lid = get_local_id(0);\n" | |
202 | ||
203 | // copy input to local memory | |
204 | " T value;\n" | |
205 | " uint bucket;\n" | |
206 | " __local uint local_input[BLOCK_SIZE];\n" | |
207 | " if(gid < input_size){\n" | |
208 | " value = input[input_offset+gid];\n" | |
209 | " bucket = radix(value, low_bit);\n" | |
210 | " local_input[lid] = bucket;\n" | |
211 | " }\n" | |
212 | ||
213 | // copy block counts to local memory | |
214 | " __local uint local_counts[(1 << K_BITS)];\n" | |
215 | " if(lid < K2_BITS){\n" | |
216 | " local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n" | |
217 | " }\n" | |
218 | ||
219 | // wait until local memory is ready | |
220 | " barrier(CLK_LOCAL_MEM_FENCE);\n" | |
221 | ||
222 | " if(gid >= input_size){\n" | |
223 | " return;\n" | |
224 | " }\n" | |
225 | ||
226 | // get global offset | |
227 | " uint offset = global_offsets[bucket] + local_counts[bucket];\n" | |
228 | ||
229 | // calculate local offset | |
230 | " uint local_offset = 0;\n" | |
231 | " for(uint i = 0; i < lid; i++){\n" | |
232 | " if(local_input[i] == bucket)\n" | |
233 | " local_offset++;\n" | |
234 | " }\n" | |
235 | ||
236 | "#ifndef SORT_BY_KEY\n" | |
237 | // write value to output | |
238 | " output[output_offset + offset + local_offset] = value;\n" | |
239 | "#else\n" | |
240 | // write key and value if doing sort_by_key | |
241 | " keys_output[keys_output_offset+offset + local_offset] = value;\n" | |
242 | " values_output[values_output_offset+offset + local_offset] =\n" | |
243 | " values_input[values_input_offset+gid];\n" | |
244 | "#endif\n" | |
245 | "}\n"; | |
246 | ||
247 | template<class T, class T2> | |
248 | inline void radix_sort_impl(const buffer_iterator<T> first, | |
249 | const buffer_iterator<T> last, | |
250 | const buffer_iterator<T2> values_first, | |
251 | const bool ascending, | |
252 | command_queue &queue) | |
253 | { | |
254 | ||
255 | typedef T value_type; | |
256 | typedef typename radix_sort_value_type<sizeof(T)>::type sort_type; | |
257 | ||
258 | const device &device = queue.get_device(); | |
259 | const context &context = queue.get_context(); | |
260 | ||
261 | ||
262 | // if we have a valid values iterator then we are doing a | |
263 | // sort by key and have to set up the values buffer | |
264 | bool sort_by_key = (values_first.get_buffer().get() != 0); | |
265 | ||
266 | // load (or create) radix sort program | |
267 | std::string cache_key = | |
268 | std::string("__boost_radix_sort_") + type_name<value_type>(); | |
269 | ||
270 | if(sort_by_key){ | |
271 | cache_key += std::string("_with_") + type_name<T2>(); | |
272 | } | |
273 | ||
274 | boost::shared_ptr<program_cache> cache = | |
275 | program_cache::get_global_cache(context); | |
276 | boost::shared_ptr<parameter_cache> parameters = | |
277 | detail::parameter_cache::get_global_cache(device); | |
278 | ||
279 | // sort parameters | |
280 | const uint_ k = parameters->get(cache_key, "k", 4); | |
281 | const uint_ k2 = 1 << k; | |
282 | const uint_ block_size = parameters->get(cache_key, "tpb", 128); | |
283 | ||
284 | // sort program compiler options | |
285 | std::stringstream options; | |
286 | options << "-DK_BITS=" << k; | |
287 | options << " -DT=" << type_name<sort_type>(); | |
288 | options << " -DBLOCK_SIZE=" << block_size; | |
289 | ||
290 | if(boost::is_floating_point<value_type>::value){ | |
291 | options << " -DIS_FLOATING_POINT"; | |
292 | } | |
293 | ||
294 | if(boost::is_signed<value_type>::value){ | |
295 | options << " -DIS_SIGNED"; | |
296 | } | |
297 | ||
298 | if(sort_by_key){ | |
299 | options << " -DSORT_BY_KEY"; | |
300 | options << " -DT2=" << type_name<T2>(); | |
301 | options << enable_double<T2>(); | |
302 | } | |
303 | ||
304 | if(ascending){ | |
305 | options << " -DASC"; | |
306 | } | |
307 | ||
308 | // load radix sort program | |
309 | program radix_sort_program = cache->get_or_build( | |
310 | cache_key, options.str(), radix_sort_source, context | |
311 | ); | |
312 | ||
313 | kernel count_kernel(radix_sort_program, "count"); | |
314 | kernel scan_kernel(radix_sort_program, "scan"); | |
315 | kernel scatter_kernel(radix_sort_program, "scatter"); | |
316 | ||
317 | size_t count = detail::iterator_range_size(first, last); | |
318 | ||
319 | uint_ block_count = static_cast<uint_>(count / block_size); | |
320 | if(block_count * block_size != count){ | |
321 | block_count++; | |
322 | } | |
323 | ||
324 | // setup temporary buffers | |
325 | vector<value_type> output(count, context); | |
326 | vector<T2> values_output(sort_by_key ? count : 0, context); | |
327 | vector<uint_> offsets(k2, context); | |
328 | vector<uint_> counts(block_count * k2, context); | |
329 | ||
330 | const buffer *input_buffer = &first.get_buffer(); | |
331 | uint_ input_offset = static_cast<uint_>(first.get_index()); | |
332 | const buffer *output_buffer = &output.get_buffer(); | |
333 | uint_ output_offset = 0; | |
334 | const buffer *values_input_buffer = &values_first.get_buffer(); | |
335 | uint_ values_input_offset = static_cast<uint_>(values_first.get_index()); | |
336 | const buffer *values_output_buffer = &values_output.get_buffer(); | |
337 | uint_ values_output_offset = 0; | |
338 | ||
339 | for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){ | |
340 | // write counts | |
341 | count_kernel.set_arg(0, *input_buffer); | |
342 | count_kernel.set_arg(1, input_offset); | |
343 | count_kernel.set_arg(2, static_cast<uint_>(count)); | |
344 | count_kernel.set_arg(3, counts); | |
345 | count_kernel.set_arg(4, offsets); | |
346 | count_kernel.set_arg(5, block_size * sizeof(uint_), 0); | |
347 | count_kernel.set_arg(6, i * k); | |
348 | queue.enqueue_1d_range_kernel(count_kernel, | |
349 | 0, | |
350 | block_count * block_size, | |
351 | block_size); | |
352 | ||
353 | // scan counts | |
354 | if(k == 1){ | |
355 | typedef uint2_ counter_type; | |
356 | ::boost::compute::exclusive_scan( | |
357 | make_buffer_iterator<counter_type>(counts.get_buffer(), 0), | |
358 | make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2), | |
359 | make_buffer_iterator<counter_type>(counts.get_buffer()), | |
360 | queue | |
361 | ); | |
362 | } | |
363 | else if(k == 2){ | |
364 | typedef uint4_ counter_type; | |
365 | ::boost::compute::exclusive_scan( | |
366 | make_buffer_iterator<counter_type>(counts.get_buffer(), 0), | |
367 | make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4), | |
368 | make_buffer_iterator<counter_type>(counts.get_buffer()), | |
369 | queue | |
370 | ); | |
371 | } | |
372 | else if(k == 4){ | |
373 | typedef uint16_ counter_type; | |
374 | ::boost::compute::exclusive_scan( | |
375 | make_buffer_iterator<counter_type>(counts.get_buffer(), 0), | |
376 | make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16), | |
377 | make_buffer_iterator<counter_type>(counts.get_buffer()), | |
378 | queue | |
379 | ); | |
380 | } | |
381 | else { | |
382 | BOOST_ASSERT(false && "unknown k"); | |
383 | break; | |
384 | } | |
385 | ||
386 | // scan global offsets | |
387 | scan_kernel.set_arg(0, counts); | |
388 | scan_kernel.set_arg(1, offsets); | |
389 | scan_kernel.set_arg(2, block_count); | |
390 | queue.enqueue_task(scan_kernel); | |
391 | ||
392 | // scatter values | |
393 | scatter_kernel.set_arg(0, *input_buffer); | |
394 | scatter_kernel.set_arg(1, input_offset); | |
395 | scatter_kernel.set_arg(2, static_cast<uint_>(count)); | |
396 | scatter_kernel.set_arg(3, i * k); | |
397 | scatter_kernel.set_arg(4, counts); | |
398 | scatter_kernel.set_arg(5, offsets); | |
399 | scatter_kernel.set_arg(6, *output_buffer); | |
400 | scatter_kernel.set_arg(7, output_offset); | |
401 | if(sort_by_key){ | |
402 | scatter_kernel.set_arg(8, *values_input_buffer); | |
403 | scatter_kernel.set_arg(9, values_input_offset); | |
404 | scatter_kernel.set_arg(10, *values_output_buffer); | |
405 | scatter_kernel.set_arg(11, values_output_offset); | |
406 | } | |
407 | queue.enqueue_1d_range_kernel(scatter_kernel, | |
408 | 0, | |
409 | block_count * block_size, | |
410 | block_size); | |
411 | ||
412 | // swap buffers | |
413 | std::swap(input_buffer, output_buffer); | |
414 | std::swap(values_input_buffer, values_output_buffer); | |
415 | std::swap(input_offset, output_offset); | |
416 | std::swap(values_input_offset, values_output_offset); | |
417 | } | |
418 | } | |
419 | ||
420 | template<class Iterator> | |
421 | inline void radix_sort(Iterator first, | |
422 | Iterator last, | |
423 | command_queue &queue) | |
424 | { | |
425 | radix_sort_impl(first, last, buffer_iterator<int>(), true, queue); | |
426 | } | |
427 | ||
428 | template<class KeyIterator, class ValueIterator> | |
429 | inline void radix_sort_by_key(KeyIterator keys_first, | |
430 | KeyIterator keys_last, | |
431 | ValueIterator values_first, | |
432 | command_queue &queue) | |
433 | { | |
434 | radix_sort_impl(keys_first, keys_last, values_first, true, queue); | |
435 | } | |
436 | ||
437 | template<class Iterator> | |
438 | inline void radix_sort(Iterator first, | |
439 | Iterator last, | |
440 | const bool ascending, | |
441 | command_queue &queue) | |
442 | { | |
443 | radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue); | |
444 | } | |
445 | ||
446 | template<class KeyIterator, class ValueIterator> | |
447 | inline void radix_sort_by_key(KeyIterator keys_first, | |
448 | KeyIterator keys_last, | |
449 | ValueIterator values_first, | |
450 | const bool ascending, | |
451 | command_queue &queue) | |
452 | { | |
453 | radix_sort_impl(keys_first, keys_last, values_first, ascending, queue); | |
454 | } | |
455 | ||
456 | ||
457 | } // end detail namespace | |
458 | } // end compute namespace | |
459 | } // end boost namespace | |
460 | ||
461 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP |