1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
11 #ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
12 #define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
14 #include <boost/proto/core.hpp>
15 #include <boost/proto/context.hpp>
16 #include <boost/type_traits.hpp>
17 #include <boost/preprocessor/repetition.hpp>
19 #include <boost/compute/config.hpp>
20 #include <boost/compute/function.hpp>
21 #include <boost/compute/lambda/result_of.hpp>
22 #include <boost/compute/lambda/functional.hpp>
23 #include <boost/compute/type_traits/result_of.hpp>
24 #include <boost/compute/type_traits/type_name.hpp>
25 #include <boost/compute/detail/meta_kernel.hpp>
31 namespace mpl = boost::mpl;
32 namespace proto = boost::proto;
34 #define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
35 template<class LHS, class RHS> \
36 void operator()(tag, const LHS &lhs, const RHS &rhs) \
38 if(proto::arity_of<LHS>::value > 0){ \
40 proto::eval(lhs, *this); \
44 proto::eval(lhs, *this); \
49 if(proto::arity_of<RHS>::value > 0){ \
51 proto::eval(rhs, *this); \
55 proto::eval(rhs, *this); \
59 // lambda expression context
61 struct context : proto::callable_context<context<Args> >
63 typedef void result_type;
64 typedef Args args_tuple;
66 // create a lambda context for kernel with args
67 context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
75 void operator()(proto::tag::terminal, const T &x)
77 // terminal values in lambda expressions are always literals
78 stream << stream.lit(x);
81 void operator()(proto::tag::terminal, const uchar_ &x)
83 stream << "(uchar)(" << stream.lit(uint_(x)) << "u)";
86 void operator()(proto::tag::terminal, const char_ &x)
88 stream << "(char)(" << stream.lit(int_(x)) << ")";
91 void operator()(proto::tag::terminal, const ushort_ &x)
93 stream << "(ushort)(" << stream.lit(x) << "u)";
96 void operator()(proto::tag::terminal, const short_ &x)
98 stream << "(short)(" << stream.lit(x) << ")";
101 void operator()(proto::tag::terminal, const uint_ &x)
103 stream << "(" << stream.lit(x) << "u)";
106 void operator()(proto::tag::terminal, const ulong_ &x)
108 stream << "(" << stream.lit(x) << "ul)";
111 void operator()(proto::tag::terminal, const long_ &x)
113 stream << "(" << stream.lit(x) << "l)";
116 // handle placeholders
118 void operator()(proto::tag::terminal, placeholder<I>)
120 stream << boost::get<I>(args);
124 #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
125 BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
127 #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
128 template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
130 proto::tag::function, \
132 BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
135 proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
138 BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
140 #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
143 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
144 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
145 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
146 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
147 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
148 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
149 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
150 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
151 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
152 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
153 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
154 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
155 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
156 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
157 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
158 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
159 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
161 // subscript operator
162 template<class LHS, class RHS>
163 void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
165 proto::eval(lhs, *this);
167 proto::eval(rhs, *this);
171 // ternary conditional operator
172 template<class Pred, class Arg1, class Arg2>
173 void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
175 proto::eval(p, *this);
177 proto::eval(x, *this);
179 proto::eval(y, *this);
182 boost::compute::detail::meta_kernel &stream;
188 template<class Expr, class Arg>
189 struct invoked_unary_expression
191 typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
193 invoked_unary_expression(const Expr &expr, const Arg &arg)
203 template<class Expr, class Arg>
204 boost::compute::detail::meta_kernel&
205 operator<<(boost::compute::detail::meta_kernel &kernel,
206 const invoked_unary_expression<Expr, Arg> &expr)
208 context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
209 proto::eval(expr.m_expr, ctx);
214 template<class Expr, class Arg1, class Arg2>
215 struct invoked_binary_expression
217 typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
219 invoked_binary_expression(const Expr &expr,
233 template<class Expr, class Arg1, class Arg2>
234 boost::compute::detail::meta_kernel&
235 operator<<(boost::compute::detail::meta_kernel &kernel,
236 const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
238 context<boost::tuple<Arg1, Arg2> > ctx(
240 boost::make_tuple(expr.m_arg1, expr.m_arg2)
242 proto::eval(expr.m_expr, ctx);
247 } // end detail namespace
249 // forward declare domain
252 // lambda expression wrapper
254 struct expression : proto::extends<Expr, expression<Expr>, domain>
256 typedef proto::extends<Expr, expression<Expr>, domain> base_type;
258 BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
260 expression(const Expr &expr = Expr())
265 // result_of protocol
266 template<class Signature>
272 struct result<This()>
275 typename ::boost::compute::lambda::result_of<Expr>::type type;
278 template<class This, class Arg>
279 struct result<This(Arg)>
282 typename ::boost::compute::lambda::result_of<
284 typename boost::tuple<Arg>
288 template<class This, class Arg1, class Arg2>
289 struct result<This(Arg1, Arg2)>
292 ::boost::compute::lambda::result_of<
294 typename boost::tuple<Arg1, Arg2>
299 detail::invoked_unary_expression<expression<Expr>, Arg>
300 operator()(const Arg &x) const
302 return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
305 template<class Arg1, class Arg2>
306 detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
307 operator()(const Arg1 &x, const Arg2 &y) const
309 return detail::invoked_binary_expression<
316 // function<> conversion operator
317 template<class R, class A1>
318 operator function<R(A1)>() const
320 using ::boost::compute::detail::meta_kernel;
322 std::stringstream source;
324 ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
326 source << "inline " << type_name<R>() << " lambda"
327 << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
329 << " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
332 return make_function_from_source<R(A1)>("lambda", source.str());
335 template<class R, class A1, class A2>
336 operator function<R(A1, A2)>() const
338 using ::boost::compute::detail::meta_kernel;
340 std::stringstream source;
342 ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
343 ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
345 source << "inline " << type_name<R>() << " lambda"
346 << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
348 << " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
351 return make_function_from_source<R(A1, A2)>("lambda", source.str());
355 // lambda expression domain
356 struct domain : proto::domain<proto::generator<expression> >
360 } // end lambda namespace
361 } // end compute namespace
362 } // end boost namespace
364 #endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP