]> git.proxmox.com Git - rustc.git/blob - vendor/derive_arbitrary/src/lib.rs
New upstream version 1.76.0+dfsg1
[rustc.git] / vendor / derive_arbitrary / src / lib.rs
1 extern crate proc_macro;
2
3 use proc_macro2::{Span, TokenStream};
4 use quote::quote;
5 use syn::*;
6
7 mod container_attributes;
8 mod field_attributes;
9 use container_attributes::ContainerAttributes;
10 use field_attributes::{determine_field_constructor, FieldConstructor};
11
12 static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
13 static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
14
15 #[proc_macro_derive(Arbitrary, attributes(arbitrary))]
16 pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
17 let input = syn::parse_macro_input!(tokens as syn::DeriveInput);
18 expand_derive_arbitrary(input)
19 .unwrap_or_else(syn::Error::into_compile_error)
20 .into()
21 }
22
23 fn expand_derive_arbitrary(input: syn::DeriveInput) -> Result<TokenStream> {
24 let container_attrs = ContainerAttributes::from_derive_input(&input)?;
25
26 let (lifetime_without_bounds, lifetime_with_bounds) =
27 build_arbitrary_lifetime(input.generics.clone());
28
29 let recursive_count = syn::Ident::new(
30 &format!("RECURSIVE_COUNT_{}", input.ident),
31 Span::call_site(),
32 );
33
34 let arbitrary_method =
35 gen_arbitrary_method(&input, lifetime_without_bounds.clone(), &recursive_count)?;
36 let size_hint_method = gen_size_hint_method(&input)?;
37 let name = input.ident;
38
39 // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
40 let generics = apply_trait_bounds(
41 input.generics,
42 lifetime_without_bounds.clone(),
43 &container_attrs,
44 )?;
45
46 // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
47 let mut generics_with_lifetime = generics.clone();
48 generics_with_lifetime
49 .params
50 .push(GenericParam::Lifetime(lifetime_with_bounds));
51 let (impl_generics, _, _) = generics_with_lifetime.split_for_impl();
52
53 // Build TypeGenerics and WhereClause without a lifetime
54 let (_, ty_generics, where_clause) = generics.split_for_impl();
55
56 Ok(quote! {
57 const _: () = {
58 std::thread_local! {
59 #[allow(non_upper_case_globals)]
60 static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
61 }
62
63 #[automatically_derived]
64 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
65 #arbitrary_method
66 #size_hint_method
67 }
68 };
69 })
70 }
71
72 // Returns: (lifetime without bounds, lifetime with bounds)
73 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
74 fn build_arbitrary_lifetime(generics: Generics) -> (LifetimeParam, LifetimeParam) {
75 let lifetime_without_bounds =
76 LifetimeParam::new(Lifetime::new(ARBITRARY_LIFETIME_NAME, Span::call_site()));
77 let mut lifetime_with_bounds = lifetime_without_bounds.clone();
78
79 for param in generics.params.iter() {
80 if let GenericParam::Lifetime(lifetime_def) = param {
81 lifetime_with_bounds
82 .bounds
83 .push(lifetime_def.lifetime.clone());
84 }
85 }
86
87 (lifetime_without_bounds, lifetime_with_bounds)
88 }
89
90 fn apply_trait_bounds(
91 mut generics: Generics,
92 lifetime: LifetimeParam,
93 container_attrs: &ContainerAttributes,
94 ) -> Result<Generics> {
95 // If user-supplied bounds exist, apply them to their matching type parameters.
96 if let Some(config_bounds) = &container_attrs.bounds {
97 let mut config_bounds_applied = 0;
98 for param in generics.params.iter_mut() {
99 if let GenericParam::Type(type_param) = param {
100 if let Some(replacement) = config_bounds
101 .iter()
102 .flatten()
103 .find(|p| p.ident == type_param.ident)
104 {
105 *type_param = replacement.clone();
106 config_bounds_applied += 1;
107 } else {
108 // If no user-supplied bounds exist for this type, delete the original bounds.
109 // This mimics serde.
110 type_param.bounds = Default::default();
111 type_param.default = None;
112 }
113 }
114 }
115 let config_bounds_supplied = config_bounds
116 .iter()
117 .map(|bounds| bounds.len())
118 .sum::<usize>();
119 if config_bounds_applied != config_bounds_supplied {
120 return Err(Error::new(
121 Span::call_site(),
122 format!(
123 "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
124 ARBITRARY_ATTRIBUTE_NAME, config_bounds_applied, config_bounds_supplied,
125 ),
126 ));
127 }
128 Ok(generics)
129 } else {
130 // Otherwise, inject a `T: Arbitrary` bound for every parameter.
131 Ok(add_trait_bounds(generics, lifetime))
132 }
133 }
134
135 // Add a bound `T: Arbitrary` to every type parameter T.
136 fn add_trait_bounds(mut generics: Generics, lifetime: LifetimeParam) -> Generics {
137 for param in generics.params.iter_mut() {
138 if let GenericParam::Type(type_param) = param {
139 type_param
140 .bounds
141 .push(parse_quote!(arbitrary::Arbitrary<#lifetime>));
142 }
143 }
144 generics
145 }
146
147 fn with_recursive_count_guard(
148 recursive_count: &syn::Ident,
149 expr: impl quote::ToTokens,
150 ) -> impl quote::ToTokens {
151 quote! {
152 let guard_against_recursion = u.is_empty();
153 if guard_against_recursion {
154 #recursive_count.with(|count| {
155 if count.get() > 0 {
156 return Err(arbitrary::Error::NotEnoughData);
157 }
158 count.set(count.get() + 1);
159 Ok(())
160 })?;
161 }
162
163 let result = (|| { #expr })();
164
165 if guard_against_recursion {
166 #recursive_count.with(|count| {
167 count.set(count.get() - 1);
168 });
169 }
170
171 result
172 }
173 }
174
175 fn gen_arbitrary_method(
176 input: &DeriveInput,
177 lifetime: LifetimeParam,
178 recursive_count: &syn::Ident,
179 ) -> Result<TokenStream> {
180 fn arbitrary_structlike(
181 fields: &Fields,
182 ident: &syn::Ident,
183 lifetime: LifetimeParam,
184 recursive_count: &syn::Ident,
185 ) -> Result<TokenStream> {
186 let arbitrary = construct(fields, |_idx, field| gen_constructor_for_field(field))?;
187 let body = with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary) });
188
189 let arbitrary_take_rest = construct_take_rest(fields)?;
190 let take_rest_body =
191 with_recursive_count_guard(recursive_count, quote! { Ok(#ident #arbitrary_take_rest) });
192
193 Ok(quote! {
194 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
195 #body
196 }
197
198 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
199 #take_rest_body
200 }
201 })
202 }
203
204 let ident = &input.ident;
205 let output = match &input.data {
206 Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
207 Data::Union(data) => arbitrary_structlike(
208 &Fields::Named(data.fields.clone()),
209 ident,
210 lifetime,
211 recursive_count,
212 )?,
213 Data::Enum(data) => {
214 let variants: Vec<TokenStream> = data
215 .variants
216 .iter()
217 .enumerate()
218 .map(|(i, variant)| {
219 let idx = i as u64;
220 let variant_name = &variant.ident;
221 construct(&variant.fields, |_, field| gen_constructor_for_field(field))
222 .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
223 })
224 .collect::<Result<_>>()?;
225
226 let variants_take_rest: Vec<TokenStream> = data
227 .variants
228 .iter()
229 .enumerate()
230 .map(|(i, variant)| {
231 let idx = i as u64;
232 let variant_name = &variant.ident;
233 construct_take_rest(&variant.fields)
234 .map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
235 })
236 .collect::<Result<_>>()?;
237
238 let count = data.variants.len() as u64;
239
240 let arbitrary = with_recursive_count_guard(
241 recursive_count,
242 quote! {
243 // Use a multiply + shift to generate a ranged random number
244 // with slight bias. For details, see:
245 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
246 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
247 #(#variants,)*
248 _ => unreachable!()
249 })
250 },
251 );
252
253 let arbitrary_take_rest = with_recursive_count_guard(
254 recursive_count,
255 quote! {
256 // Use a multiply + shift to generate a ranged random number
257 // with slight bias. For details, see:
258 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
259 Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
260 #(#variants_take_rest,)*
261 _ => unreachable!()
262 })
263 },
264 );
265
266 quote! {
267 fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
268 #arbitrary
269 }
270
271 fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
272 #arbitrary_take_rest
273 }
274 }
275 }
276 };
277 Ok(output)
278 }
279
280 fn construct(
281 fields: &Fields,
282 ctor: impl Fn(usize, &Field) -> Result<TokenStream>,
283 ) -> Result<TokenStream> {
284 let output = match fields {
285 Fields::Named(names) => {
286 let names: Vec<TokenStream> = names
287 .named
288 .iter()
289 .enumerate()
290 .map(|(i, f)| {
291 let name = f.ident.as_ref().unwrap();
292 ctor(i, f).map(|ctor| quote! { #name: #ctor })
293 })
294 .collect::<Result<_>>()?;
295 quote! { { #(#names,)* } }
296 }
297 Fields::Unnamed(names) => {
298 let names: Vec<TokenStream> = names
299 .unnamed
300 .iter()
301 .enumerate()
302 .map(|(i, f)| ctor(i, f).map(|ctor| quote! { #ctor }))
303 .collect::<Result<_>>()?;
304 quote! { ( #(#names),* ) }
305 }
306 Fields::Unit => quote!(),
307 };
308 Ok(output)
309 }
310
311 fn construct_take_rest(fields: &Fields) -> Result<TokenStream> {
312 construct(fields, |idx, field| {
313 determine_field_constructor(field).map(|field_constructor| match field_constructor {
314 FieldConstructor::Default => quote!(Default::default()),
315 FieldConstructor::Arbitrary => {
316 if idx + 1 == fields.len() {
317 quote! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
318 } else {
319 quote! { arbitrary::Arbitrary::arbitrary(&mut u)? }
320 }
321 }
322 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(&mut u)?),
323 FieldConstructor::Value(value) => quote!(#value),
324 })
325 })
326 }
327
328 fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
329 let size_hint_fields = |fields: &Fields| {
330 fields
331 .iter()
332 .map(|f| {
333 let ty = &f.ty;
334 determine_field_constructor(f).map(|field_constructor| {
335 match field_constructor {
336 FieldConstructor::Default | FieldConstructor::Value(_) => {
337 quote!((0, Some(0)))
338 }
339 FieldConstructor::Arbitrary => {
340 quote! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
341 }
342
343 // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
344 // just an educated guess, although it's gonna be inaccurate for dynamically
345 // allocated types (Vec, HashMap, etc.).
346 FieldConstructor::With(_) => {
347 quote! { (::core::mem::size_of::<#ty>(), None) }
348 }
349 }
350 })
351 })
352 .collect::<Result<Vec<TokenStream>>>()
353 .map(|hints| {
354 quote! {
355 arbitrary::size_hint::and_all(&[
356 #( #hints ),*
357 ])
358 }
359 })
360 };
361 let size_hint_structlike = |fields: &Fields| {
362 size_hint_fields(fields).map(|hint| {
363 quote! {
364 #[inline]
365 fn size_hint(depth: usize) -> (usize, Option<usize>) {
366 arbitrary::size_hint::recursion_guard(depth, |depth| #hint)
367 }
368 }
369 })
370 };
371 match &input.data {
372 Data::Struct(data) => size_hint_structlike(&data.fields),
373 Data::Union(data) => size_hint_structlike(&Fields::Named(data.fields.clone())),
374 Data::Enum(data) => data
375 .variants
376 .iter()
377 .map(|v| size_hint_fields(&v.fields))
378 .collect::<Result<Vec<TokenStream>>>()
379 .map(|variants| {
380 quote! {
381 #[inline]
382 fn size_hint(depth: usize) -> (usize, Option<usize>) {
383 arbitrary::size_hint::and(
384 <u32 as arbitrary::Arbitrary>::size_hint(depth),
385 arbitrary::size_hint::recursion_guard(depth, |depth| {
386 arbitrary::size_hint::or_all(&[ #( #variants ),* ])
387 }),
388 )
389 }
390 }
391 }),
392 }
393 }
394
395 fn gen_constructor_for_field(field: &Field) -> Result<TokenStream> {
396 let ctor = match determine_field_constructor(field)? {
397 FieldConstructor::Default => quote!(Default::default()),
398 FieldConstructor::Arbitrary => quote!(arbitrary::Arbitrary::arbitrary(u)?),
399 FieldConstructor::With(function_or_closure) => quote!((#function_or_closure)(u)?),
400 FieldConstructor::Value(value) => quote!(#value),
401 };
402 Ok(ctor)
403 }