]>
Commit | Line | Data |
---|---|---|
f20569fa XL |
1 | use proc_macro2::TokenStream; |
2 | use quote::quote; | |
3 | ||
4 | use crate::attrs::*; | |
5 | use crate::utils::*; | |
6 | ||
7 | type Variants = syn::punctuated::Punctuated<syn::Variant, syn::Token![,]>; | |
8 | ||
9 | /// Defines and implements `config_type` enum. | |
10 | pub fn define_config_type_on_enum(em: &syn::ItemEnum) -> syn::Result<TokenStream> { | |
11 | let syn::ItemEnum { | |
12 | vis, | |
13 | enum_token, | |
14 | ident, | |
15 | generics, | |
16 | variants, | |
17 | .. | |
18 | } = em; | |
19 | ||
20 | let mod_name_str = format!("__define_config_type_on_enum_{}", ident); | |
21 | let mod_name = syn::Ident::new(&mod_name_str, ident.span()); | |
22 | let variants = fold_quote(variants.iter().map(process_variant), |meta| quote!(#meta,)); | |
23 | ||
24 | let impl_doc_hint = impl_doc_hint(&em.ident, &em.variants); | |
25 | let impl_from_str = impl_from_str(&em.ident, &em.variants); | |
26 | let impl_display = impl_display(&em.ident, &em.variants); | |
27 | let impl_serde = impl_serde(&em.ident, &em.variants); | |
28 | let impl_deserialize = impl_deserialize(&em.ident, &em.variants); | |
29 | ||
30 | Ok(quote! { | |
31 | #[allow(non_snake_case)] | |
32 | mod #mod_name { | |
33 | #[derive(Debug, Copy, Clone, Eq, PartialEq)] | |
34 | pub #enum_token #ident #generics { #variants } | |
35 | #impl_display | |
36 | #impl_doc_hint | |
37 | #impl_from_str | |
38 | #impl_serde | |
39 | #impl_deserialize | |
40 | } | |
41 | #vis use #mod_name::#ident; | |
42 | }) | |
43 | } | |
44 | ||
45 | /// Remove attributes specific to `config_proc_macro` from enum variant fields. | |
46 | fn process_variant(variant: &syn::Variant) -> TokenStream { | |
47 | let metas = variant | |
48 | .attrs | |
49 | .iter() | |
50 | .filter(|attr| !is_doc_hint(attr) && !is_config_value(attr)); | |
51 | let attrs = fold_quote(metas, |meta| quote!(#meta)); | |
52 | let syn::Variant { ident, fields, .. } = variant; | |
53 | quote!(#attrs #ident #fields) | |
54 | } | |
55 | ||
56 | fn impl_doc_hint(ident: &syn::Ident, variants: &Variants) -> TokenStream { | |
57 | let doc_hint = variants | |
58 | .iter() | |
59 | .map(doc_hint_of_variant) | |
60 | .collect::<Vec<_>>() | |
61 | .join("|"); | |
62 | let doc_hint = format!("[{}]", doc_hint); | |
63 | quote! { | |
64 | use crate::config::ConfigType; | |
65 | impl ConfigType for #ident { | |
66 | fn doc_hint() -> String { | |
67 | #doc_hint.to_owned() | |
68 | } | |
69 | } | |
70 | } | |
71 | } | |
72 | ||
73 | fn impl_display(ident: &syn::Ident, variants: &Variants) -> TokenStream { | |
74 | let vs = variants | |
75 | .iter() | |
76 | .filter(|v| is_unit(v)) | |
77 | .map(|v| (config_value_of_variant(v), &v.ident)); | |
78 | let match_patterns = fold_quote(vs, |(s, v)| { | |
79 | quote! { | |
80 | #ident::#v => write!(f, "{}", #s), | |
81 | } | |
82 | }); | |
83 | quote! { | |
84 | use std::fmt; | |
85 | impl fmt::Display for #ident { | |
86 | fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { | |
87 | match self { | |
88 | #match_patterns | |
89 | _ => unimplemented!(), | |
90 | } | |
91 | } | |
92 | } | |
93 | } | |
94 | } | |
95 | ||
96 | fn impl_from_str(ident: &syn::Ident, variants: &Variants) -> TokenStream { | |
97 | let vs = variants | |
98 | .iter() | |
99 | .filter(|v| is_unit(v)) | |
100 | .map(|v| (config_value_of_variant(v), &v.ident)); | |
101 | let if_patterns = fold_quote(vs, |(s, v)| { | |
102 | quote! { | |
103 | if #s.eq_ignore_ascii_case(s) { | |
104 | return Ok(#ident::#v); | |
105 | } | |
106 | } | |
107 | }); | |
108 | let mut err_msg = String::from("Bad variant, expected one of:"); | |
109 | for v in variants.iter().filter(|v| is_unit(v)) { | |
110 | err_msg.push_str(&format!(" `{}`", v.ident)); | |
111 | } | |
112 | ||
113 | quote! { | |
114 | impl ::std::str::FromStr for #ident { | |
115 | type Err = &'static str; | |
116 | ||
117 | fn from_str(s: &str) -> Result<Self, Self::Err> { | |
118 | #if_patterns | |
119 | return Err(#err_msg); | |
120 | } | |
121 | } | |
122 | } | |
123 | } | |
124 | ||
125 | fn doc_hint_of_variant(variant: &syn::Variant) -> String { | |
126 | find_doc_hint(&variant.attrs).unwrap_or(variant.ident.to_string()) | |
127 | } | |
128 | ||
129 | fn config_value_of_variant(variant: &syn::Variant) -> String { | |
130 | find_config_value(&variant.attrs).unwrap_or(variant.ident.to_string()) | |
131 | } | |
132 | ||
133 | fn impl_serde(ident: &syn::Ident, variants: &Variants) -> TokenStream { | |
134 | let arms = fold_quote(variants.iter(), |v| { | |
135 | let v_ident = &v.ident; | |
136 | let pattern = match v.fields { | |
137 | syn::Fields::Named(..) => quote!(#ident::v_ident{..}), | |
138 | syn::Fields::Unnamed(..) => quote!(#ident::#v_ident(..)), | |
139 | syn::Fields::Unit => quote!(#ident::#v_ident), | |
140 | }; | |
141 | let option_value = config_value_of_variant(v); | |
142 | quote! { | |
143 | #pattern => serializer.serialize_str(&#option_value), | |
144 | } | |
145 | }); | |
146 | ||
147 | quote! { | |
148 | impl ::serde::ser::Serialize for #ident { | |
149 | fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> | |
150 | where | |
151 | S: ::serde::ser::Serializer, | |
152 | { | |
153 | use serde::ser::Error; | |
154 | match self { | |
155 | #arms | |
156 | _ => Err(S::Error::custom(format!("Cannot serialize {:?}", self))), | |
157 | } | |
158 | } | |
159 | } | |
160 | } | |
161 | } | |
162 | ||
163 | // Currently only unit variants are supported. | |
164 | fn impl_deserialize(ident: &syn::Ident, variants: &Variants) -> TokenStream { | |
165 | let supported_vs = variants.iter().filter(|v| is_unit(v)); | |
166 | let if_patterns = fold_quote(supported_vs, |v| { | |
167 | let config_value = config_value_of_variant(v); | |
168 | let variant_ident = &v.ident; | |
169 | quote! { | |
170 | if #config_value.eq_ignore_ascii_case(s) { | |
171 | return Ok(#ident::#variant_ident); | |
172 | } | |
173 | } | |
174 | }); | |
175 | ||
176 | let supported_vs = variants.iter().filter(|v| is_unit(v)); | |
177 | let allowed = fold_quote(supported_vs.map(config_value_of_variant), |s| quote!(#s,)); | |
178 | ||
179 | quote! { | |
180 | impl<'de> serde::de::Deserialize<'de> for #ident { | |
181 | fn deserialize<D>(d: D) -> Result<Self, D::Error> | |
182 | where | |
183 | D: serde::Deserializer<'de>, | |
184 | { | |
185 | use serde::de::{Error, Visitor}; | |
186 | use std::marker::PhantomData; | |
187 | use std::fmt; | |
188 | struct StringOnly<T>(PhantomData<T>); | |
189 | impl<'de, T> Visitor<'de> for StringOnly<T> | |
190 | where T: serde::Deserializer<'de> { | |
191 | type Value = String; | |
192 | fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result { | |
193 | formatter.write_str("string") | |
194 | } | |
195 | fn visit_str<E>(self, value: &str) -> Result<String, E> { | |
196 | Ok(String::from(value)) | |
197 | } | |
198 | } | |
199 | let s = &d.deserialize_string(StringOnly::<D>(PhantomData))?; | |
200 | ||
201 | #if_patterns | |
202 | ||
203 | static ALLOWED: &'static[&str] = &[#allowed]; | |
204 | Err(D::Error::unknown_variant(&s, ALLOWED)) | |
205 | } | |
206 | } | |
207 | } | |
208 | } |