]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013-2015 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
13 | ||
92f5a8d4 TL |
14 | #include <boost/static_assert.hpp> |
15 | ||
7c673cae FG |
16 | #include <boost/compute/cl.hpp> |
17 | #include <boost/compute/system.hpp> | |
18 | #include <boost/compute/command_queue.hpp> | |
19 | #include <boost/compute/algorithm/count.hpp> | |
20 | #include <boost/compute/algorithm/count_if.hpp> | |
21 | #include <boost/compute/algorithm/exclusive_scan.hpp> | |
22 | #include <boost/compute/container/vector.hpp> | |
23 | #include <boost/compute/detail/meta_kernel.hpp> | |
24 | #include <boost/compute/detail/iterator_range_size.hpp> | |
25 | #include <boost/compute/iterator/discard_iterator.hpp> | |
92f5a8d4 | 26 | #include <boost/compute/type_traits/is_device_iterator.hpp> |
7c673cae FG |
27 | |
28 | namespace boost { | |
29 | namespace compute { | |
30 | namespace detail { | |
31 | ||
b32b8144 | 32 | // Space complexity: O(2n) |
7c673cae FG |
33 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> |
34 | inline OutputIterator transform_if_impl(InputIterator first, | |
35 | InputIterator last, | |
36 | OutputIterator result, | |
37 | UnaryFunction function, | |
38 | Predicate predicate, | |
39 | bool copyIndex, | |
40 | command_queue &queue) | |
41 | { | |
42 | typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type; | |
43 | ||
44 | size_t count = detail::iterator_range_size(first, last); | |
45 | if(count == 0){ | |
46 | return result; | |
47 | } | |
48 | ||
49 | const context &context = queue.get_context(); | |
50 | ||
51 | // storage for destination indices | |
52 | ::boost::compute::vector<cl_uint> indices(count, context); | |
53 | ||
54 | // write counts | |
55 | ::boost::compute::detail::meta_kernel k1("transform_if_write_counts"); | |
56 | k1 << indices.begin()[k1.get_global_id(0)] << " = " | |
57 | << predicate(first[k1.get_global_id(0)]) << " ? 1 : 0;\n"; | |
58 | k1.exec_1d(queue, 0, count); | |
59 | ||
7c673cae | 60 | // scan indices |
b32b8144 | 61 | size_t copied_element_count = (indices.cend() - 1).read(queue); |
7c673cae FG |
62 | ::boost::compute::exclusive_scan( |
63 | indices.begin(), indices.end(), indices.begin(), queue | |
64 | ); | |
b32b8144 | 65 | copied_element_count += (indices.cend() - 1).read(queue); // last scan element plus last mask element |
7c673cae FG |
66 | |
67 | // copy values | |
68 | ::boost::compute::detail::meta_kernel k2("transform_if_do_copy"); | |
69 | k2 << "if(" << predicate(first[k2.get_global_id(0)]) << ")" << | |
70 | " " << result[indices.begin()[k2.get_global_id(0)]] << "="; | |
71 | ||
72 | if(copyIndex){ | |
73 | k2 << k2.get_global_id(0) << ";\n"; | |
74 | } | |
75 | else { | |
76 | k2 << function(first[k2.get_global_id(0)]) << ";\n"; | |
77 | } | |
78 | ||
79 | k2.exec_1d(queue, 0, count); | |
80 | ||
81 | return result + static_cast<difference_type>(copied_element_count); | |
82 | } | |
83 | ||
84 | template<class InputIterator, class UnaryFunction, class Predicate> | |
85 | inline discard_iterator transform_if_impl(InputIterator first, | |
86 | InputIterator last, | |
87 | discard_iterator result, | |
88 | UnaryFunction function, | |
89 | Predicate predicate, | |
90 | bool copyIndex, | |
91 | command_queue &queue) | |
92 | { | |
93 | (void) function; | |
94 | (void) copyIndex; | |
95 | ||
96 | return result + count_if(first, last, predicate, queue); | |
97 | } | |
98 | ||
99 | } // end detail namespace | |
100 | ||
101 | /// Copies each element in the range [\p first, \p last) for which | |
102 | /// \p predicate returns \c true to the range beginning at \p result. | |
b32b8144 FG |
103 | /// |
104 | /// Space complexity: O(2n) | |
7c673cae FG |
105 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> |
106 | inline OutputIterator transform_if(InputIterator first, | |
107 | InputIterator last, | |
108 | OutputIterator result, | |
109 | UnaryFunction function, | |
110 | Predicate predicate, | |
111 | command_queue &queue = system::default_queue()) | |
112 | { | |
92f5a8d4 TL |
113 | BOOST_STATIC_ASSERT(is_device_iterator<InputIterator>::value); |
114 | BOOST_STATIC_ASSERT(is_device_iterator<OutputIterator>::value); | |
7c673cae FG |
115 | return detail::transform_if_impl( |
116 | first, last, result, function, predicate, false, queue | |
117 | ); | |
118 | } | |
119 | ||
120 | } // end compute namespace | |
121 | } // end boost namespace | |
122 | ||
123 | #endif // BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP |