]>
Commit | Line | Data |
---|---|---|
7c673cae FG |
1 | //---------------------------------------------------------------------------// |
2 | // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com> | |
3 | // | |
4 | // Distributed under the Boost Software License, Version 1.0 | |
5 | // See accompanying file LICENSE_1_0.txt or copy at | |
6 | // http://www.boost.org/LICENSE_1_0.txt | |
7 | // | |
8 | // See http://boostorg.github.com/compute for more information. | |
9 | //---------------------------------------------------------------------------// | |
10 | ||
11 | #ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP | |
12 | #define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP | |
13 | ||
14 | #include <boost/proto/core.hpp> | |
15 | #include <boost/proto/context.hpp> | |
16 | #include <boost/type_traits.hpp> | |
17 | #include <boost/preprocessor/repetition.hpp> | |
18 | ||
19 | #include <boost/compute/config.hpp> | |
20 | #include <boost/compute/function.hpp> | |
21 | #include <boost/compute/lambda/result_of.hpp> | |
22 | #include <boost/compute/lambda/functional.hpp> | |
23 | #include <boost/compute/type_traits/result_of.hpp> | |
24 | #include <boost/compute/type_traits/type_name.hpp> | |
25 | #include <boost/compute/detail/meta_kernel.hpp> | |
26 | ||
27 | namespace boost { | |
28 | namespace compute { | |
29 | namespace lambda { | |
30 | ||
31 | namespace mpl = boost::mpl; | |
32 | namespace proto = boost::proto; | |
33 | ||
34 | #define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \ | |
35 | template<class LHS, class RHS> \ | |
36 | void operator()(tag, const LHS &lhs, const RHS &rhs) \ | |
37 | { \ | |
38 | if(proto::arity_of<LHS>::value > 0){ \ | |
39 | stream << '('; \ | |
40 | proto::eval(lhs, *this); \ | |
41 | stream << ')'; \ | |
42 | } \ | |
43 | else { \ | |
44 | proto::eval(lhs, *this); \ | |
45 | } \ | |
46 | \ | |
47 | stream << op; \ | |
48 | \ | |
49 | if(proto::arity_of<RHS>::value > 0){ \ | |
50 | stream << '('; \ | |
51 | proto::eval(rhs, *this); \ | |
52 | stream << ')'; \ | |
53 | } \ | |
54 | else { \ | |
55 | proto::eval(rhs, *this); \ | |
56 | } \ | |
57 | } | |
58 | ||
59 | // lambda expression context | |
60 | template<class Args> | |
61 | struct context : proto::callable_context<context<Args> > | |
62 | { | |
63 | typedef void result_type; | |
64 | typedef Args args_tuple; | |
65 | ||
66 | // create a lambda context for kernel with args | |
67 | context(boost::compute::detail::meta_kernel &kernel, const Args &args_) | |
68 | : stream(kernel), | |
69 | args(args_) | |
70 | { | |
71 | } | |
72 | ||
73 | // handle terminals | |
74 | template<class T> | |
75 | void operator()(proto::tag::terminal, const T &x) | |
76 | { | |
77 | // terminal values in lambda expressions are always literals | |
78 | stream << stream.lit(x); | |
79 | } | |
80 | ||
81 | // handle placeholders | |
82 | template<int I> | |
83 | void operator()(proto::tag::terminal, placeholder<I>) | |
84 | { | |
85 | stream << boost::get<I>(args); | |
86 | } | |
87 | ||
88 | // handle functions | |
89 | #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \ | |
90 | BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n) | |
91 | ||
92 | #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \ | |
93 | template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \ | |
94 | void operator()( \ | |
95 | proto::tag::function, \ | |
96 | const F &function, \ | |
97 | BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \ | |
98 | ) \ | |
99 | { \ | |
100 | proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \ | |
101 | } | |
102 | ||
103 | BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~) | |
104 | ||
105 | #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION | |
106 | ||
107 | // operators | |
108 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+') | |
109 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-') | |
110 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*') | |
111 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/') | |
112 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%') | |
113 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<') | |
114 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>') | |
115 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=") | |
116 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=") | |
117 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==") | |
118 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=") | |
119 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&") | |
120 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||") | |
121 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&') | |
122 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|') | |
123 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^') | |
124 | BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=') | |
125 | ||
126 | // subscript operator | |
127 | template<class LHS, class RHS> | |
128 | void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs) | |
129 | { | |
130 | proto::eval(lhs, *this); | |
131 | stream << '['; | |
132 | proto::eval(rhs, *this); | |
133 | stream << ']'; | |
134 | } | |
135 | ||
136 | // ternary conditional operator | |
137 | template<class Pred, class Arg1, class Arg2> | |
138 | void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y) | |
139 | { | |
140 | proto::eval(p, *this); | |
141 | stream << '?'; | |
142 | proto::eval(x, *this); | |
143 | stream << ':'; | |
144 | proto::eval(y, *this); | |
145 | } | |
146 | ||
147 | boost::compute::detail::meta_kernel &stream; | |
148 | Args args; | |
149 | }; | |
150 | ||
151 | namespace detail { | |
152 | ||
153 | template<class Expr, class Arg> | |
154 | struct invoked_unary_expression | |
155 | { | |
156 | typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type; | |
157 | ||
158 | invoked_unary_expression(const Expr &expr, const Arg &arg) | |
159 | : m_expr(expr), | |
160 | m_arg(arg) | |
161 | { | |
162 | } | |
163 | ||
164 | Expr m_expr; | |
165 | Arg m_arg; | |
166 | }; | |
167 | ||
168 | template<class Expr, class Arg> | |
169 | boost::compute::detail::meta_kernel& | |
170 | operator<<(boost::compute::detail::meta_kernel &kernel, | |
171 | const invoked_unary_expression<Expr, Arg> &expr) | |
172 | { | |
173 | context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg)); | |
174 | proto::eval(expr.m_expr, ctx); | |
175 | ||
176 | return kernel; | |
177 | } | |
178 | ||
179 | template<class Expr, class Arg1, class Arg2> | |
180 | struct invoked_binary_expression | |
181 | { | |
182 | typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type; | |
183 | ||
184 | invoked_binary_expression(const Expr &expr, | |
185 | const Arg1 &arg1, | |
186 | const Arg2 &arg2) | |
187 | : m_expr(expr), | |
188 | m_arg1(arg1), | |
189 | m_arg2(arg2) | |
190 | { | |
191 | } | |
192 | ||
193 | Expr m_expr; | |
194 | Arg1 m_arg1; | |
195 | Arg2 m_arg2; | |
196 | }; | |
197 | ||
198 | template<class Expr, class Arg1, class Arg2> | |
199 | boost::compute::detail::meta_kernel& | |
200 | operator<<(boost::compute::detail::meta_kernel &kernel, | |
201 | const invoked_binary_expression<Expr, Arg1, Arg2> &expr) | |
202 | { | |
203 | context<boost::tuple<Arg1, Arg2> > ctx( | |
204 | kernel, | |
205 | boost::make_tuple(expr.m_arg1, expr.m_arg2) | |
206 | ); | |
207 | proto::eval(expr.m_expr, ctx); | |
208 | ||
209 | return kernel; | |
210 | } | |
211 | ||
212 | } // end detail namespace | |
213 | ||
214 | // forward declare domain | |
215 | struct domain; | |
216 | ||
217 | // lambda expression wrapper | |
218 | template<class Expr> | |
219 | struct expression : proto::extends<Expr, expression<Expr>, domain> | |
220 | { | |
221 | typedef proto::extends<Expr, expression<Expr>, domain> base_type; | |
222 | ||
223 | BOOST_PROTO_EXTENDS_USING_ASSIGN(expression) | |
224 | ||
225 | expression(const Expr &expr = Expr()) | |
226 | : base_type(expr) | |
227 | { | |
228 | } | |
229 | ||
230 | // result_of protocol | |
231 | template<class Signature> | |
232 | struct result | |
233 | { | |
234 | }; | |
235 | ||
236 | template<class This> | |
237 | struct result<This()> | |
238 | { | |
239 | typedef | |
240 | typename ::boost::compute::lambda::result_of<Expr>::type type; | |
241 | }; | |
242 | ||
243 | template<class This, class Arg> | |
244 | struct result<This(Arg)> | |
245 | { | |
246 | typedef | |
247 | typename ::boost::compute::lambda::result_of< | |
248 | Expr, | |
249 | typename boost::tuple<Arg> | |
250 | >::type type; | |
251 | }; | |
252 | ||
253 | template<class This, class Arg1, class Arg2> | |
254 | struct result<This(Arg1, Arg2)> | |
255 | { | |
256 | typedef typename | |
257 | ::boost::compute::lambda::result_of< | |
258 | Expr, | |
259 | typename boost::tuple<Arg1, Arg2> | |
260 | >::type type; | |
261 | }; | |
262 | ||
263 | template<class Arg> | |
264 | detail::invoked_unary_expression<expression<Expr>, Arg> | |
265 | operator()(const Arg &x) const | |
266 | { | |
267 | return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x); | |
268 | } | |
269 | ||
270 | template<class Arg1, class Arg2> | |
271 | detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2> | |
272 | operator()(const Arg1 &x, const Arg2 &y) const | |
273 | { | |
274 | return detail::invoked_binary_expression< | |
275 | expression<Expr>, | |
276 | Arg1, | |
277 | Arg2 | |
278 | >(*this, x, y); | |
279 | } | |
280 | ||
281 | // function<> conversion operator | |
282 | template<class R, class A1> | |
283 | operator function<R(A1)>() const | |
284 | { | |
285 | using ::boost::compute::detail::meta_kernel; | |
286 | ||
287 | std::stringstream source; | |
288 | ||
289 | ::boost::compute::detail::meta_kernel_variable<A1> arg1("x"); | |
290 | ||
291 | source << "inline " << type_name<R>() << " lambda" | |
292 | << ::boost::compute::detail::generate_argument_list<R(A1)>('x') | |
293 | << "{\n" | |
294 | << " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n" | |
295 | << "}\n"; | |
296 | ||
297 | return make_function_from_source<R(A1)>("lambda", source.str()); | |
298 | } | |
299 | ||
300 | template<class R, class A1, class A2> | |
301 | operator function<R(A1, A2)>() const | |
302 | { | |
303 | using ::boost::compute::detail::meta_kernel; | |
304 | ||
305 | std::stringstream source; | |
306 | ||
307 | ::boost::compute::detail::meta_kernel_variable<A1> arg1("x"); | |
308 | ::boost::compute::detail::meta_kernel_variable<A1> arg2("y"); | |
309 | ||
310 | source << "inline " << type_name<R>() << " lambda" | |
311 | << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x') | |
312 | << "{\n" | |
313 | << " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n" | |
314 | << "}\n"; | |
315 | ||
316 | return make_function_from_source<R(A1, A2)>("lambda", source.str()); | |
317 | } | |
318 | }; | |
319 | ||
320 | // lambda expression domain | |
321 | struct domain : proto::domain<proto::generator<expression> > | |
322 | { | |
323 | }; | |
324 | ||
325 | } // end lambda namespace | |
326 | } // end compute namespace | |
327 | } // end boost namespace | |
328 | ||
329 | #endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP |