]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013-2015 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
13 | ||
14 | #include <boost/compute/cl.hpp> | |
15 | #include <boost/compute/system.hpp> | |
16 | #include <boost/compute/command_queue.hpp> | |
17 | #include <boost/compute/algorithm/count.hpp> | |
18 | #include <boost/compute/algorithm/count_if.hpp> | |
19 | #include <boost/compute/algorithm/exclusive_scan.hpp> | |
20 | #include <boost/compute/container/vector.hpp> | |
21 | #include <boost/compute/detail/meta_kernel.hpp> | |
22 | #include <boost/compute/detail/iterator_range_size.hpp> | |
23 | #include <boost/compute/iterator/discard_iterator.hpp> | |
24 | ||
25 | namespace boost { | |
26 | namespace compute { | |
27 | namespace detail { | |
28 | ||
b32b8144 | 29 | // Space complexity: O(2n) |
7c673cae FG |
30 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> |
31 | inline OutputIterator transform_if_impl(InputIterator first, | |
32 | InputIterator last, | |
33 | OutputIterator result, | |
34 | UnaryFunction function, | |
35 | Predicate predicate, | |
36 | bool copyIndex, | |
37 | command_queue &queue) | |
38 | { | |
39 | typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type; | |
40 | ||
41 | size_t count = detail::iterator_range_size(first, last); | |
42 | if(count == 0){ | |
43 | return result; | |
44 | } | |
45 | ||
46 | const context &context = queue.get_context(); | |
47 | ||
48 | // storage for destination indices | |
49 | ::boost::compute::vector<cl_uint> indices(count, context); | |
50 | ||
51 | // write counts | |
52 | ::boost::compute::detail::meta_kernel k1("transform_if_write_counts"); | |
53 | k1 << indices.begin()[k1.get_global_id(0)] << " = " | |
54 | << predicate(first[k1.get_global_id(0)]) << " ? 1 : 0;\n"; | |
55 | k1.exec_1d(queue, 0, count); | |
56 | ||
7c673cae | 57 | // scan indices |
b32b8144 | 58 | size_t copied_element_count = (indices.cend() - 1).read(queue); |
7c673cae FG |
59 | ::boost::compute::exclusive_scan( |
60 | indices.begin(), indices.end(), indices.begin(), queue | |
61 | ); | |
b32b8144 | 62 | copied_element_count += (indices.cend() - 1).read(queue); // last scan element plus last mask element |
7c673cae FG |
63 | |
64 | // copy values | |
65 | ::boost::compute::detail::meta_kernel k2("transform_if_do_copy"); | |
66 | k2 << "if(" << predicate(first[k2.get_global_id(0)]) << ")" << | |
67 | " " << result[indices.begin()[k2.get_global_id(0)]] << "="; | |
68 | ||
69 | if(copyIndex){ | |
70 | k2 << k2.get_global_id(0) << ";\n"; | |
71 | } | |
72 | else { | |
73 | k2 << function(first[k2.get_global_id(0)]) << ";\n"; | |
74 | } | |
75 | ||
76 | k2.exec_1d(queue, 0, count); | |
77 | ||
78 | return result + static_cast<difference_type>(copied_element_count); | |
79 | } | |
80 | ||
81 | template<class InputIterator, class UnaryFunction, class Predicate> | |
82 | inline discard_iterator transform_if_impl(InputIterator first, | |
83 | InputIterator last, | |
84 | discard_iterator result, | |
85 | UnaryFunction function, | |
86 | Predicate predicate, | |
87 | bool copyIndex, | |
88 | command_queue &queue) | |
89 | { | |
90 | (void) function; | |
91 | (void) copyIndex; | |
92 | ||
93 | return result + count_if(first, last, predicate, queue); | |
94 | } | |
95 | ||
96 | } // end detail namespace | |
97 | ||
98 | /// Copies each element in the range [\p first, \p last) for which | |
99 | /// \p predicate returns \c true to the range beginning at \p result. | |
b32b8144 FG |
100 | /// |
101 | /// Space complexity: O(2n) | |
7c673cae FG |
102 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> |
103 | inline OutputIterator transform_if(InputIterator first, | |
104 | InputIterator last, | |
105 | OutputIterator result, | |
106 | UnaryFunction function, | |
107 | Predicate predicate, | |
108 | command_queue &queue = system::default_queue()) | |
109 | { | |
110 | return detail::transform_if_impl( | |
111 | first, last, result, function, predicate, false, queue | |
112 | ); | |
113 | } | |
114 | ||
115 | } // end compute namespace | |
116 | } // end boost namespace | |
117 | ||
118 | #endif // BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP |