1 extern crate proc_macro
;
3 use proc_macro2
::{Span, TokenStream}
;
7 mod container_attributes
;
9 use container_attributes
::ContainerAttributes
;
10 use field_attributes
::{determine_field_constructor, FieldConstructor}
;
12 static ARBITRARY_ATTRIBUTE_NAME
: &str = "arbitrary";
13 static ARBITRARY_LIFETIME_NAME
: &str = "'arbitrary";
15 #[proc_macro_derive(Arbitrary, attributes(arbitrary))]
16 pub fn derive_arbitrary(tokens
: proc_macro
::TokenStream
) -> proc_macro
::TokenStream
{
17 let input
= syn
::parse_macro_input
!(tokens
as syn
::DeriveInput
);
18 expand_derive_arbitrary(input
)
19 .unwrap_or_else(syn
::Error
::into_compile_error
)
23 fn expand_derive_arbitrary(input
: syn
::DeriveInput
) -> Result
<TokenStream
> {
24 let container_attrs
= ContainerAttributes
::from_derive_input(&input
)?
;
26 let (lifetime_without_bounds
, lifetime_with_bounds
) =
27 build_arbitrary_lifetime(input
.generics
.clone());
29 let recursive_count
= syn
::Ident
::new(
30 &format
!("RECURSIVE_COUNT_{}", input
.ident
),
34 let arbitrary_method
=
35 gen_arbitrary_method(&input
, lifetime_without_bounds
.clone(), &recursive_count
)?
;
36 let size_hint_method
= gen_size_hint_method(&input
)?
;
37 let name
= input
.ident
;
39 // Apply user-supplied bounds or automatic `T: ArbitraryBounds`.
40 let generics
= apply_trait_bounds(
42 lifetime_without_bounds
.clone(),
46 // Build ImplGeneric with a lifetime (https://github.com/dtolnay/syn/issues/90)
47 let mut generics_with_lifetime
= generics
.clone();
48 generics_with_lifetime
50 .push(GenericParam
::Lifetime(lifetime_with_bounds
));
51 let (impl_generics
, _
, _
) = generics_with_lifetime
.split_for_impl();
53 // Build TypeGenerics and WhereClause without a lifetime
54 let (_
, ty_generics
, where_clause
) = generics
.split_for_impl();
59 #[allow(non_upper_case_globals)]
60 static #recursive_count: std::cell::Cell<u32> = std::cell::Cell::new(0);
63 #[automatically_derived]
64 impl #impl_generics arbitrary::Arbitrary<#lifetime_without_bounds> for #name #ty_generics #where_clause {
72 // Returns: (lifetime without bounds, lifetime with bounds)
73 // Example: ("'arbitrary", "'arbitrary: 'a + 'b")
74 fn build_arbitrary_lifetime(generics
: Generics
) -> (LifetimeParam
, LifetimeParam
) {
75 let lifetime_without_bounds
=
76 LifetimeParam
::new(Lifetime
::new(ARBITRARY_LIFETIME_NAME
, Span
::call_site()));
77 let mut lifetime_with_bounds
= lifetime_without_bounds
.clone();
79 for param
in generics
.params
.iter() {
80 if let GenericParam
::Lifetime(lifetime_def
) = param
{
83 .push(lifetime_def
.lifetime
.clone());
87 (lifetime_without_bounds
, lifetime_with_bounds
)
90 fn apply_trait_bounds(
91 mut generics
: Generics
,
92 lifetime
: LifetimeParam
,
93 container_attrs
: &ContainerAttributes
,
94 ) -> Result
<Generics
> {
95 // If user-supplied bounds exist, apply them to their matching type parameters.
96 if let Some(config_bounds
) = &container_attrs
.bounds
{
97 let mut config_bounds_applied
= 0;
98 for param
in generics
.params
.iter_mut() {
99 if let GenericParam
::Type(type_param
) = param
{
100 if let Some(replacement
) = config_bounds
103 .find(|p
| p
.ident
== type_param
.ident
)
105 *type_param
= replacement
.clone();
106 config_bounds_applied
+= 1;
108 // If no user-supplied bounds exist for this type, delete the original bounds.
109 // This mimics serde.
110 type_param
.bounds
= Default
::default();
111 type_param
.default = None
;
115 let config_bounds_supplied
= config_bounds
117 .map(|bounds
| bounds
.len())
119 if config_bounds_applied
!= config_bounds_supplied
{
120 return Err(Error
::new(
123 "invalid `{}` attribute. too many bounds, only {} out of {} are applicable",
124 ARBITRARY_ATTRIBUTE_NAME
, config_bounds_applied
, config_bounds_supplied
,
130 // Otherwise, inject a `T: Arbitrary` bound for every parameter.
131 Ok(add_trait_bounds(generics
, lifetime
))
135 // Add a bound `T: Arbitrary` to every type parameter T.
136 fn add_trait_bounds(mut generics
: Generics
, lifetime
: LifetimeParam
) -> Generics
{
137 for param
in generics
.params
.iter_mut() {
138 if let GenericParam
::Type(type_param
) = param
{
141 .push(parse_quote
!(arbitrary
::Arbitrary
<#lifetime>));
147 fn with_recursive_count_guard(
148 recursive_count
: &syn
::Ident
,
149 expr
: impl quote
::ToTokens
,
150 ) -> impl quote
::ToTokens
{
152 let guard_against_recursion
= u
.is_empty();
153 if guard_against_recursion
{
154 #recursive_count.with(|count| {
156 return Err(arbitrary
::Error
::NotEnoughData
);
158 count
.set(count
.get() + 1);
163 let result
= (|| { #expr }
)();
165 if guard_against_recursion
{
166 #recursive_count.with(|count| {
167 count
.set(count
.get() - 1);
175 fn gen_arbitrary_method(
177 lifetime
: LifetimeParam
,
178 recursive_count
: &syn
::Ident
,
179 ) -> Result
<TokenStream
> {
180 fn arbitrary_structlike(
183 lifetime
: LifetimeParam
,
184 recursive_count
: &syn
::Ident
,
185 ) -> Result
<TokenStream
> {
186 let arbitrary
= construct(fields
, |_idx
, field
| gen_constructor_for_field(field
))?
;
187 let body
= with_recursive_count_guard(recursive_count
, quote
! { Ok(#ident #arbitrary) }
);
189 let arbitrary_take_rest
= construct_take_rest(fields
)?
;
191 with_recursive_count_guard(recursive_count
, quote
! { Ok(#ident #arbitrary_take_rest) }
);
194 fn arbitrary(u
: &mut arbitrary
::Unstructured
<#lifetime>) -> arbitrary::Result<Self> {
198 fn arbitrary_take_rest(mut u
: arbitrary
::Unstructured
<#lifetime>) -> arbitrary::Result<Self> {
204 let ident
= &input
.ident
;
205 let output
= match &input
.data
{
206 Data
::Struct(data
) => arbitrary_structlike(&data
.fields
, ident
, lifetime
, recursive_count
)?
,
207 Data
::Union(data
) => arbitrary_structlike(
208 &Fields
::Named(data
.fields
.clone()),
213 Data
::Enum(data
) => {
214 let variants
: Vec
<TokenStream
> = data
218 .map(|(i
, variant
)| {
220 let variant_name
= &variant
.ident
;
221 construct(&variant
.fields
, |_
, field
| gen_constructor_for_field(field
))
222 .map(|ctor
| quote
! { #idx => #ident::#variant_name #ctor }
)
224 .collect
::<Result
<_
>>()?
;
226 let variants_take_rest
: Vec
<TokenStream
> = data
230 .map(|(i
, variant
)| {
232 let variant_name
= &variant
.ident
;
233 construct_take_rest(&variant
.fields
)
234 .map(|ctor
| quote
! { #idx => #ident::#variant_name #ctor }
)
236 .collect
::<Result
<_
>>()?
;
238 let count
= data
.variants
.len() as u64;
240 let arbitrary
= with_recursive_count_guard(
243 // Use a multiply + shift to generate a ranged random number
244 // with slight bias. For details, see:
245 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
246 Ok(match (u64::from(<u32 as arbitrary
::Arbitrary
>::arbitrary(u
)?
) * #count) >> 32 {
253 let arbitrary_take_rest
= with_recursive_count_guard(
256 // Use a multiply + shift to generate a ranged random number
257 // with slight bias. For details, see:
258 // https://lemire.me/blog/2016/06/30/fast-random-shuffling
259 Ok(match (u64::from(<u32 as arbitrary
::Arbitrary
>::arbitrary(&mut u
)?
) * #count) >> 32 {
260 #(#variants_take_rest,)*
267 fn arbitrary(u
: &mut arbitrary
::Unstructured
<#lifetime>) -> arbitrary::Result<Self> {
271 fn arbitrary_take_rest(mut u
: arbitrary
::Unstructured
<#lifetime>) -> arbitrary::Result<Self> {
282 ctor
: impl Fn(usize, &Field
) -> Result
<TokenStream
>,
283 ) -> Result
<TokenStream
> {
284 let output
= match fields
{
285 Fields
::Named(names
) => {
286 let names
: Vec
<TokenStream
> = names
291 let name
= f
.ident
.as_ref().unwrap();
292 ctor(i
, f
).map(|ctor
| quote
! { #name: #ctor }
)
294 .collect
::<Result
<_
>>()?
;
295 quote
! { { #(#names,)* }
}
297 Fields
::Unnamed(names
) => {
298 let names
: Vec
<TokenStream
> = names
302 .map(|(i
, f
)| ctor(i
, f
).map(|ctor
| quote
! { #ctor }
))
303 .collect
::<Result
<_
>>()?
;
304 quote
! { ( #(#names),* ) }
306 Fields
::Unit
=> quote
!(),
311 fn construct_take_rest(fields
: &Fields
) -> Result
<TokenStream
> {
312 construct(fields
, |idx
, field
| {
313 determine_field_constructor(field
).map(|field_constructor
| match field_constructor
{
314 FieldConstructor
::Default
=> quote
!(Default
::default()),
315 FieldConstructor
::Arbitrary
=> {
316 if idx
+ 1 == fields
.len() {
317 quote
! { arbitrary::Arbitrary::arbitrary_take_rest(u)? }
319 quote
! { arbitrary::Arbitrary::arbitrary(&mut u)? }
322 FieldConstructor
::With(function_or_closure
) => quote
!((#function_or_closure)(&mut u)?),
323 FieldConstructor
::Value(value
) => quote
!(#value),
328 fn gen_size_hint_method(input
: &DeriveInput
) -> Result
<TokenStream
> {
329 let size_hint_fields
= |fields
: &Fields
| {
334 determine_field_constructor(f
).map(|field_constructor
| {
335 match field_constructor
{
336 FieldConstructor
::Default
| FieldConstructor
::Value(_
) => {
339 FieldConstructor
::Arbitrary
=> {
340 quote
! { <#ty as arbitrary::Arbitrary>::size_hint(depth) }
343 // Note that in this case it's hard to determine what size_hint must be, so size_of::<T>() is
344 // just an educated guess, although it's gonna be inaccurate for dynamically
345 // allocated types (Vec, HashMap, etc.).
346 FieldConstructor
::With(_
) => {
347 quote
! { (::core::mem::size_of::<#ty>(), None) }
352 .collect
::<Result
<Vec
<TokenStream
>>>()
355 arbitrary
::size_hint
::and_all(&[
361 let size_hint_structlike
= |fields
: &Fields
| {
362 size_hint_fields(fields
).map(|hint
| {
365 fn size_hint(depth
: usize) -> (usize, Option
<usize>) {
366 arbitrary
::size_hint
::recursion_guard(depth
, |depth
| #hint)
372 Data
::Struct(data
) => size_hint_structlike(&data
.fields
),
373 Data
::Union(data
) => size_hint_structlike(&Fields
::Named(data
.fields
.clone())),
374 Data
::Enum(data
) => data
377 .map(|v
| size_hint_fields(&v
.fields
))
378 .collect
::<Result
<Vec
<TokenStream
>>>()
382 fn size_hint(depth
: usize) -> (usize, Option
<usize>) {
383 arbitrary
::size_hint
::and(
384 <u32 as arbitrary
::Arbitrary
>::size_hint(depth
),
385 arbitrary
::size_hint
::recursion_guard(depth
, |depth
| {
386 arbitrary
::size_hint
::or_all(&[ #( #variants ),* ])
395 fn gen_constructor_for_field(field
: &Field
) -> Result
<TokenStream
> {
396 let ctor
= match determine_field_constructor(field
)?
{
397 FieldConstructor
::Default
=> quote
!(Default
::default()),
398 FieldConstructor
::Arbitrary
=> quote
!(arbitrary
::Arbitrary
::arbitrary(u
)?
),
399 FieldConstructor
::With(function_or_closure
) => quote
!((#function_or_closure)(u)?),
400 FieldConstructor
::Value(value
) => quote
!(#value),