]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP | |
13 | ||
14 | #include <boost/compute/kernel.hpp> | |
15 | #include <boost/compute/program.hpp> | |
16 | #include <boost/compute/command_queue.hpp> | |
17 | #include <boost/compute/detail/meta_kernel.hpp> | |
18 | #include <boost/compute/detail/iterator_range_size.hpp> | |
19 | #include <boost/compute/memory/local_buffer.hpp> | |
20 | ||
21 | namespace boost { | |
22 | namespace compute { | |
23 | namespace detail { | |
24 | ||
25 | template<class Iterator, class Compare> | |
26 | inline void serial_insertion_sort(Iterator first, | |
27 | Iterator last, | |
28 | Compare compare, | |
29 | command_queue &queue) | |
30 | { | |
31 | typedef typename std::iterator_traits<Iterator>::value_type T; | |
32 | ||
33 | size_t count = iterator_range_size(first, last); | |
34 | if(count < 2){ | |
35 | return; | |
36 | } | |
37 | ||
38 | meta_kernel k("serial_insertion_sort"); | |
39 | size_t local_data_arg = k.add_arg<T *>(memory_object::local_memory, "data"); | |
40 | size_t count_arg = k.add_arg<uint_>("n"); | |
41 | ||
42 | k << | |
43 | // copy data to local memory | |
44 | "for(uint i = 0; i < n; i++){\n" << | |
45 | " data[i] = " << first[k.var<uint_>("i")] << ";\n" | |
46 | "}\n" | |
47 | ||
48 | // sort data in local memory | |
49 | "for(uint i = 1; i < n; i++){\n" << | |
50 | " " << k.decl<const T>("value") << " = data[i];\n" << | |
51 | " uint pos = i;\n" << | |
52 | " while(pos > 0 && " << | |
53 | compare(k.var<const T>("value"), | |
54 | k.var<const T>("data[pos-1]")) << "){\n" << | |
55 | " data[pos] = data[pos-1];\n" << | |
56 | " pos--;\n" << | |
57 | " }\n" << | |
58 | " data[pos] = value;\n" << | |
59 | "}\n" << | |
60 | ||
61 | // copy sorted data to output | |
62 | "for(uint i = 0; i < n; i++){\n" << | |
63 | " " << first[k.var<uint_>("i")] << " = data[i];\n" | |
64 | "}\n"; | |
65 | ||
66 | const context &context = queue.get_context(); | |
67 | ::boost::compute::kernel kernel = k.compile(context); | |
68 | kernel.set_arg(local_data_arg, local_buffer<T>(count)); | |
69 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
70 | ||
71 | queue.enqueue_task(kernel); | |
72 | } | |
73 | ||
74 | template<class Iterator> | |
75 | inline void serial_insertion_sort(Iterator first, | |
76 | Iterator last, | |
77 | command_queue &queue) | |
78 | { | |
79 | typedef typename std::iterator_traits<Iterator>::value_type T; | |
80 | ||
81 | ::boost::compute::less<T> less; | |
82 | ||
83 | return serial_insertion_sort(first, last, less, queue); | |
84 | } | |
85 | ||
86 | template<class KeyIterator, class ValueIterator, class Compare> | |
87 | inline void serial_insertion_sort_by_key(KeyIterator keys_first, | |
88 | KeyIterator keys_last, | |
89 | ValueIterator values_first, | |
90 | Compare compare, | |
91 | command_queue &queue) | |
92 | { | |
93 | typedef typename std::iterator_traits<KeyIterator>::value_type key_type; | |
94 | typedef typename std::iterator_traits<ValueIterator>::value_type value_type; | |
95 | ||
96 | size_t count = iterator_range_size(keys_first, keys_last); | |
97 | if(count < 2){ | |
98 | return; | |
99 | } | |
100 | ||
101 | meta_kernel k("serial_insertion_sort_by_key"); | |
102 | size_t local_keys_arg = k.add_arg<key_type *>(memory_object::local_memory, "keys"); | |
103 | size_t local_data_arg = k.add_arg<value_type *>(memory_object::local_memory, "data"); | |
104 | size_t count_arg = k.add_arg<uint_>("n"); | |
105 | ||
106 | k << | |
107 | // copy data to local memory | |
108 | "for(uint i = 0; i < n; i++){\n" << | |
109 | " keys[i] = " << keys_first[k.var<uint_>("i")] << ";\n" | |
110 | " data[i] = " << values_first[k.var<uint_>("i")] << ";\n" | |
111 | "}\n" | |
112 | ||
113 | // sort data in local memory | |
114 | "for(uint i = 1; i < n; i++){\n" << | |
115 | " " << k.decl<const key_type>("key") << " = keys[i];\n" << | |
116 | " " << k.decl<const value_type>("value") << " = data[i];\n" << | |
117 | " uint pos = i;\n" << | |
118 | " while(pos > 0 && " << | |
119 | compare(k.var<const key_type>("key"), | |
120 | k.var<const key_type>("keys[pos-1]")) << "){\n" << | |
121 | " keys[pos] = keys[pos-1];\n" << | |
122 | " data[pos] = data[pos-1];\n" << | |
123 | " pos--;\n" << | |
124 | " }\n" << | |
125 | " keys[pos] = key;\n" << | |
126 | " data[pos] = value;\n" << | |
127 | "}\n" << | |
128 | ||
129 | // copy sorted data to output | |
130 | "for(uint i = 0; i < n; i++){\n" << | |
131 | " " << keys_first[k.var<uint_>("i")] << " = keys[i];\n" | |
132 | " " << values_first[k.var<uint_>("i")] << " = data[i];\n" | |
133 | "}\n"; | |
134 | ||
135 | const context &context = queue.get_context(); | |
136 | ::boost::compute::kernel kernel = k.compile(context); | |
137 | kernel.set_arg(local_keys_arg, static_cast<uint_>(count * sizeof(key_type)), 0); | |
138 | kernel.set_arg(local_data_arg, static_cast<uint_>(count * sizeof(value_type)), 0); | |
139 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
140 | ||
141 | queue.enqueue_task(kernel); | |
142 | } | |
143 | ||
144 | template<class KeyIterator, class ValueIterator> | |
145 | inline void serial_insertion_sort_by_key(KeyIterator keys_first, | |
146 | KeyIterator keys_last, | |
147 | ValueIterator values_first, | |
148 | command_queue &queue) | |
149 | { | |
150 | typedef typename std::iterator_traits<KeyIterator>::value_type key_type; | |
151 | ||
152 | serial_insertion_sort_by_key( | |
153 | keys_first, | |
154 | keys_last, | |
155 | values_first, | |
156 | boost::compute::less<key_type>(), | |
157 | queue | |
158 | ); | |
159 | } | |
160 | ||
161 | } // end detail namespace | |
162 | } // end compute namespace | |
163 | } // end boost namespace | |
164 | ||
165 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_INSERTION_SORT_HPP |