1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
11 #ifndef BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
12 #define BOOST_COMPUTE_LAMBDA_CONTEXT_HPP
14 #include <boost/proto/core.hpp>
15 #include <boost/proto/context.hpp>
16 #include <boost/type_traits.hpp>
17 #include <boost/preprocessor/repetition.hpp>
19 #include <boost/compute/config.hpp>
20 #include <boost/compute/function.hpp>
21 #include <boost/compute/lambda/result_of.hpp>
22 #include <boost/compute/lambda/functional.hpp>
23 #include <boost/compute/type_traits/result_of.hpp>
24 #include <boost/compute/type_traits/type_name.hpp>
25 #include <boost/compute/detail/meta_kernel.hpp>
31 namespace mpl = boost::mpl;
32 namespace proto = boost::proto;
34 #define BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(tag, op) \
35 template<class LHS, class RHS> \
36 void operator()(tag, const LHS &lhs, const RHS &rhs) \
38 if(proto::arity_of<LHS>::value > 0){ \
40 proto::eval(lhs, *this); \
44 proto::eval(lhs, *this); \
49 if(proto::arity_of<RHS>::value > 0){ \
51 proto::eval(rhs, *this); \
55 proto::eval(rhs, *this); \
59 // lambda expression context
61 struct context : proto::callable_context<context<Args> >
63 typedef void result_type;
64 typedef Args args_tuple;
66 // create a lambda context for kernel with args
67 context(boost::compute::detail::meta_kernel &kernel, const Args &args_)
75 void operator()(proto::tag::terminal, const T &x)
77 // terminal values in lambda expressions are always literals
78 stream << stream.lit(x);
81 // handle placeholders
83 void operator()(proto::tag::terminal, placeholder<I>)
85 stream << boost::get<I>(args);
89 #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG(z, n, unused) \
90 BOOST_PP_COMMA_IF(n) BOOST_PP_CAT(const Arg, n) BOOST_PP_CAT(&arg, n)
92 #define BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION(z, n, unused) \
93 template<class F, BOOST_PP_ENUM_PARAMS(n, class Arg)> \
95 proto::tag::function, \
97 BOOST_PP_REPEAT(n, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION_ARG, ~) \
100 proto::value(function).apply(*this, BOOST_PP_ENUM_PARAMS(n, arg)); \
103 BOOST_PP_REPEAT_FROM_TO(1, BOOST_COMPUTE_MAX_ARITY, BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION, ~)
105 #undef BOOST_COMPUTE_LAMBDA_CONTEXT_FUNCTION
108 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::plus, '+')
109 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::minus, '-')
110 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::multiplies, '*')
111 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::divides, '/')
112 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::modulus, '%')
113 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less, '<')
114 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater, '>')
115 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::less_equal, "<=")
116 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::greater_equal, ">=")
117 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::equal_to, "==")
118 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::not_equal_to, "!=")
119 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_and, "&&")
120 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::logical_or, "||")
121 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_and, '&')
122 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_or, '|')
123 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::bitwise_xor, '^')
124 BOOST_COMPUTE_LAMBDA_CONTEXT_DEFINE_BINARY_OPERATOR(proto::tag::assign, '=')
126 // subscript operator
127 template<class LHS, class RHS>
128 void operator()(proto::tag::subscript, const LHS &lhs, const RHS &rhs)
130 proto::eval(lhs, *this);
132 proto::eval(rhs, *this);
136 // ternary conditional operator
137 template<class Pred, class Arg1, class Arg2>
138 void operator()(proto::tag::if_else_, const Pred &p, const Arg1 &x, const Arg2 &y)
140 proto::eval(p, *this);
142 proto::eval(x, *this);
144 proto::eval(y, *this);
147 boost::compute::detail::meta_kernel &stream;
153 template<class Expr, class Arg>
154 struct invoked_unary_expression
156 typedef typename ::boost::compute::result_of<Expr(Arg)>::type result_type;
158 invoked_unary_expression(const Expr &expr, const Arg &arg)
168 template<class Expr, class Arg>
169 boost::compute::detail::meta_kernel&
170 operator<<(boost::compute::detail::meta_kernel &kernel,
171 const invoked_unary_expression<Expr, Arg> &expr)
173 context<boost::tuple<Arg> > ctx(kernel, boost::make_tuple(expr.m_arg));
174 proto::eval(expr.m_expr, ctx);
179 template<class Expr, class Arg1, class Arg2>
180 struct invoked_binary_expression
182 typedef typename ::boost::compute::result_of<Expr(Arg1, Arg2)>::type result_type;
184 invoked_binary_expression(const Expr &expr,
198 template<class Expr, class Arg1, class Arg2>
199 boost::compute::detail::meta_kernel&
200 operator<<(boost::compute::detail::meta_kernel &kernel,
201 const invoked_binary_expression<Expr, Arg1, Arg2> &expr)
203 context<boost::tuple<Arg1, Arg2> > ctx(
205 boost::make_tuple(expr.m_arg1, expr.m_arg2)
207 proto::eval(expr.m_expr, ctx);
212 } // end detail namespace
214 // forward declare domain
217 // lambda expression wrapper
219 struct expression : proto::extends<Expr, expression<Expr>, domain>
221 typedef proto::extends<Expr, expression<Expr>, domain> base_type;
223 BOOST_PROTO_EXTENDS_USING_ASSIGN(expression)
225 expression(const Expr &expr = Expr())
230 // result_of protocol
231 template<class Signature>
237 struct result<This()>
240 typename ::boost::compute::lambda::result_of<Expr>::type type;
243 template<class This, class Arg>
244 struct result<This(Arg)>
247 typename ::boost::compute::lambda::result_of<
249 typename boost::tuple<Arg>
253 template<class This, class Arg1, class Arg2>
254 struct result<This(Arg1, Arg2)>
257 ::boost::compute::lambda::result_of<
259 typename boost::tuple<Arg1, Arg2>
264 detail::invoked_unary_expression<expression<Expr>, Arg>
265 operator()(const Arg &x) const
267 return detail::invoked_unary_expression<expression<Expr>, Arg>(*this, x);
270 template<class Arg1, class Arg2>
271 detail::invoked_binary_expression<expression<Expr>, Arg1, Arg2>
272 operator()(const Arg1 &x, const Arg2 &y) const
274 return detail::invoked_binary_expression<
281 // function<> conversion operator
282 template<class R, class A1>
283 operator function<R(A1)>() const
285 using ::boost::compute::detail::meta_kernel;
287 std::stringstream source;
289 ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
291 source << "inline " << type_name<R>() << " lambda"
292 << ::boost::compute::detail::generate_argument_list<R(A1)>('x')
294 << " return " << meta_kernel::expr_to_string((*this)(arg1)) << ";\n"
297 return make_function_from_source<R(A1)>("lambda", source.str());
300 template<class R, class A1, class A2>
301 operator function<R(A1, A2)>() const
303 using ::boost::compute::detail::meta_kernel;
305 std::stringstream source;
307 ::boost::compute::detail::meta_kernel_variable<A1> arg1("x");
308 ::boost::compute::detail::meta_kernel_variable<A1> arg2("y");
310 source << "inline " << type_name<R>() << " lambda"
311 << ::boost::compute::detail::generate_argument_list<R(A1, A2)>('x')
313 << " return " << meta_kernel::expr_to_string((*this)(arg1, arg2)) << ";\n"
316 return make_function_from_source<R(A1, A2)>("lambda", source.str());
320 // lambda expression domain
321 struct domain : proto::domain<proto::generator<expression> >
325 } // end lambda namespace
326 } // end compute namespace
327 } // end boost namespace
329 #endif // BOOST_COMPUTE_LAMBDA_CONTEXT_HPP