]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP | |
13 | ||
14 | #include <boost/compute/kernel.hpp> | |
15 | #include <boost/compute/detail/meta_kernel.hpp> | |
16 | #include <boost/compute/command_queue.hpp> | |
17 | #include <boost/compute/container/vector.hpp> | |
18 | #include <boost/compute/detail/iterator_range_size.hpp> | |
19 | #include <boost/compute/memory/local_buffer.hpp> | |
20 | #include <boost/compute/iterator/buffer_iterator.hpp> | |
21 | ||
22 | namespace boost { | |
23 | namespace compute { | |
24 | namespace detail { | |
25 | ||
26 | template<class InputIterator, class OutputIterator, class BinaryOperator> | |
27 | class local_scan_kernel : public meta_kernel | |
28 | { | |
29 | public: | |
30 | local_scan_kernel(InputIterator first, | |
31 | InputIterator last, | |
32 | OutputIterator result, | |
33 | bool exclusive, | |
34 | BinaryOperator op) | |
35 | : meta_kernel("local_scan") | |
36 | { | |
37 | typedef typename std::iterator_traits<InputIterator>::value_type T; | |
38 | ||
39 | (void) last; | |
40 | ||
41 | bool checked = true; | |
42 | ||
43 | m_block_sums_arg = add_arg<T *>(memory_object::global_memory, "block_sums"); | |
44 | m_scratch_arg = add_arg<T *>(memory_object::local_memory, "scratch"); | |
45 | m_block_size_arg = add_arg<const cl_uint>("block_size"); | |
46 | m_count_arg = add_arg<const cl_uint>("count"); | |
47 | m_init_value_arg = add_arg<const T>("init"); | |
48 | ||
49 | // work-item parameters | |
50 | *this << | |
51 | "const uint gid = get_global_id(0);\n" << | |
52 | "const uint lid = get_local_id(0);\n"; | |
53 | ||
54 | // check against data size | |
55 | if(checked){ | |
56 | *this << | |
57 | "if(gid < count){\n"; | |
58 | } | |
59 | ||
60 | // copy values from input to local memory | |
61 | if(exclusive){ | |
62 | *this << | |
63 | decl<const T>("local_init") << "= (gid == 0) ? init : 0;\n" << | |
64 | "if(lid == 0){ scratch[lid] = local_init; }\n" << | |
65 | "else { scratch[lid] = " << first[expr<cl_uint>("gid-1")] << "; }\n"; | |
66 | } | |
67 | else{ | |
68 | *this << | |
69 | "scratch[lid] = " << first[expr<cl_uint>("gid")] << ";\n"; | |
70 | } | |
71 | ||
72 | if(checked){ | |
73 | *this << | |
74 | "}\n" | |
75 | "else {\n" << | |
76 | " scratch[lid] = 0;\n" << | |
77 | "}\n"; | |
78 | } | |
79 | ||
80 | // wait for all threads to read from input | |
81 | *this << | |
82 | "barrier(CLK_LOCAL_MEM_FENCE);\n"; | |
83 | ||
84 | // perform scan | |
85 | *this << | |
86 | "for(uint i = 1; i < block_size; i <<= 1){\n" << | |
87 | " " << decl<const T>("x") << " = lid >= i ? scratch[lid-i] : 0;\n" << | |
88 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
89 | " if(lid >= i){\n" << | |
90 | " scratch[lid] = " << op(var<T>("scratch[lid]"), var<T>("x")) << ";\n" << | |
91 | " }\n" << | |
92 | " barrier(CLK_LOCAL_MEM_FENCE);\n" << | |
93 | "}\n"; | |
94 | ||
95 | // copy results to output | |
96 | if(checked){ | |
97 | *this << | |
98 | "if(gid < count){\n"; | |
99 | } | |
100 | ||
101 | *this << | |
102 | result[expr<cl_uint>("gid")] << " = scratch[lid];\n"; | |
103 | ||
104 | if(checked){ | |
105 | *this << "}\n"; | |
106 | } | |
107 | ||
108 | // store sum for the block | |
109 | if(exclusive){ | |
110 | *this << | |
92f5a8d4 | 111 | "if(lid == block_size - 1 && gid < count) {\n" << |
7c673cae FG |
112 | " block_sums[get_group_id(0)] = " << |
113 | op(first[expr<cl_uint>("gid")], var<T>("scratch[lid]")) << | |
114 | ";\n" << | |
115 | "}\n"; | |
116 | } | |
117 | else { | |
118 | *this << | |
119 | "if(lid == block_size - 1){\n" << | |
120 | " block_sums[get_group_id(0)] = scratch[lid];\n" << | |
121 | "}\n"; | |
122 | } | |
123 | } | |
124 | ||
125 | size_t m_block_sums_arg; | |
126 | size_t m_scratch_arg; | |
127 | size_t m_block_size_arg; | |
128 | size_t m_count_arg; | |
129 | size_t m_init_value_arg; | |
130 | }; | |
131 | ||
132 | template<class T, class BinaryOperator> | |
133 | class write_scanned_output_kernel : public meta_kernel | |
134 | { | |
135 | public: | |
136 | write_scanned_output_kernel(BinaryOperator op) | |
137 | : meta_kernel("write_scanned_output") | |
138 | { | |
139 | bool checked = true; | |
140 | ||
141 | m_output_arg = add_arg<T *>(memory_object::global_memory, "output"); | |
142 | m_block_sums_arg = add_arg<const T *>(memory_object::global_memory, "block_sums"); | |
143 | m_count_arg = add_arg<const cl_uint>("count"); | |
144 | ||
145 | // work-item parameters | |
146 | *this << | |
147 | "const uint gid = get_global_id(0);\n" << | |
148 | "const uint block_id = get_group_id(0);\n"; | |
149 | ||
150 | // check against data size | |
151 | if(checked){ | |
152 | *this << "if(gid < count){\n"; | |
153 | } | |
154 | ||
155 | // write output | |
156 | *this << | |
157 | "output[gid] = " << | |
158 | op(var<T>("block_sums[block_id]"), var<T>("output[gid] ")) << ";\n"; | |
159 | ||
160 | if(checked){ | |
161 | *this << "}\n"; | |
162 | } | |
163 | } | |
164 | ||
165 | size_t m_output_arg; | |
166 | size_t m_block_sums_arg; | |
167 | size_t m_count_arg; | |
168 | }; | |
169 | ||
170 | template<class InputIterator> | |
171 | inline size_t pick_scan_block_size(InputIterator first, InputIterator last) | |
172 | { | |
173 | size_t count = iterator_range_size(first, last); | |
174 | ||
175 | if(count == 0) { return 0; } | |
176 | else if(count <= 1) { return 1; } | |
177 | else if(count <= 2) { return 2; } | |
178 | else if(count <= 4) { return 4; } | |
179 | else if(count <= 8) { return 8; } | |
180 | else if(count <= 16) { return 16; } | |
181 | else if(count <= 32) { return 32; } | |
182 | else if(count <= 64) { return 64; } | |
183 | else if(count <= 128) { return 128; } | |
184 | else { return 256; } | |
185 | } | |
186 | ||
187 | template<class InputIterator, class OutputIterator, class T, class BinaryOperator> | |
188 | inline OutputIterator scan_impl(InputIterator first, | |
189 | InputIterator last, | |
190 | OutputIterator result, | |
191 | bool exclusive, | |
192 | T init, | |
193 | BinaryOperator op, | |
194 | command_queue &queue) | |
195 | { | |
196 | typedef typename | |
197 | std::iterator_traits<InputIterator>::value_type | |
198 | input_type; | |
199 | typedef typename | |
200 | std::iterator_traits<InputIterator>::difference_type | |
201 | difference_type; | |
202 | typedef typename | |
203 | std::iterator_traits<OutputIterator>::value_type | |
204 | output_type; | |
205 | ||
206 | const context &context = queue.get_context(); | |
207 | const size_t count = detail::iterator_range_size(first, last); | |
208 | ||
209 | size_t block_size = pick_scan_block_size(first, last); | |
210 | size_t block_count = count / block_size; | |
211 | ||
212 | if(block_count * block_size < count){ | |
213 | block_count++; | |
214 | } | |
215 | ||
216 | ::boost::compute::vector<input_type> block_sums(block_count, context); | |
217 | ||
218 | // zero block sums | |
219 | input_type zero; | |
220 | std::memset(&zero, 0, sizeof(input_type)); | |
221 | ::boost::compute::fill(block_sums.begin(), block_sums.end(), zero, queue); | |
222 | ||
223 | // local scan | |
224 | local_scan_kernel<InputIterator, OutputIterator, BinaryOperator> | |
225 | local_scan_kernel(first, last, result, exclusive, op); | |
226 | ||
227 | ::boost::compute::kernel kernel = local_scan_kernel.compile(context); | |
228 | kernel.set_arg(local_scan_kernel.m_scratch_arg, local_buffer<input_type>(block_size)); | |
229 | kernel.set_arg(local_scan_kernel.m_block_sums_arg, block_sums); | |
230 | kernel.set_arg(local_scan_kernel.m_block_size_arg, static_cast<cl_uint>(block_size)); | |
231 | kernel.set_arg(local_scan_kernel.m_count_arg, static_cast<cl_uint>(count)); | |
232 | kernel.set_arg(local_scan_kernel.m_init_value_arg, static_cast<output_type>(init)); | |
233 | ||
234 | queue.enqueue_1d_range_kernel(kernel, | |
235 | 0, | |
236 | block_count * block_size, | |
237 | block_size); | |
238 | ||
239 | // inclusive scan block sums | |
240 | if(block_count > 1){ | |
241 | scan_impl(block_sums.begin(), | |
242 | block_sums.end(), | |
243 | block_sums.begin(), | |
244 | false, | |
245 | init, | |
246 | op, | |
247 | queue | |
248 | ); | |
249 | } | |
250 | ||
251 | // add block sums to each block | |
252 | if(block_count > 1){ | |
253 | write_scanned_output_kernel<input_type, BinaryOperator> | |
254 | write_output_kernel(op); | |
255 | kernel = write_output_kernel.compile(context); | |
256 | kernel.set_arg(write_output_kernel.m_output_arg, result.get_buffer()); | |
257 | kernel.set_arg(write_output_kernel.m_block_sums_arg, block_sums); | |
258 | kernel.set_arg(write_output_kernel.m_count_arg, static_cast<cl_uint>(count)); | |
259 | ||
260 | queue.enqueue_1d_range_kernel(kernel, | |
261 | block_size, | |
262 | block_count * block_size, | |
263 | block_size); | |
264 | } | |
265 | ||
266 | return result + static_cast<difference_type>(count); | |
267 | } | |
268 | ||
269 | template<class InputIterator, class OutputIterator, class T, class BinaryOperator> | |
270 | inline OutputIterator dispatch_scan(InputIterator first, | |
271 | InputIterator last, | |
272 | OutputIterator result, | |
273 | bool exclusive, | |
274 | T init, | |
275 | BinaryOperator op, | |
276 | command_queue &queue) | |
277 | { | |
278 | return scan_impl(first, last, result, exclusive, init, op, queue); | |
279 | } | |
280 | ||
281 | template<class InputIterator, class T, class BinaryOperator> | |
282 | inline InputIterator dispatch_scan(InputIterator first, | |
283 | InputIterator last, | |
284 | InputIterator result, | |
285 | bool exclusive, | |
286 | T init, | |
287 | BinaryOperator op, | |
288 | command_queue &queue) | |
289 | { | |
290 | typedef typename std::iterator_traits<InputIterator>::value_type value_type; | |
291 | ||
292 | if(first == result){ | |
293 | // scan input in-place | |
294 | const context &context = queue.get_context(); | |
295 | ||
296 | // make a temporary copy the input | |
297 | size_t count = iterator_range_size(first, last); | |
298 | vector<value_type> tmp(count, context); | |
299 | copy(first, last, tmp.begin(), queue); | |
300 | ||
301 | // scan from temporary values | |
302 | return scan_impl(tmp.begin(), tmp.end(), first, exclusive, init, op, queue); | |
303 | } | |
304 | else { | |
305 | // scan input to output | |
306 | return scan_impl(first, last, result, exclusive, init, op, queue); | |
307 | } | |
308 | } | |
309 | ||
310 | template<class InputIterator, class OutputIterator, class T, class BinaryOperator> | |
311 | inline OutputIterator scan_on_gpu(InputIterator first, | |
312 | InputIterator last, | |
313 | OutputIterator result, | |
314 | bool exclusive, | |
315 | T init, | |
316 | BinaryOperator op, | |
317 | command_queue &queue) | |
318 | { | |
319 | if(first == last){ | |
320 | return result; | |
321 | } | |
322 | ||
323 | return dispatch_scan(first, last, result, exclusive, init, op, queue); | |
324 | } | |
325 | ||
326 | } // end detail namespace | |
327 | } // end compute namespace | |
328 | } // end boost namespace | |
329 | ||
330 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP |