]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2015 Jakub Szuppe <j.szuppe@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP | |
12 | #define BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP | |
13 | ||
14 | #include <iterator> | |
15 | ||
16 | #include <boost/compute/command_queue.hpp> | |
17 | #include <boost/compute/functional.hpp> | |
18 | #include <boost/compute/container/vector.hpp> | |
19 | #include <boost/compute/container/detail/scalar.hpp> | |
20 | #include <boost/compute/detail/meta_kernel.hpp> | |
21 | #include <boost/compute/detail/iterator_range_size.hpp> | |
22 | #include <boost/compute/type_traits/result_of.hpp> | |
23 | ||
24 | namespace boost { | |
25 | namespace compute { | |
26 | namespace detail { | |
27 | ||
28 | template<class InputKeyIterator, class InputValueIterator, | |
29 | class OutputKeyIterator, class OutputValueIterator, | |
30 | class BinaryFunction, class BinaryPredicate> | |
31 | inline size_t serial_reduce_by_key(InputKeyIterator keys_first, | |
32 | InputKeyIterator keys_last, | |
33 | InputValueIterator values_first, | |
34 | OutputKeyIterator keys_result, | |
35 | OutputValueIterator values_result, | |
36 | BinaryFunction function, | |
37 | BinaryPredicate predicate, | |
38 | command_queue &queue) | |
39 | { | |
40 | typedef typename | |
41 | std::iterator_traits<InputValueIterator>::value_type value_type; | |
42 | typedef typename | |
43 | std::iterator_traits<InputKeyIterator>::value_type key_type; | |
44 | typedef typename | |
45 | ::boost::compute::result_of<BinaryFunction(value_type, value_type)>::type result_type; | |
46 | ||
47 | const context &context = queue.get_context(); | |
48 | size_t count = detail::iterator_range_size(keys_first, keys_last); | |
49 | if(count < 1){ | |
50 | return count; | |
51 | } | |
52 | ||
53 | meta_kernel k("serial_reduce_by_key"); | |
54 | size_t count_arg = k.add_arg<uint_>("count"); | |
55 | size_t result_size_arg = k.add_arg<uint_ *>(memory_object::global_memory, | |
56 | "result_size"); | |
57 | ||
58 | convert<result_type> to_result_type; | |
59 | ||
60 | k << | |
61 | k.decl<result_type>("result") << | |
62 | " = " << to_result_type(values_first[0]) << ";\n" << | |
63 | k.decl<key_type>("previous_key") << " = " << keys_first[0] << ";\n" << | |
64 | k.decl<result_type>("value") << ";\n" << | |
65 | k.decl<key_type>("key") << ";\n" << | |
66 | ||
67 | k.decl<uint_>("size") << " = 1;\n" << | |
68 | ||
69 | keys_result[0] << " = previous_key;\n" << | |
70 | values_result[0] << " = result;\n" << | |
71 | ||
72 | "for(ulong i = 1; i < count; i++) {\n" << | |
73 | " value = " << to_result_type(values_first[k.var<uint_>("i")]) << ";\n" << | |
74 | " key = " << keys_first[k.var<uint_>("i")] << ";\n" << | |
75 | " if (" << predicate(k.var<key_type>("previous_key"), | |
76 | k.var<key_type>("key")) << ") {\n" << | |
77 | ||
78 | " result = " << function(k.var<result_type>("result"), | |
79 | k.var<result_type>("value")) << ";\n" << | |
80 | " }\n " << | |
81 | " else { \n" << | |
82 | keys_result[k.var<uint_>("size - 1")] << " = previous_key;\n" << | |
83 | values_result[k.var<uint_>("size - 1")] << " = result;\n" << | |
84 | " result = value;\n" << | |
85 | " size++;\n" << | |
86 | " } \n" << | |
87 | " previous_key = key;\n" << | |
88 | "}\n" << | |
89 | keys_result[k.var<uint_>("size - 1")] << " = previous_key;\n" << | |
90 | values_result[k.var<uint_>("size - 1")] << " = result;\n" << | |
91 | "*result_size = size;"; | |
92 | ||
93 | kernel kernel = k.compile(context); | |
94 | ||
95 | scalar<uint_> result_size(context); | |
96 | kernel.set_arg(result_size_arg, result_size.get_buffer()); | |
97 | kernel.set_arg(count_arg, static_cast<uint_>(count)); | |
98 | ||
99 | queue.enqueue_task(kernel); | |
100 | ||
101 | return static_cast<size_t>(result_size.read(queue)); | |
102 | } | |
103 | ||
104 | } // end detail namespace | |
105 | } // end compute namespace | |
106 | } // end boost namespace | |
107 | ||
108 | #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_SERIAL_REDUCE_BY_KEY_HPP |