]> git.proxmox.com Git - rustc.git/blame - compiler/rustc_macros/src/newtype.rs
New upstream version 1.63.0+dfsg1
[rustc.git] / compiler / rustc_macros / src / newtype.rs
CommitLineData
5e7ed085
FG
1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::parse::*;
4use syn::punctuated::Punctuated;
5use syn::*;
6
7mod kw {
8 syn::custom_keyword!(derive);
9 syn::custom_keyword!(DEBUG_FORMAT);
10 syn::custom_keyword!(MAX);
11 syn::custom_keyword!(ENCODABLE);
12 syn::custom_keyword!(custom);
13 syn::custom_keyword!(ORD_IMPL);
14}
15
16#[derive(Debug)]
17enum DebugFormat {
18 // The user will provide a custom `Debug` impl, so we shouldn't generate
19 // one
20 Custom,
21 // Use the specified format string in the generated `Debug` impl
22 // By default, this is "{}"
23 Format(String),
24}
25
26// We parse the input and emit the output in a single step.
27// This field stores the final macro output
28struct Newtype(TokenStream);
29
30impl Parse for Newtype {
31 fn parse(input: ParseStream<'_>) -> Result<Self> {
32 let attrs = input.call(Attribute::parse_outer)?;
33 let vis: Visibility = input.parse()?;
34 input.parse::<Token![struct]>()?;
35 let name: Ident = input.parse()?;
36
37 let body;
38 braced!(body in input);
39
40 // Any additional `#[derive]` macro paths to apply
41 let mut derive_paths: Vec<Path> = Vec::new();
42 let mut debug_format: Option<DebugFormat> = None;
43 let mut max = None;
44 let mut consts = Vec::new();
45 let mut encodable = true;
46 let mut ord = true;
47
48 // Parse an optional trailing comma
49 let try_comma = || -> Result<()> {
50 if body.lookahead1().peek(Token![,]) {
51 body.parse::<Token![,]>()?;
52 }
53 Ok(())
54 };
55
56 if body.lookahead1().peek(Token![..]) {
57 body.parse::<Token![..]>()?;
58 } else {
59 loop {
60 if body.lookahead1().peek(kw::derive) {
61 body.parse::<kw::derive>()?;
62 let derives;
63 bracketed!(derives in body);
64 let derives: Punctuated<Path, Token![,]> =
65 derives.parse_terminated(Path::parse)?;
66 try_comma()?;
67 derive_paths.extend(derives);
68 continue;
69 }
70 if body.lookahead1().peek(kw::DEBUG_FORMAT) {
71 body.parse::<kw::DEBUG_FORMAT>()?;
72 body.parse::<Token![=]>()?;
73 let new_debug_format = if body.lookahead1().peek(kw::custom) {
74 body.parse::<kw::custom>()?;
75 DebugFormat::Custom
76 } else {
77 let format_str: LitStr = body.parse()?;
78 DebugFormat::Format(format_str.value())
79 };
80 try_comma()?;
81 if let Some(old) = debug_format.replace(new_debug_format) {
82 panic!("Specified multiple debug format options: {:?}", old);
83 }
84 continue;
85 }
86 if body.lookahead1().peek(kw::MAX) {
87 body.parse::<kw::MAX>()?;
88 body.parse::<Token![=]>()?;
89 let val: Lit = body.parse()?;
90 try_comma()?;
91 if let Some(old) = max.replace(val) {
92 panic!("Specified multiple MAX: {:?}", old);
93 }
94 continue;
95 }
96 if body.lookahead1().peek(kw::ENCODABLE) {
97 body.parse::<kw::ENCODABLE>()?;
98 body.parse::<Token![=]>()?;
99 body.parse::<kw::custom>()?;
100 try_comma()?;
101 encodable = false;
102 continue;
103 }
104 if body.lookahead1().peek(kw::ORD_IMPL) {
105 body.parse::<kw::ORD_IMPL>()?;
106 body.parse::<Token![=]>()?;
107 body.parse::<kw::custom>()?;
108 ord = false;
109 continue;
110 }
111
112 // We've parsed everything that the user provided, so we're done
113 if body.is_empty() {
114 break;
115 }
116
117 // Otherwise, we are parsing a user-defined constant
118 let const_attrs = body.call(Attribute::parse_outer)?;
119 body.parse::<Token![const]>()?;
120 let const_name: Ident = body.parse()?;
121 body.parse::<Token![=]>()?;
122 let const_val: Expr = body.parse()?;
123 try_comma()?;
124 consts.push(quote! { #(#const_attrs)* #vis const #const_name: #name = #name::from_u32(#const_val); });
125 }
126 }
127
128 let debug_format = debug_format.unwrap_or(DebugFormat::Format("{}".to_string()));
129 // shave off 256 indices at the end to allow space for packing these indices into enums
130 let max = max.unwrap_or_else(|| Lit::Int(LitInt::new("0xFFFF_FF00", Span::call_site())));
131
132 let encodable_impls = if encodable {
133 quote! {
134 impl<D: ::rustc_serialize::Decoder> ::rustc_serialize::Decodable<D> for #name {
135 fn decode(d: &mut D) -> Self {
136 Self::from_u32(d.read_u32())
137 }
138 }
139 impl<E: ::rustc_serialize::Encoder> ::rustc_serialize::Encodable<E> for #name {
923072b8
FG
140 fn encode(&self, e: &mut E) {
141 e.emit_u32(self.private);
5e7ed085
FG
142 }
143 }
144 }
145 } else {
146 quote! {}
147 };
148
149 if ord {
150 derive_paths.push(parse_quote!(Ord));
151 derive_paths.push(parse_quote!(PartialOrd));
152 }
153
154 let step = if ord {
155 quote! {
156 impl ::std::iter::Step for #name {
157 #[inline]
158 fn steps_between(start: &Self, end: &Self) -> Option<usize> {
159 <usize as ::std::iter::Step>::steps_between(
160 &Self::index(*start),
161 &Self::index(*end),
162 )
163 }
164
165 #[inline]
166 fn forward_checked(start: Self, u: usize) -> Option<Self> {
167 Self::index(start).checked_add(u).map(Self::from_usize)
168 }
169
170 #[inline]
171 fn backward_checked(start: Self, u: usize) -> Option<Self> {
172 Self::index(start).checked_sub(u).map(Self::from_usize)
173 }
174 }
175
176 // Safety: The implementation of `Step` upholds all invariants.
177 unsafe impl ::std::iter::TrustedStep for #name {}
178 }
179 } else {
180 quote! {}
181 };
182
183 let debug_impl = match debug_format {
184 DebugFormat::Custom => quote! {},
185 DebugFormat::Format(format) => {
186 quote! {
187 impl ::std::fmt::Debug for #name {
188 fn fmt(&self, fmt: &mut ::std::fmt::Formatter<'_>) -> ::std::fmt::Result {
189 write!(fmt, #format, self.as_u32())
190 }
191 }
192 }
193 }
194 };
195
196 Ok(Self(quote! {
197 #(#attrs)*
198 #[derive(Clone, Copy, PartialEq, Eq, Hash, #(#derive_paths),*)]
199 #[rustc_layout_scalar_valid_range_end(#max)]
200 #[rustc_pass_by_value]
201 #vis struct #name {
202 private: u32,
203 }
204
205 #(#consts)*
206
207 impl #name {
208 /// Maximum value the index can take, as a `u32`.
209 #vis const MAX_AS_U32: u32 = #max;
210
211 /// Maximum value the index can take.
212 #vis const MAX: Self = Self::from_u32(#max);
213
214 /// Creates a new index from a given `usize`.
215 ///
216 /// # Panics
217 ///
218 /// Will panic if `value` exceeds `MAX`.
219 #[inline]
220 #vis const fn from_usize(value: usize) -> Self {
221 assert!(value <= (#max as usize));
222 // SAFETY: We just checked that `value <= max`.
223 unsafe {
224 Self::from_u32_unchecked(value as u32)
225 }
226 }
227
228 /// Creates a new index from a given `u32`.
229 ///
230 /// # Panics
231 ///
232 /// Will panic if `value` exceeds `MAX`.
233 #[inline]
234 #vis const fn from_u32(value: u32) -> Self {
235 assert!(value <= #max);
236 // SAFETY: We just checked that `value <= max`.
237 unsafe {
238 Self::from_u32_unchecked(value)
239 }
240 }
241
242 /// Creates a new index from a given `u32`.
243 ///
244 /// # Safety
245 ///
246 /// The provided value must be less than or equal to the maximum value for the newtype.
247 /// Providing a value outside this range is undefined due to layout restrictions.
248 ///
249 /// Prefer using `from_u32`.
250 #[inline]
251 #vis const unsafe fn from_u32_unchecked(value: u32) -> Self {
252 Self { private: value }
253 }
254
255 /// Extracts the value of this index as a `usize`.
256 #[inline]
257 #vis const fn index(self) -> usize {
258 self.as_usize()
259 }
260
261 /// Extracts the value of this index as a `u32`.
262 #[inline]
263 #vis const fn as_u32(self) -> u32 {
264 self.private
265 }
266
267 /// Extracts the value of this index as a `usize`.
268 #[inline]
269 #vis const fn as_usize(self) -> usize {
270 self.as_u32() as usize
271 }
272 }
273
274 impl std::ops::Add<usize> for #name {
275 type Output = Self;
276
277 fn add(self, other: usize) -> Self {
278 Self::from_usize(self.index() + other)
279 }
280 }
281
282 impl rustc_index::vec::Idx for #name {
283 #[inline]
284 fn new(value: usize) -> Self {
285 Self::from_usize(value)
286 }
287
288 #[inline]
289 fn index(self) -> usize {
290 self.as_usize()
291 }
292 }
293
294 #step
295
296 impl From<#name> for u32 {
297 #[inline]
298 fn from(v: #name) -> u32 {
299 v.as_u32()
300 }
301 }
302
303 impl From<#name> for usize {
304 #[inline]
305 fn from(v: #name) -> usize {
306 v.as_usize()
307 }
308 }
309
310 impl From<usize> for #name {
311 #[inline]
312 fn from(value: usize) -> Self {
313 Self::from_usize(value)
314 }
315 }
316
317 impl From<u32> for #name {
318 #[inline]
319 fn from(value: u32) -> Self {
320 Self::from_u32(value)
321 }
322 }
323
324 #encodable_impls
325 #debug_impl
326 }))
327 }
328}
329
330pub fn newtype(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
331 let input = parse_macro_input!(input as Newtype);
332 input.0.into()
333}