]> git.proxmox.com Git - rustc.git/blame - src/tools/rustfmt/config_proc_macro/src/item_enum.rs
New upstream version 1.52.1+dfsg1
[rustc.git] / src / tools / rustfmt / config_proc_macro / src / item_enum.rs
CommitLineData
f20569fa
XL
1use proc_macro2::TokenStream;
2use quote::quote;
3
4use crate::attrs::*;
5use crate::utils::*;
6
7type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>;
8
9/// Defines and implements `config_type` enum.
10pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> {
11 let syn::ItemEnum {
12 vis,
13 enum_token,
14 ident,
15 generics,
16 variants,
17 ..
18 } = em;
19
20 let mod_name_str = format!("__define_config_type_on_enum_{}", ident);
21 let mod_name = syn::Ident::new(&mod_name_str, ident.span());
22 let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,));
23
24 let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants);
25 let impl_from_str = impl_from_str(&em.ident, &em.variants);
26 let impl_display = impl_display(&em.ident, &em.variants);
27 let impl_serde = impl_serde(&em.ident, &em.variants);
28 let impl_deserialize = impl_deserialize(&em.ident, &em.variants);
29
30 Ok(quote! {
31 #[allow(non_snake_case)]
32 mod #mod_name {
33 #[derive(Debug, Copy, Clone, Eq, PartialEq)]
34 pub #enum_token #ident #generics { #variants }
35 #impl_display
36 #impl_doc_hint
37 #impl_from_str
38 #impl_serde
39 #impl_deserialize
40 }
41 #vis use #mod_name::#ident;
42 })
43}
44
45/// Remove attributes specific to `config_proc_macro` from enum variant fields.
46fn process_variant(variant: &syn::Variant) -> TokenStream {
47 let metas = variant
48 .attrs
49 .iter()
50 .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr));
51 let attrs = fold_quote(metas, |meta| quote!(#meta));
52 let syn::Variant { ident, fields, .. } = variant;
53 quote!(#attrs #ident #fields)
54}
55
56fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream {
57 let doc_hint = variants
58 .iter()
59 .map(doc_hint_of_variant)
60 .collect::<Vec<_>>()
61 .join("|");
62 let doc_hint = format!("[{}]", doc_hint);
63 quote! {
64 use crate::config::ConfigType;
65 impl ConfigType for #ident {
66 fn doc_hint() -> String {
67 #doc_hint.to_owned()
68 }
69 }
70 }
71}
72
73fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream {
74 let vs = variants
75 .iter()
76 .filter(|v| is_unit(v))
77 .map(|v| (config_value_of_variant(v), &v.ident));
78 let match_patterns = fold_quote(vs, |(s, v)| {
79 quote! {
80 #ident::#v => write!(f, "{}", #s),
81 }
82 });
83 quote! {
84 use std::fmt;
85 impl fmt::Display for #ident {
86 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87 match self {
88 #match_patterns
89 _ => unimplemented!(),
90 }
91 }
92 }
93 }
94}
95
96fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream {
97 let vs = variants
98 .iter()
99 .filter(|v| is_unit(v))
100 .map(|v| (config_value_of_variant(v), &v.ident));
101 let if_patterns = fold_quote(vs, |(s, v)| {
102 quote! {
103 if #s.eq_ignore_ascii_case(s) {
104 return Ok(#ident::#v);
105 }
106 }
107 });
108 let mut err_msg = String::from("Bad variant, expected one of:");
109 for v in variants.iter().filter(|v| is_unit(v)) {
110 err_msg.push_str(&format!(" `{}`", v.ident));
111 }
112
113 quote! {
114 impl ::std::str::FromStr for #ident {
115 type Err = &'static str;
116
117 fn from_str(s: &str) -> Result<Self, Self::Err> {
118 #if_patterns
119 return Err(#err_msg);
120 }
121 }
122 }
123}
124
125fn doc_hint_of_variant(variant: &syn::Variant) -> String {
126 find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string())
127}
128
129fn config_value_of_variant(variant: &syn::Variant) -> String {
130 find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string())
131}
132
133fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream {
134 let arms = fold_quote(variants.iter(), |v| {
135 let v_ident = &v.ident;
136 let pattern = match v.fields {
137 syn::Fields::Named(..) => quote!(#ident::v_ident{..}),
138 syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)),
139 syn::Fields::Unit => quote!(#ident::#v_ident),
140 };
141 let option_value = config_value_of_variant(v);
142 quote! {
143 #pattern => serializer.serialize_str(&#option_value),
144 }
145 });
146
147 quote! {
148 impl ::serde::ser::Serialize for #ident {
149 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
150 where
151 S: ::serde::ser::Serializer,
152 {
153 use serde::ser::Error;
154 match self {
155 #arms
156 _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))),
157 }
158 }
159 }
160 }
161}
162
163// Currently only unit variants are supported.
164fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream {
165 let supported_vs = variants.iter().filter(|v| is_unit(v));
166 let if_patterns = fold_quote(supported_vs, |v| {
167 let config_value = config_value_of_variant(v);
168 let variant_ident = &v.ident;
169 quote! {
170 if #config_value.eq_ignore_ascii_case(s) {
171 return Ok(#ident::#variant_ident);
172 }
173 }
174 });
175
176 let supported_vs = variants.iter().filter(|v| is_unit(v));
177 let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,));
178
179 quote! {
180 impl<'de> serde::de::Deserialize<'de> for #ident {
181 fn deserialize<D>(d: D) -> Result<Self, D::Error>
182 where
183 D: serde::Deserializer<'de>,
184 {
185 use serde::de::{Error, Visitor};
186 use std::marker::PhantomData;
187 use std::fmt;
188 struct StringOnly<T>(PhantomData<T>);
189 impl<'de, T> Visitor<'de> for StringOnly<T>
190 where T: serde::Deserializer<'de> {
191 type Value = String;
192 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
193 formatter.write_str("string")
194 }
195 fn visit_str<E>(self, value: &str) -> Result<String, E> {
196 Ok(String::from(value))
197 }
198 }
199 let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?;
200
201 #if_patterns
202
203 static ALLOWED: &'static[&str] = &[#allowed];
204 Err(D::Error::unknown_variant(&s, ALLOWED))
205 }
206 }
207 }
208}