]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013-2015 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP | |
13 | ||
14 | #include <boost/compute/cl.hpp> | |
15 | #include <boost/compute/system.hpp> | |
16 | #include <boost/compute/command_queue.hpp> | |
17 | #include <boost/compute/algorithm/count.hpp> | |
18 | #include <boost/compute/algorithm/count_if.hpp> | |
19 | #include <boost/compute/algorithm/exclusive_scan.hpp> | |
20 | #include <boost/compute/container/vector.hpp> | |
21 | #include <boost/compute/detail/meta_kernel.hpp> | |
22 | #include <boost/compute/detail/iterator_range_size.hpp> | |
23 | #include <boost/compute/iterator/discard_iterator.hpp> | |
24 | ||
25 | namespace boost { | |
26 | namespace compute { | |
27 | namespace detail { | |
28 | ||
29 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> | |
30 | inline OutputIterator transform_if_impl(InputIterator first, | |
31 | InputIterator last, | |
32 | OutputIterator result, | |
33 | UnaryFunction function, | |
34 | Predicate predicate, | |
35 | bool copyIndex, | |
36 | command_queue &queue) | |
37 | { | |
38 | typedef typename std::iterator_traits<OutputIterator>::difference_type difference_type; | |
39 | ||
40 | size_t count = detail::iterator_range_size(first, last); | |
41 | if(count == 0){ | |
42 | return result; | |
43 | } | |
44 | ||
45 | const context &context = queue.get_context(); | |
46 | ||
47 | // storage for destination indices | |
48 | ::boost::compute::vector<cl_uint> indices(count, context); | |
49 | ||
50 | // write counts | |
51 | ::boost::compute::detail::meta_kernel k1("transform_if_write_counts"); | |
52 | k1 << indices.begin()[k1.get_global_id(0)] << " = " | |
53 | << predicate(first[k1.get_global_id(0)]) << " ? 1 : 0;\n"; | |
54 | k1.exec_1d(queue, 0, count); | |
55 | ||
56 | // count number of elements to be copied | |
57 | size_t copied_element_count = | |
58 | ::boost::compute::count(indices.begin(), indices.end(), 1, queue); | |
59 | ||
60 | // scan indices | |
61 | ::boost::compute::exclusive_scan( | |
62 | indices.begin(), indices.end(), indices.begin(), queue | |
63 | ); | |
64 | ||
65 | // copy values | |
66 | ::boost::compute::detail::meta_kernel k2("transform_if_do_copy"); | |
67 | k2 << "if(" << predicate(first[k2.get_global_id(0)]) << ")" << | |
68 | " " << result[indices.begin()[k2.get_global_id(0)]] << "="; | |
69 | ||
70 | if(copyIndex){ | |
71 | k2 << k2.get_global_id(0) << ";\n"; | |
72 | } | |
73 | else { | |
74 | k2 << function(first[k2.get_global_id(0)]) << ";\n"; | |
75 | } | |
76 | ||
77 | k2.exec_1d(queue, 0, count); | |
78 | ||
79 | return result + static_cast<difference_type>(copied_element_count); | |
80 | } | |
81 | ||
82 | template<class InputIterator, class UnaryFunction, class Predicate> | |
83 | inline discard_iterator transform_if_impl(InputIterator first, | |
84 | InputIterator last, | |
85 | discard_iterator result, | |
86 | UnaryFunction function, | |
87 | Predicate predicate, | |
88 | bool copyIndex, | |
89 | command_queue &queue) | |
90 | { | |
91 | (void) function; | |
92 | (void) copyIndex; | |
93 | ||
94 | return result + count_if(first, last, predicate, queue); | |
95 | } | |
96 | ||
97 | } // end detail namespace | |
98 | ||
99 | /// Copies each element in the range [\p first, \p last) for which | |
100 | /// \p predicate returns \c true to the range beginning at \p result. | |
101 | template<class InputIterator, class OutputIterator, class UnaryFunction, class Predicate> | |
102 | inline OutputIterator transform_if(InputIterator first, | |
103 | InputIterator last, | |
104 | OutputIterator result, | |
105 | UnaryFunction function, | |
106 | Predicate predicate, | |
107 | command_queue &queue = system::default_queue()) | |
108 | { | |
109 | return detail::transform_if_impl( | |
110 | first, last, result, function, predicate, false, queue | |
111 | ); | |
112 | } | |
113 | ||
114 | } // end compute namespace | |
115 | } // end boost namespace | |
116 | ||
117 | #endif // BOOST_COMPUTE_ALGORITHM_TRANSFORM_IF_HPP |