]>
Commit | Line | Data |
---|---|---|
5e7ed085 FG |
1 | use proc_macro2::{Span, TokenStream}; |
2 | use quote::quote; | |
3 | use syn::parse::*; | |
4 | use syn::punctuated::Punctuated; | |
5 | use syn::*; | |
6 | ||
7 | mod kw { | |
8 | syn::custom_keyword!(derive); | |
9 | syn::custom_keyword!(DEBUG_FORMAT); | |
10 | syn::custom_keyword!(MAX); | |
11 | syn::custom_keyword!(ENCODABLE); | |
12 | syn::custom_keyword!(custom); | |
13 | syn::custom_keyword!(ORD_IMPL); | |
14 | } | |
15 | ||
16 | #[derive(Debug)] | |
17 | enum DebugFormat { | |
18 | // The user will provide a custom `Debug` impl, so we shouldn't generate | |
19 | // one | |
20 | Custom, | |
21 | // Use the specified format string in the generated `Debug` impl | |
22 | // By default, this is "{}" | |
23 | Format(String), | |
24 | } | |
25 | ||
26 | // We parse the input and emit the output in a single step. | |
27 | // This field stores the final macro output | |
28 | struct Newtype(TokenStream); | |
29 | ||
30 | impl Parse for Newtype { | |
31 | fn parse(input: ParseStream<'_>) -> Result<Self> { | |
32 | let attrs = input.call(Attribute::parse_outer)?; | |
33 | let vis: Visibility = input.parse()?; | |
34 | input.parse::<Token![struct]>()?; | |
35 | let name: Ident = input.parse()?; | |
36 | ||
37 | let body; | |
38 | braced!(body in input); | |
39 | ||
40 | // Any additional `#[derive]` macro paths to apply | |
41 | let mut derive_paths: Vec<Path> = Vec::new(); | |
42 | let mut debug_format: Option<DebugFormat> = None; | |
43 | let mut max = None; | |
44 | let mut consts = Vec::new(); | |
45 | let mut encodable = true; | |
46 | let mut ord = true; | |
47 | ||
48 | // Parse an optional trailing comma | |
49 | let try_comma = || -> Result<()> { | |
50 | if body.lookahead1().peek(Token![,]) { | |
51 | body.parse::<Token![,]>()?; | |
52 | } | |
53 | Ok(()) | |
54 | }; | |
55 | ||
56 | if body.lookahead1().peek(Token![..]) { | |
57 | body.parse::<Token![..]>()?; | |
58 | } else { | |
59 | loop { | |
60 | if body.lookahead1().peek(kw::derive) { | |
61 | body.parse::<kw::derive>()?; | |
62 | let derives; | |
63 | bracketed!(derives in body); | |
64 | let derives: Punctuated<Path, Token![,]> = | |
65 | derives.parse_terminated(Path::parse)?; | |
66 | try_comma()?; | |
67 | derive_paths.extend(derives); | |
68 | continue; | |
69 | } | |
70 | if body.lookahead1().peek(kw::DEBUG_FORMAT) { | |
71 | body.parse::<kw::DEBUG_FORMAT>()?; | |
72 | body.parse::<Token![=]>()?; | |
73 | let new_debug_format = if body.lookahead1().peek(kw::custom) { | |
74 | body.parse::<kw::custom>()?; | |
75 | DebugFormat::Custom | |
76 | } else { | |
77 | let format_str: LitStr = body.parse()?; | |
78 | DebugFormat::Format(format_str.value()) | |
79 | }; | |
80 | try_comma()?; | |
81 | if let Some(old) = debug_format.replace(new_debug_format) { | |
82 | panic!("Specified multiple debug format options: {:?}", old); | |
83 | } | |
84 | continue; | |
85 | } | |
86 | if body.lookahead1().peek(kw::MAX) { | |
87 | body.parse::<kw::MAX>()?; | |
88 | body.parse::<Token![=]>()?; | |
89 | let val: Lit = body.parse()?; | |
90 | try_comma()?; | |
91 | if let Some(old) = max.replace(val) { | |
92 | panic!("Specified multiple MAX: {:?}", old); | |
93 | } | |
94 | continue; | |
95 | } | |
96 | if body.lookahead1().peek(kw::ENCODABLE) { | |
97 | body.parse::<kw::ENCODABLE>()?; | |
98 | body.parse::<Token![=]>()?; | |
99 | body.parse::<kw::custom>()?; | |
100 | try_comma()?; | |
101 | encodable = false; | |
102 | continue; | |
103 | } | |
104 | if body.lookahead1().peek(kw::ORD_IMPL) { | |
105 | body.parse::<kw::ORD_IMPL>()?; | |
106 | body.parse::<Token![=]>()?; | |
107 | body.parse::<kw::custom>()?; | |
108 | ord = false; | |
109 | continue; | |
110 | } | |
111 | ||
112 | // We've parsed everything that the user provided, so we're done | |
113 | if body.is_empty() { | |
114 | break; | |
115 | } | |
116 | ||
117 | // Otherwise, we are parsing a user-defined constant | |
118 | let const_attrs = body.call(Attribute::parse_outer)?; | |
119 | body.parse::<Token![const]>()?; | |
120 | let const_name: Ident = body.parse()?; | |
121 | body.parse::<Token![=]>()?; | |
122 | let const_val: Expr = body.parse()?; | |
123 | try_comma()?; | |
124 | consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); }); | |
125 | } | |
126 | } | |
127 | ||
128 | let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string())); | |
129 | // shave off 256 indices at the end to allow space for packing these indices into enums | |
130 | let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site()))); | |
131 | ||
132 | let encodable_impls = if encodable { | |
133 | quote! { | |
134 | impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name { | |
135 | fn decode(d: &mut D) -> Self { | |
136 | Self::from_u32(d.read_u32()) | |
137 | } | |
138 | } | |
139 | impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name { | |
923072b8 FG |
140 | fn encode(&self, e: &mut E) { |
141 | e.emit_u32(self.private); | |
5e7ed085 FG |
142 | } |
143 | } | |
144 | } | |
145 | } else { | |
146 | quote! {} | |
147 | }; | |
148 | ||
149 | if ord { | |
150 | derive_paths.push(parse_quote!(Ord)); | |
151 | derive_paths.push(parse_quote!(PartialOrd)); | |
152 | } | |
153 | ||
154 | let step = if ord { | |
155 | quote! { | |
156 | impl ::std::iter::Step for #name { | |
157 | #[inline] | |
158 | fn steps_between(start: &Self, end: &Self) -> Option<usize> { | |
159 | <usize as ::std::iter::Step>::steps_between( | |
160 | &Self::index(*start), | |
161 | &Self::index(*end), | |
162 | ) | |
163 | } | |
164 | ||
165 | #[inline] | |
166 | fn forward_checked(start: Self, u: usize) -> Option<Self> { | |
167 | Self::index(start).checked_add(u).map(Self::from_usize) | |
168 | } | |
169 | ||
170 | #[inline] | |
171 | fn backward_checked(start: Self, u: usize) -> Option<Self> { | |
172 | Self::index(start).checked_sub(u).map(Self::from_usize) | |
173 | } | |
174 | } | |
175 | ||
176 | // Safety: The implementation of `Step` upholds all invariants. | |
177 | unsafe impl ::std::iter::TrustedStep for #name {} | |
178 | } | |
179 | } else { | |
180 | quote! {} | |
181 | }; | |
182 | ||
183 | let debug_impl = match debug_format { | |
184 | DebugFormat::Custom => quote! {}, | |
185 | DebugFormat::Format(format) => { | |
186 | quote! { | |
187 | impl ::std::fmt::Debug for #name { | |
188 | fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result { | |
189 | write!(fmt, #format, self.as_u32()) | |
190 | } | |
191 | } | |
192 | } | |
193 | } | |
194 | }; | |
195 | ||
196 | Ok(Self(quote! { | |
197 | #(#attrs)* | |
198 | #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)] | |
199 | #[rustc_layout_scalar_valid_range_end(#max)] | |
200 | #[rustc_pass_by_value] | |
201 | #vis struct #name { | |
202 | private: u32, | |
203 | } | |
204 | ||
205 | #(#consts)* | |
206 | ||
207 | impl #name { | |
208 | /// Maximum value the index can take, as a `u32`. | |
209 | #vis const MAX_AS_U32: u32 = #max; | |
210 | ||
211 | /// Maximum value the index can take. | |
212 | #vis const MAX: Self = Self::from_u32(#max); | |
213 | ||
214 | /// Creates a new index from a given `usize`. | |
215 | /// | |
216 | /// # Panics | |
217 | /// | |
218 | /// Will panic if `value` exceeds `MAX`. | |
219 | #[inline] | |
220 | #vis const fn from_usize(value: usize) -> Self { | |
221 | assert!(value <= (#max as usize)); | |
222 | // SAFETY: We just checked that `value <= max`. | |
223 | unsafe { | |
224 | Self::from_u32_unchecked(value as u32) | |
225 | } | |
226 | } | |
227 | ||
228 | /// Creates a new index from a given `u32`. | |
229 | /// | |
230 | /// # Panics | |
231 | /// | |
232 | /// Will panic if `value` exceeds `MAX`. | |
233 | #[inline] | |
234 | #vis const fn from_u32(value: u32) -> Self { | |
235 | assert!(value <= #max); | |
236 | // SAFETY: We just checked that `value <= max`. | |
237 | unsafe { | |
238 | Self::from_u32_unchecked(value) | |
239 | } | |
240 | } | |
241 | ||
242 | /// Creates a new index from a given `u32`. | |
243 | /// | |
244 | /// # Safety | |
245 | /// | |
246 | /// The provided value must be less than or equal to the maximum value for the newtype. | |
247 | /// Providing a value outside this range is undefined due to layout restrictions. | |
248 | /// | |
249 | /// Prefer using `from_u32`. | |
250 | #[inline] | |
251 | #vis const unsafe fn from_u32_unchecked(value: u32) -> Self { | |
252 | Self { private: value } | |
253 | } | |
254 | ||
255 | /// Extracts the value of this index as a `usize`. | |
256 | #[inline] | |
257 | #vis const fn index(self) -> usize { | |
258 | self.as_usize() | |
259 | } | |
260 | ||
261 | /// Extracts the value of this index as a `u32`. | |
262 | #[inline] | |
263 | #vis const fn as_u32(self) -> u32 { | |
264 | self.private | |
265 | } | |
266 | ||
267 | /// Extracts the value of this index as a `usize`. | |
268 | #[inline] | |
269 | #vis const fn as_usize(self) -> usize { | |
270 | self.as_u32() as usize | |
271 | } | |
272 | } | |
273 | ||
274 | impl std::ops::Add<usize> for #name { | |
275 | type Output = Self; | |
276 | ||
277 | fn add(self, other: usize) -> Self { | |
278 | Self::from_usize(self.index() + other) | |
279 | } | |
280 | } | |
281 | ||
282 | impl rustc_index::vec::Idx for #name { | |
283 | #[inline] | |
284 | fn new(value: usize) -> Self { | |
285 | Self::from_usize(value) | |
286 | } | |
287 | ||
288 | #[inline] | |
289 | fn index(self) -> usize { | |
290 | self.as_usize() | |
291 | } | |
292 | } | |
293 | ||
294 | #step | |
295 | ||
296 | impl From<#name> for u32 { | |
297 | #[inline] | |
298 | fn from(v: #name) -> u32 { | |
299 | v.as_u32() | |
300 | } | |
301 | } | |
302 | ||
303 | impl From<#name> for usize { | |
304 | #[inline] | |
305 | fn from(v: #name) -> usize { | |
306 | v.as_usize() | |
307 | } | |
308 | } | |
309 | ||
310 | impl From<usize> for #name { | |
311 | #[inline] | |
312 | fn from(value: usize) -> Self { | |
313 | Self::from_usize(value) | |
314 | } | |
315 | } | |
316 | ||
317 | impl From<u32> for #name { | |
318 | #[inline] | |
319 | fn from(value: u32) -> Self { | |
320 | Self::from_u32(value) | |
321 | } | |
322 | } | |
323 | ||
324 | #encodable_impls | |
325 | #debug_impl | |
326 | })) | |
327 | } | |
328 | } | |
329 | ||
330 | pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream { | |
331 | let input = parse_macro_input!(input as Newtype); | |
332 | input.0.into() | |
333 | } |