]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2016 Jakub Szuppe <j.szuppe@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_ON_CPU_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_ON_CPU_HPP | |
13 | ||
14 | #include <algorithm> | |
15 | ||
16 | #include <boost/compute/algorithm/detail/find_extrema_with_reduce.hpp> | |
17 | #include <boost/compute/algorithm/detail/find_extrema_with_atomics.hpp> | |
18 | #include <boost/compute/algorithm/detail/serial_find_extrema.hpp> | |
19 | #include <boost/compute/detail/iterator_range_size.hpp> | |
20 | #include <boost/compute/iterator/buffer_iterator.hpp> | |
21 | ||
22 | namespace boost { | |
23 | namespace compute { | |
24 | namespace detail { | |
25 | ||
26 | template<class InputIterator, class Compare> | |
27 | inline InputIterator find_extrema_on_cpu(InputIterator first, | |
28 | InputIterator last, | |
29 | Compare compare, | |
30 | const bool find_minimum, | |
31 | command_queue &queue) | |
32 | { | |
33 | typedef typename std::iterator_traits<InputIterator>::value_type input_type; | |
34 | typedef typename std::iterator_traits<InputIterator>::difference_type difference_type; | |
35 | size_t count = iterator_range_size(first, last); | |
36 | ||
37 | const device &device = queue.get_device(); | |
38 | const uint_ compute_units = queue.get_device().compute_units(); | |
39 | ||
40 | boost::shared_ptr<parameter_cache> parameters = | |
41 | detail::parameter_cache::get_global_cache(device); | |
42 | std::string cache_key = | |
43 | "__boost_find_extrema_cpu_" | |
44 | + boost::lexical_cast<std::string>(sizeof(input_type)); | |
45 | ||
46 | // for inputs smaller than serial_find_extrema_threshold | |
47 | // serial_find_extrema algorithm is used | |
48 | uint_ serial_find_extrema_threshold = parameters->get( | |
49 | cache_key, | |
50 | "serial_find_extrema_threshold", | |
51 | 16384 * sizeof(input_type) | |
52 | ); | |
53 | serial_find_extrema_threshold = | |
54 | (std::max)(serial_find_extrema_threshold, uint_(2 * compute_units)); | |
55 | ||
56 | const context &context = queue.get_context(); | |
57 | if(count < serial_find_extrema_threshold) { | |
58 | return serial_find_extrema(first, last, compare, find_minimum, queue); | |
59 | } | |
60 | ||
61 | meta_kernel k("find_extrema_on_cpu"); | |
62 | buffer output(context, sizeof(input_type) * compute_units); | |
63 | buffer output_idx( | |
64 | context, sizeof(uint_) * compute_units, | |
65 | buffer::read_write | buffer::alloc_host_ptr | |
66 | ); | |
67 | ||
68 | size_t count_arg = k.add_arg<uint_>("count"); | |
69 | size_t output_arg = | |
70 | k.add_arg<input_type *>(memory_object::global_memory, "output"); | |
71 | size_t output_idx_arg = | |
72 | k.add_arg<uint_ *>(memory_object::global_memory, "output_idx"); | |
73 | ||
74 | k << | |
75 | "uint block = " << | |
76 | "(uint)ceil(((float)count)/get_global_size(0));\n" << | |
77 | "uint index = get_global_id(0) * block;\n" << | |
78 | "uint end = min(count, index + block);\n" << | |
79 | ||
80 | "uint value_index = index;\n" << | |
81 | k.decl<input_type>("value") << " = " << first[k.var<uint_>("index")] << ";\n" << | |
82 | ||
83 | "index++;\n" << | |
84 | "while(index < end){\n" << | |
85 | k.decl<input_type>("candidate") << | |
86 | " = " << first[k.var<uint_>("index")] << ";\n" << | |
87 | "#ifndef BOOST_COMPUTE_FIND_MAXIMUM\n" << | |
88 | "bool compare = " << compare(k.var<input_type>("candidate"), | |
89 | k.var<input_type>("value")) << ";\n" << | |
90 | "#else\n" << | |
91 | "bool compare = " << compare(k.var<input_type>("value"), | |
92 | k.var<input_type>("candidate")) << ";\n" << | |
93 | "#endif\n" << | |
94 | "value = compare ? candidate : value;\n" << | |
95 | "value_index = compare ? index : value_index;\n" << | |
96 | "index++;\n" << | |
97 | "}\n" << | |
98 | "output[get_global_id(0)] = value;\n" << | |
99 | "output_idx[get_global_id(0)] = value_index;\n"; | |
100 | ||
101 | size_t global_work_size = compute_units; | |
102 | std::string options; | |
103 | if(!find_minimum){ | |
104 | options = "-DBOOST_COMPUTE_FIND_MAXIMUM"; | |
105 | } | |
106 | kernel kernel = k.compile(context, options); | |
107 | ||
108 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
109 | kernel.set_arg(output_arg, output); | |
110 | kernel.set_arg(output_idx_arg, output_idx); | |
111 | queue.enqueue_1d_range_kernel(kernel, 0, global_work_size, 0); | |
112 | ||
113 | buffer_iterator<input_type> result = serial_find_extrema( | |
114 | make_buffer_iterator<input_type>(output), | |
115 | make_buffer_iterator<input_type>(output, global_work_size), | |
116 | compare, | |
117 | find_minimum, | |
118 | queue | |
119 | ); | |
120 | ||
121 | uint_* output_idx_host_ptr = | |
122 | static_cast<uint_*>( | |
123 | queue.enqueue_map_buffer( | |
124 | output_idx, command_queue::map_read, | |
125 | 0, global_work_size * sizeof(uint_) | |
126 | ) | |
127 | ); | |
128 | ||
129 | difference_type extremum_idx = | |
130 | static_cast<difference_type>(*(output_idx_host_ptr + result.get_index())); | |
131 | return first + extremum_idx; | |
132 | } | |
133 | ||
134 | } // end detail namespace | |
135 | } // end compute namespace | |
136 | } // end boost namespace | |
137 | ||
138 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_FIND_EXTREMA_ON_CPU_HPP |