]>
Commit | Line | Data |
---|---|---|
5099ac24 FG |
1 | //! The futures-rs `join! macro implementation. |
2 | ||
3 | use proc_macro::TokenStream; | |
4 | use proc_macro2::{Span, TokenStream as TokenStream2}; | |
5 | use quote::{format_ident, quote}; | |
6 | use syn::parse::{Parse, ParseStream}; | |
7 | use syn::{Expr, Ident, Token}; | |
8 | ||
9 | #[derive(Default)] | |
10 | struct Join { | |
11 | fut_exprs: Vec<Expr>, | |
12 | } | |
13 | ||
14 | impl Parse for Join { | |
15 | fn parse(input: ParseStream<'_>) -> syn::Result<Self> { | |
16 | let mut join = Self::default(); | |
17 | ||
18 | while !input.is_empty() { | |
19 | join.fut_exprs.push(input.parse::<Expr>()?); | |
20 | ||
21 | if !input.is_empty() { | |
22 | input.parse::<Token![,]>()?; | |
23 | } | |
24 | } | |
25 | ||
26 | Ok(join) | |
27 | } | |
28 | } | |
29 | ||
30 | fn bind_futures(fut_exprs: Vec<Expr>, span: Span) -> (Vec<TokenStream2>, Vec<Ident>) { | |
31 | let mut future_let_bindings = Vec::with_capacity(fut_exprs.len()); | |
32 | let future_names: Vec<_> = fut_exprs | |
33 | .into_iter() | |
34 | .enumerate() | |
35 | .map(|(i, expr)| { | |
36 | let name = format_ident!("_fut{}", i, span = span); | |
37 | future_let_bindings.push(quote! { | |
38 | // Move future into a local so that it is pinned in one place and | |
39 | // is no longer accessible by the end user. | |
40 | let mut #name = __futures_crate::future::maybe_done(#expr); | |
353b0b11 | 41 | let mut #name = unsafe { __futures_crate::Pin::new_unchecked(&mut #name) }; |
5099ac24 FG |
42 | }); |
43 | name | |
44 | }) | |
45 | .collect(); | |
46 | ||
47 | (future_let_bindings, future_names) | |
48 | } | |
49 | ||
50 | /// The `join!` macro. | |
51 | pub(crate) fn join(input: TokenStream) -> TokenStream { | |
52 | let parsed = syn::parse_macro_input!(input as Join); | |
53 | ||
54 | // should be def_site, but that's unstable | |
55 | let span = Span::call_site(); | |
56 | ||
57 | let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span); | |
58 | ||
59 | let poll_futures = future_names.iter().map(|fut| { | |
60 | quote! { | |
61 | __all_done &= __futures_crate::future::Future::poll( | |
353b0b11 | 62 | #fut.as_mut(), __cx).is_ready(); |
5099ac24 FG |
63 | } |
64 | }); | |
65 | let take_outputs = future_names.iter().map(|fut| { | |
66 | quote! { | |
353b0b11 | 67 | #fut.as_mut().take_output().unwrap(), |
5099ac24 FG |
68 | } |
69 | }); | |
70 | ||
71 | TokenStream::from(quote! { { | |
72 | #( #future_let_bindings )* | |
73 | ||
74 | __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| { | |
75 | let mut __all_done = true; | |
76 | #( #poll_futures )* | |
77 | if __all_done { | |
78 | __futures_crate::task::Poll::Ready(( | |
79 | #( #take_outputs )* | |
80 | )) | |
81 | } else { | |
82 | __futures_crate::task::Poll::Pending | |
83 | } | |
84 | }).await | |
85 | } }) | |
86 | } | |
87 | ||
88 | /// The `try_join!` macro. | |
89 | pub(crate) fn try_join(input: TokenStream) -> TokenStream { | |
90 | let parsed = syn::parse_macro_input!(input as Join); | |
91 | ||
92 | // should be def_site, but that's unstable | |
93 | let span = Span::call_site(); | |
94 | ||
95 | let (future_let_bindings, future_names) = bind_futures(parsed.fut_exprs, span); | |
96 | ||
97 | let poll_futures = future_names.iter().map(|fut| { | |
98 | quote! { | |
99 | if __futures_crate::future::Future::poll( | |
353b0b11 | 100 | #fut.as_mut(), __cx).is_pending() |
5099ac24 FG |
101 | { |
102 | __all_done = false; | |
353b0b11 | 103 | } else if #fut.as_mut().output_mut().unwrap().is_err() { |
5099ac24 FG |
104 | // `.err().unwrap()` rather than `.unwrap_err()` so that we don't introduce |
105 | // a `T: Debug` bound. | |
106 | // Also, for an error type of ! any code after `err().unwrap()` is unreachable. | |
107 | #[allow(unreachable_code)] | |
108 | return __futures_crate::task::Poll::Ready( | |
109 | __futures_crate::Err( | |
353b0b11 | 110 | #fut.as_mut().take_output().unwrap().err().unwrap() |
5099ac24 FG |
111 | ) |
112 | ); | |
113 | } | |
114 | } | |
115 | }); | |
116 | let take_outputs = future_names.iter().map(|fut| { | |
117 | quote! { | |
118 | // `.ok().unwrap()` rather than `.unwrap()` so that we don't introduce | |
119 | // an `E: Debug` bound. | |
120 | // Also, for an ok type of ! any code after `ok().unwrap()` is unreachable. | |
121 | #[allow(unreachable_code)] | |
353b0b11 | 122 | #fut.as_mut().take_output().unwrap().ok().unwrap(), |
5099ac24 FG |
123 | } |
124 | }); | |
125 | ||
126 | TokenStream::from(quote! { { | |
127 | #( #future_let_bindings )* | |
128 | ||
129 | #[allow(clippy::diverging_sub_expression)] | |
130 | __futures_crate::future::poll_fn(move |__cx: &mut __futures_crate::task::Context<'_>| { | |
131 | let mut __all_done = true; | |
132 | #( #poll_futures )* | |
133 | if __all_done { | |
134 | __futures_crate::task::Poll::Ready( | |
135 | __futures_crate::Ok(( | |
136 | #( #take_outputs )* | |
137 | )) | |
138 | ) | |
139 | } else { | |
140 | __futures_crate::task::Poll::Pending | |
141 | } | |
142 | }).await | |
143 | } }) | |
144 | } |