]> git.proxmox.com Git - ceph.git/blame - ceph/src/boost/boost/compute/algorithm/detail/scan_on_gpu.hpp
import new upstream nautilus stable release 14.2.8
[ceph.git] / ceph / src / boost / boost / compute / algorithm / detail / scan_on_gpu.hpp
CommitLineData
7c673cae
FG
1//---------------------------------------------------------------------------//
2// Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3//
4// Distributed under the Boost Software License, Version 1.0
5// See accompanying file LICENSE_1_0.txt or copy at
6// http://www.boost.org/LICENSE_1_0.txt
7//
8// See http://boostorg.github.com/compute for more information.
9//---------------------------------------------------------------------------//
10
11#ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
12#define BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP
13
14#include <boost/compute/kernel.hpp>
15#include <boost/compute/detail/meta_kernel.hpp>
16#include <boost/compute/command_queue.hpp>
17#include <boost/compute/container/vector.hpp>
18#include <boost/compute/detail/iterator_range_size.hpp>
19#include <boost/compute/memory/local_buffer.hpp>
20#include <boost/compute/iterator/buffer_iterator.hpp>
21
22namespace boost {
23namespace compute {
24namespace detail {
25
26template<class InputIterator, class OutputIterator, class BinaryOperator>
27class local_scan_kernel : public meta_kernel
28{
29public:
30 local_scan_kernel(InputIterator first,
31 InputIterator last,
32 OutputIterator result,
33 bool exclusive,
34 BinaryOperator op)
35 : meta_kernel("local_scan")
36 {
37 typedef typename std::iterator_traits<InputIterator>::value_type T;
38
39 (void) last;
40
41 bool checked = true;
42
43 m_block_sums_arg = add_arg<T *>(memory_object::global_memory, "block_sums");
44 m_scratch_arg = add_arg<T *>(memory_object::local_memory, "scratch");
45 m_block_size_arg = add_arg<const cl_uint>("block_size");
46 m_count_arg = add_arg<const cl_uint>("count");
47 m_init_value_arg = add_arg<const T>("init");
48
49 // work-item parameters
50 *this <<
51 "const uint gid = get_global_id(0);\n" <<
52 "const uint lid = get_local_id(0);\n";
53
54 // check against data size
55 if(checked){
56 *this <<
57 "if(gid < count){\n";
58 }
59
60 // copy values from input to local memory
61 if(exclusive){
62 *this <<
63 decl<const T>("local_init") << "= (gid == 0) ? init : 0;\n" <<
64 "if(lid == 0){ scratch[lid] = local_init; }\n" <<
65 "else { scratch[lid] = " << first[expr<cl_uint>("gid-1")] << "; }\n";
66 }
67 else{
68 *this <<
69 "scratch[lid] = " << first[expr<cl_uint>("gid")] << ";\n";
70 }
71
72 if(checked){
73 *this <<
74 "}\n"
75 "else {\n" <<
76 " scratch[lid] = 0;\n" <<
77 "}\n";
78 }
79
80 // wait for all threads to read from input
81 *this <<
82 "barrier(CLK_LOCAL_MEM_FENCE);\n";
83
84 // perform scan
85 *this <<
86 "for(uint i = 1; i < block_size; i <<= 1){\n" <<
87 " " << decl<const T>("x") << " = lid >= i ? scratch[lid-i] : 0;\n" <<
88 " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
89 " if(lid >= i){\n" <<
90 " scratch[lid] = " << op(var<T>("scratch[lid]"), var<T>("x")) << ";\n" <<
91 " }\n" <<
92 " barrier(CLK_LOCAL_MEM_FENCE);\n" <<
93 "}\n";
94
95 // copy results to output
96 if(checked){
97 *this <<
98 "if(gid < count){\n";
99 }
100
101 *this <<
102 result[expr<cl_uint>("gid")] << " = scratch[lid];\n";
103
104 if(checked){
105 *this << "}\n";
106 }
107
108 // store sum for the block
109 if(exclusive){
110 *this <<
92f5a8d4 111 "if(lid == block_size - 1 && gid < count) {\n" <<
7c673cae
FG
112 " block_sums[get_group_id(0)] = " <<
113 op(first[expr<cl_uint>("gid")], var<T>("scratch[lid]")) <<
114 ";\n" <<
115 "}\n";
116 }
117 else {
118 *this <<
119 "if(lid == block_size - 1){\n" <<
120 " block_sums[get_group_id(0)] = scratch[lid];\n" <<
121 "}\n";
122 }
123 }
124
125 size_t m_block_sums_arg;
126 size_t m_scratch_arg;
127 size_t m_block_size_arg;
128 size_t m_count_arg;
129 size_t m_init_value_arg;
130};
131
132template<class T, class BinaryOperator>
133class write_scanned_output_kernel : public meta_kernel
134{
135public:
136 write_scanned_output_kernel(BinaryOperator op)
137 : meta_kernel("write_scanned_output")
138 {
139 bool checked = true;
140
141 m_output_arg = add_arg<T *>(memory_object::global_memory, "output");
142 m_block_sums_arg = add_arg<const T *>(memory_object::global_memory, "block_sums");
143 m_count_arg = add_arg<const cl_uint>("count");
144
145 // work-item parameters
146 *this <<
147 "const uint gid = get_global_id(0);\n" <<
148 "const uint block_id = get_group_id(0);\n";
149
150 // check against data size
151 if(checked){
152 *this << "if(gid < count){\n";
153 }
154
155 // write output
156 *this <<
157 "output[gid] = " <<
158 op(var<T>("block_sums[block_id]"), var<T>("output[gid] ")) << ";\n";
159
160 if(checked){
161 *this << "}\n";
162 }
163 }
164
165 size_t m_output_arg;
166 size_t m_block_sums_arg;
167 size_t m_count_arg;
168};
169
170template<class InputIterator>
171inline size_t pick_scan_block_size(InputIterator first, InputIterator last)
172{
173 size_t count = iterator_range_size(first, last);
174
175 if(count == 0) { return 0; }
176 else if(count <= 1) { return 1; }
177 else if(count <= 2) { return 2; }
178 else if(count <= 4) { return 4; }
179 else if(count <= 8) { return 8; }
180 else if(count <= 16) { return 16; }
181 else if(count <= 32) { return 32; }
182 else if(count <= 64) { return 64; }
183 else if(count <= 128) { return 128; }
184 else { return 256; }
185}
186
187template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
188inline OutputIterator scan_impl(InputIterator first,
189 InputIterator last,
190 OutputIterator result,
191 bool exclusive,
192 T init,
193 BinaryOperator op,
194 command_queue &queue)
195{
196 typedef typename
197 std::iterator_traits<InputIterator>::value_type
198 input_type;
199 typedef typename
200 std::iterator_traits<InputIterator>::difference_type
201 difference_type;
202 typedef typename
203 std::iterator_traits<OutputIterator>::value_type
204 output_type;
205
206 const context &context = queue.get_context();
207 const size_t count = detail::iterator_range_size(first, last);
208
209 size_t block_size = pick_scan_block_size(first, last);
210 size_t block_count = count / block_size;
211
212 if(block_count * block_size < count){
213 block_count++;
214 }
215
216 ::boost::compute::vector<input_type> block_sums(block_count, context);
217
218 // zero block sums
219 input_type zero;
220 std::memset(&zero, 0, sizeof(input_type));
221 ::boost::compute::fill(block_sums.begin(), block_sums.end(), zero, queue);
222
223 // local scan
224 local_scan_kernel<InputIterator, OutputIterator, BinaryOperator>
225 local_scan_kernel(first, last, result, exclusive, op);
226
227 ::boost::compute::kernel kernel = local_scan_kernel.compile(context);
228 kernel.set_arg(local_scan_kernel.m_scratch_arg, local_buffer<input_type>(block_size));
229 kernel.set_arg(local_scan_kernel.m_block_sums_arg, block_sums);
230 kernel.set_arg(local_scan_kernel.m_block_size_arg, static_cast<cl_uint>(block_size));
231 kernel.set_arg(local_scan_kernel.m_count_arg, static_cast<cl_uint>(count));
232 kernel.set_arg(local_scan_kernel.m_init_value_arg, static_cast<output_type>(init));
233
234 queue.enqueue_1d_range_kernel(kernel,
235 0,
236 block_count * block_size,
237 block_size);
238
239 // inclusive scan block sums
240 if(block_count > 1){
241 scan_impl(block_sums.begin(),
242 block_sums.end(),
243 block_sums.begin(),
244 false,
245 init,
246 op,
247 queue
248 );
249 }
250
251 // add block sums to each block
252 if(block_count > 1){
253 write_scanned_output_kernel<input_type, BinaryOperator>
254 write_output_kernel(op);
255 kernel = write_output_kernel.compile(context);
256 kernel.set_arg(write_output_kernel.m_output_arg, result.get_buffer());
257 kernel.set_arg(write_output_kernel.m_block_sums_arg, block_sums);
258 kernel.set_arg(write_output_kernel.m_count_arg, static_cast<cl_uint>(count));
259
260 queue.enqueue_1d_range_kernel(kernel,
261 block_size,
262 block_count * block_size,
263 block_size);
264 }
265
266 return result + static_cast<difference_type>(count);
267}
268
269template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
270inline OutputIterator dispatch_scan(InputIterator first,
271 InputIterator last,
272 OutputIterator result,
273 bool exclusive,
274 T init,
275 BinaryOperator op,
276 command_queue &queue)
277{
278 return scan_impl(first, last, result, exclusive, init, op, queue);
279}
280
281template<class InputIterator, class T, class BinaryOperator>
282inline InputIterator dispatch_scan(InputIterator first,
283 InputIterator last,
284 InputIterator result,
285 bool exclusive,
286 T init,
287 BinaryOperator op,
288 command_queue &queue)
289{
290 typedef typename std::iterator_traits<InputIterator>::value_type value_type;
291
292 if(first == result){
293 // scan input in-place
294 const context &context = queue.get_context();
295
296 // make a temporary copy the input
297 size_t count = iterator_range_size(first, last);
298 vector<value_type> tmp(count, context);
299 copy(first, last, tmp.begin(), queue);
300
301 // scan from temporary values
302 return scan_impl(tmp.begin(), tmp.end(), first, exclusive, init, op, queue);
303 }
304 else {
305 // scan input to output
306 return scan_impl(first, last, result, exclusive, init, op, queue);
307 }
308}
309
310template<class InputIterator, class OutputIterator, class T, class BinaryOperator>
311inline OutputIterator scan_on_gpu(InputIterator first,
312 InputIterator last,
313 OutputIterator result,
314 bool exclusive,
315 T init,
316 BinaryOperator op,
317 command_queue &queue)
318{
319 if(first == last){
320 return result;
321 }
322
323 return dispatch_scan(first, last, result, exclusive, init, op, queue);
324}
325
326} // end detail namespace
327} // end compute namespace
328} // end boost namespace
329
330#endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SCAN_ON_GPU_HPP