]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2014 Roshan <thisisroshansmail@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP | |
13 | ||
14 | #include <boost/compute/functional.hpp> | |
15 | #include <boost/compute/algorithm/find_if.hpp> | |
16 | #include <boost/compute/algorithm/transform.hpp> | |
17 | #include <boost/compute/command_queue.hpp> | |
18 | #include <boost/compute/detail/parameter_cache.hpp> | |
19 | ||
20 | namespace boost { | |
21 | namespace compute { | |
22 | namespace detail{ | |
23 | ||
24 | /// | |
25 | /// \brief Binary find kernel class | |
26 | /// | |
27 | /// Subclass of meta_kernel to perform single step in binary find. | |
28 | /// | |
29 | template<class InputIterator, class UnaryPredicate> | |
30 | class binary_find_kernel : public meta_kernel | |
31 | { | |
32 | public: | |
33 | binary_find_kernel(InputIterator first, | |
34 | InputIterator last, | |
35 | UnaryPredicate predicate) | |
36 | : meta_kernel("binary_find") | |
37 | { | |
38 | typedef typename std::iterator_traits<InputIterator>::value_type value_type; | |
39 | ||
40 | m_index_arg = add_arg<uint_ *>(memory_object::global_memory, "index"); | |
41 | m_block_arg = add_arg<uint_>("block"); | |
42 | ||
43 | atomic_min<uint_> atomic_min_uint; | |
44 | ||
45 | *this << | |
46 | "uint i = get_global_id(0) * block;\n" << | |
47 | decl<value_type>("value") << "=" << first[var<uint_>("i")] << ";\n" << | |
48 | "if(" << predicate(var<value_type>("value")) << ") {\n" << | |
49 | atomic_min_uint(var<uint_ *>("index"), var<uint_>("i")) << ";\n" << | |
50 | "}\n"; | |
51 | } | |
52 | ||
53 | size_t m_index_arg; | |
54 | size_t m_block_arg; | |
55 | }; | |
56 | ||
57 | /// | |
58 | /// \brief Binary find algorithm | |
59 | /// | |
60 | /// Finds the end of true values in the partitioned range [first, last). | |
61 | /// \return Iterator pointing to end of true values | |
62 | /// | |
63 | /// \param first Iterator pointing to start of range | |
64 | /// \param last Iterator pointing to end of range | |
65 | /// \param predicate Predicate according to which the range is partitioned | |
66 | /// \param queue Queue on which to execute | |
67 | /// | |
68 | template<class InputIterator, class UnaryPredicate> | |
69 | inline InputIterator binary_find(InputIterator first, | |
70 | InputIterator last, | |
71 | UnaryPredicate predicate, | |
72 | command_queue &queue = system::default_queue()) | |
73 | { | |
74 | const device &device = queue.get_device(); | |
75 | ||
76 | boost::shared_ptr<parameter_cache> parameters = | |
77 | detail::parameter_cache::get_global_cache(device); | |
78 | ||
79 | const std::string cache_key = "__boost_binary_find"; | |
80 | ||
81 | size_t find_if_limit = 128; | |
82 | size_t threads = parameters->get(cache_key, "tpb", 128); | |
83 | size_t count = iterator_range_size(first, last); | |
84 | ||
85 | InputIterator search_first = first; | |
86 | InputIterator search_last = last; | |
87 | ||
88 | scalar<uint_> index(queue.get_context()); | |
89 | ||
90 | // construct and compile binary_find kernel | |
91 | binary_find_kernel<InputIterator, UnaryPredicate> | |
92 | binary_find_kernel(search_first, search_last, predicate); | |
93 | ::boost::compute::kernel kernel = binary_find_kernel.compile(queue.get_context()); | |
94 | ||
95 | // set buffer for index | |
96 | kernel.set_arg(binary_find_kernel.m_index_arg, index.get_buffer()); | |
97 | ||
98 | while(count > find_if_limit) { | |
99 | index.write(static_cast<uint_>(count), queue); | |
100 | ||
101 | // set block and run binary_find kernel | |
102 | uint_ block = static_cast<uint_>((count - 1)/(threads - 1)); | |
103 | kernel.set_arg(binary_find_kernel.m_block_arg, block); | |
104 | queue.enqueue_1d_range_kernel(kernel, 0, threads, 0); | |
105 | ||
106 | size_t i = index.read(queue); | |
107 | ||
108 | if(i == count) { | |
109 | search_first = search_last - ((count - 1)%(threads - 1)); | |
110 | break; | |
111 | } else { | |
112 | search_last = search_first + i; | |
113 | search_first = search_last - ((count - 1)/(threads - 1)); | |
114 | } | |
115 | ||
116 | // Make sure that first and last stay within the input range | |
117 | search_last = (std::min)(search_last, last); | |
118 | search_last = (std::max)(search_last, first); | |
119 | ||
120 | search_first = (std::max)(search_first, first); | |
121 | search_first = (std::min)(search_first, last); | |
122 | ||
123 | count = iterator_range_size(search_first, search_last); | |
124 | } | |
125 | ||
126 | return find_if(search_first, search_last, predicate, queue); | |
127 | } | |
128 | ||
129 | } // end detail namespace | |
130 | } // end compute namespace | |
131 | } // end boost namespace | |
132 | ||
133 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_BINARY_FIND_HPP |