]>
Commit | Line | Data |
---|---|---|
60c5eb7d XL |
1 | //! The general point of the optimizations provided here is to simplify something like: |
2 | //! | |
3 | //! ```rust | |
4 | //! match x { | |
5 | //! Ok(x) => Ok(x), | |
6 | //! Err(x) => Err(x) | |
7 | //! } | |
8 | //! ``` | |
9 | //! | |
10 | //! into just `x`. | |
11 | ||
29967ef6 | 12 | use crate::transform::{simplify, MirPass}; |
dfeec247 | 13 | use itertools::Itertools as _; |
f035d41b XL |
14 | use rustc_index::{bit_set::BitSet, vec::IndexVec}; |
15 | use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; | |
ba9703b0 | 16 | use rustc_middle::mir::*; |
3dfed10e | 17 | use rustc_middle::ty::{self, List, Ty, TyCtxt}; |
60c5eb7d | 18 | use rustc_target::abi::VariantIdx; |
1b1a35ee | 19 | use std::iter::{once, Enumerate, Peekable}; |
f9f354fc | 20 | use std::slice::Iter; |
60c5eb7d XL |
21 | |
22 | /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. | |
23 | /// | |
24 | /// This is done by transforming basic blocks where the statements match: | |
25 | /// | |
26 | /// ```rust | |
27 | /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); | |
f9f354fc XL |
28 | /// _TMP_2 = _LOCAL_TMP; |
29 | /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; | |
60c5eb7d XL |
30 | /// discriminant(_LOCAL_0) = VAR_IDX; |
31 | /// ``` | |
32 | /// | |
33 | /// into: | |
34 | /// | |
35 | /// ```rust | |
36 | /// _LOCAL_0 = move _LOCAL_1 | |
37 | /// ``` | |
38 | pub struct SimplifyArmIdentity; | |
39 | ||
f9f354fc XL |
40 | #[derive(Debug)] |
41 | struct ArmIdentityInfo<'tcx> { | |
42 | /// Storage location for the variant's field | |
43 | local_temp_0: Local, | |
44 | /// Storage location holding the variant being read from | |
45 | local_1: Local, | |
46 | /// The variant field being read from | |
47 | vf_s0: VarField<'tcx>, | |
48 | /// Index of the statement which loads the variant being read | |
49 | get_variant_field_stmt: usize, | |
50 | ||
51 | /// Tracks each assignment to a temporary of the variant's field | |
52 | field_tmp_assignments: Vec<(Local, Local)>, | |
53 | ||
54 | /// Storage location holding the variant's field that was read from | |
55 | local_tmp_s1: Local, | |
56 | /// Storage location holding the enum that we are writing to | |
57 | local_0: Local, | |
58 | /// The variant field being written to | |
59 | vf_s1: VarField<'tcx>, | |
60 | ||
61 | /// Storage location that the discriminant is being written to | |
62 | set_discr_local: Local, | |
63 | /// The variant being written | |
64 | set_discr_var_idx: VariantIdx, | |
65 | ||
66 | /// Index of the statement that should be overwritten as a move | |
67 | stmt_to_overwrite: usize, | |
68 | /// SourceInfo for the new move | |
69 | source_info: SourceInfo, | |
70 | ||
71 | /// Indices of matching Storage{Live,Dead} statements encountered. | |
72 | /// (StorageLive index,, StorageDead index, Local) | |
73 | storage_stmts: Vec<(usize, usize, Local)>, | |
74 | ||
75 | /// The statements that should be removed (turned into nops) | |
76 | stmts_to_remove: Vec<usize>, | |
f035d41b XL |
77 | |
78 | /// Indices of debug variables that need to be adjusted to point to | |
79 | // `{local_0}.{dbg_projection}`. | |
80 | dbg_info_to_adjust: Vec<usize>, | |
81 | ||
82 | /// The projection used to rewrite debug info. | |
83 | dbg_projection: &'tcx List<PlaceElem<'tcx>>, | |
f9f354fc XL |
84 | } |
85 | ||
f035d41b XL |
86 | fn get_arm_identity_info<'a, 'tcx>( |
87 | stmts: &'a [Statement<'tcx>], | |
88 | locals_count: usize, | |
89 | debug_info: &'a [VarDebugInfo<'tcx>], | |
90 | ) -> Option<ArmIdentityInfo<'tcx>> { | |
f9f354fc XL |
91 | // This can't possibly match unless there are at least 3 statements in the block |
92 | // so fail fast on tiny blocks. | |
93 | if stmts.len() < 3 { | |
94 | return None; | |
95 | } | |
96 | ||
97 | let mut tmp_assigns = Vec::new(); | |
98 | let mut nop_stmts = Vec::new(); | |
99 | let mut storage_stmts = Vec::new(); | |
100 | let mut storage_live_stmts = Vec::new(); | |
101 | let mut storage_dead_stmts = Vec::new(); | |
102 | ||
103 | type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>; | |
104 | ||
105 | fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool { | |
106 | matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) | |
107 | } | |
108 | ||
109 | /// Eats consecutive Statements which match `test`, performing the specified `action` for each. | |
110 | /// The iterator `stmt_iter` is not advanced if none were matched. | |
111 | fn try_eat<'a, 'tcx>( | |
112 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
113 | test: impl Fn(&'a Statement<'tcx>) -> bool, | |
f035d41b | 114 | mut action: impl FnMut(usize, &'a Statement<'tcx>), |
f9f354fc | 115 | ) { |
5869c6ff | 116 | while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) { |
f9f354fc XL |
117 | let (idx, stmt) = stmt_iter.next().unwrap(); |
118 | ||
119 | action(idx, stmt); | |
120 | } | |
121 | } | |
122 | ||
123 | /// Eats consecutive `StorageLive` and `StorageDead` Statements. | |
124 | /// The iterator `stmt_iter` is not advanced if none were found. | |
125 | fn try_eat_storage_stmts<'a, 'tcx>( | |
126 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
127 | storage_live_stmts: &mut Vec<(usize, Local)>, | |
128 | storage_dead_stmts: &mut Vec<(usize, Local)>, | |
129 | ) { | |
130 | try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { | |
131 | if let StatementKind::StorageLive(l) = stmt.kind { | |
132 | storage_live_stmts.push((idx, l)); | |
133 | } else if let StatementKind::StorageDead(l) = stmt.kind { | |
134 | storage_dead_stmts.push((idx, l)); | |
135 | } | |
136 | }) | |
137 | } | |
138 | ||
139 | fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool { | |
140 | use rustc_middle::mir::StatementKind::Assign; | |
141 | if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { | |
142 | place.as_local().is_some() && p.as_local().is_some() | |
143 | } else { | |
144 | false | |
145 | } | |
146 | } | |
147 | ||
148 | /// Eats consecutive `Assign` Statements. | |
149 | // The iterator `stmt_iter` is not advanced if none were found. | |
150 | fn try_eat_assign_tmp_stmts<'a, 'tcx>( | |
151 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
152 | tmp_assigns: &mut Vec<(Local, Local)>, | |
153 | nop_stmts: &mut Vec<usize>, | |
154 | ) { | |
155 | try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { | |
156 | use rustc_middle::mir::StatementKind::Assign; | |
157 | if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = | |
158 | &stmt.kind | |
159 | { | |
160 | tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); | |
161 | nop_stmts.push(idx); | |
162 | } | |
163 | }) | |
164 | } | |
165 | ||
166 | fn find_storage_live_dead_stmts_for_local<'tcx>( | |
167 | local: Local, | |
168 | stmts: &[Statement<'tcx>], | |
169 | ) -> Option<(usize, usize)> { | |
170 | trace!("looking for {:?}", local); | |
171 | let mut storage_live_stmt = None; | |
172 | let mut storage_dead_stmt = None; | |
173 | for (idx, stmt) in stmts.iter().enumerate() { | |
174 | if stmt.kind == StatementKind::StorageLive(local) { | |
175 | storage_live_stmt = Some(idx); | |
176 | } else if stmt.kind == StatementKind::StorageDead(local) { | |
177 | storage_dead_stmt = Some(idx); | |
178 | } | |
179 | } | |
180 | ||
181 | Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) | |
182 | } | |
183 | ||
184 | // Try to match the expected MIR structure with the basic block we're processing. | |
185 | // We want to see something that looks like: | |
186 | // ``` | |
187 | // (StorageLive(_) | StorageDead(_));* | |
188 | // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); | |
189 | // (StorageLive(_) | StorageDead(_));* | |
190 | // (tmp_n+1 = tmp_n);* | |
191 | // (StorageLive(_) | StorageDead(_));* | |
192 | // (tmp_n+1 = tmp_n);* | |
193 | // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; | |
194 | // discriminant(LOCAL_FROM) = VariantIdx; | |
195 | // (StorageLive(_) | StorageDead(_));* | |
196 | // ``` | |
197 | let mut stmt_iter = stmts.iter().enumerate().peekable(); | |
198 | ||
199 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
200 | ||
201 | let (get_variant_field_stmt, stmt) = stmt_iter.next()?; | |
f035d41b | 202 | let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; |
f9f354fc XL |
203 | |
204 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
205 | ||
206 | try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); | |
207 | ||
208 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
209 | ||
210 | try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); | |
211 | ||
212 | let (idx, stmt) = stmt_iter.next()?; | |
213 | let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; | |
214 | nop_stmts.push(idx); | |
215 | ||
216 | let (idx, stmt) = stmt_iter.next()?; | |
217 | let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; | |
218 | let discr_stmt_source_info = stmt.source_info; | |
219 | nop_stmts.push(idx); | |
220 | ||
221 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
222 | ||
223 | for (live_idx, live_local) in storage_live_stmts { | |
224 | if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { | |
225 | let (dead_idx, _) = storage_dead_stmts.swap_remove(i); | |
226 | storage_stmts.push((live_idx, dead_idx, live_local)); | |
227 | ||
228 | if live_local == local_tmp_s0 { | |
229 | nop_stmts.push(get_variant_field_stmt); | |
230 | } | |
231 | } | |
232 | } | |
1b1a35ee XL |
233 | // We sort primitive usize here so we can use unstable sort |
234 | nop_stmts.sort_unstable(); | |
f9f354fc XL |
235 | |
236 | // Use one of the statements we're going to discard between the point | |
237 | // where the storage location for the variant field becomes live and | |
238 | // is killed. | |
239 | let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; | |
240 | let stmt_to_overwrite = | |
241 | nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); | |
242 | ||
f035d41b XL |
243 | let mut tmp_assigned_vars = BitSet::new_empty(locals_count); |
244 | for (l, r) in &tmp_assigns { | |
245 | tmp_assigned_vars.insert(*l); | |
246 | tmp_assigned_vars.insert(*r); | |
247 | } | |
248 | ||
fc512014 XL |
249 | let dbg_info_to_adjust: Vec<_> = debug_info |
250 | .iter() | |
251 | .enumerate() | |
252 | .filter_map(|(i, var_info)| { | |
253 | if let VarDebugInfoContents::Place(p) = var_info.value { | |
254 | if tmp_assigned_vars.contains(p.local) { | |
255 | return Some(i); | |
256 | } | |
257 | } | |
258 | ||
259 | None | |
260 | }) | |
261 | .collect(); | |
f035d41b | 262 | |
f9f354fc XL |
263 | Some(ArmIdentityInfo { |
264 | local_temp_0: local_tmp_s0, | |
265 | local_1, | |
266 | vf_s0, | |
267 | get_variant_field_stmt, | |
268 | field_tmp_assignments: tmp_assigns, | |
269 | local_tmp_s1, | |
270 | local_0, | |
271 | vf_s1, | |
272 | set_discr_local, | |
273 | set_discr_var_idx, | |
274 | stmt_to_overwrite: *stmt_to_overwrite?, | |
275 | source_info: discr_stmt_source_info, | |
276 | storage_stmts, | |
277 | stmts_to_remove: nop_stmts, | |
f035d41b XL |
278 | dbg_info_to_adjust, |
279 | dbg_projection, | |
f9f354fc XL |
280 | }) |
281 | } | |
282 | ||
283 | fn optimization_applies<'tcx>( | |
284 | opt_info: &ArmIdentityInfo<'tcx>, | |
285 | local_decls: &IndexVec<Local, LocalDecl<'tcx>>, | |
f035d41b XL |
286 | local_uses: &IndexVec<Local, usize>, |
287 | var_debug_info: &[VarDebugInfo<'tcx>], | |
f9f354fc XL |
288 | ) -> bool { |
289 | trace!("testing if optimization applies..."); | |
290 | ||
291 | // FIXME(wesleywiser): possibly relax this restriction? | |
292 | if opt_info.local_0 == opt_info.local_1 { | |
293 | trace!("NO: moving into ourselves"); | |
294 | return false; | |
295 | } else if opt_info.vf_s0 != opt_info.vf_s1 { | |
296 | trace!("NO: the field-and-variant information do not match"); | |
297 | return false; | |
298 | } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { | |
299 | // FIXME(Centril,oli-obk): possibly relax to same layout? | |
300 | trace!("NO: source and target locals have different types"); | |
301 | return false; | |
302 | } else if (opt_info.local_0, opt_info.vf_s0.var_idx) | |
303 | != (opt_info.set_discr_local, opt_info.set_discr_var_idx) | |
304 | { | |
305 | trace!("NO: the discriminants do not match"); | |
306 | return false; | |
307 | } | |
308 | ||
fc512014 | 309 | // Verify the assignment chain consists of the form b = a; c = b; d = c; etc... |
f035d41b | 310 | if opt_info.field_tmp_assignments.is_empty() { |
f9f354fc | 311 | trace!("NO: no assignments found"); |
f035d41b | 312 | return false; |
f9f354fc XL |
313 | } |
314 | let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; | |
315 | let source_local = last_assigned_to; | |
316 | for (l, r) in &opt_info.field_tmp_assignments { | |
317 | if *r != last_assigned_to { | |
318 | trace!("NO: found unexpected assignment {:?} = {:?}", l, r); | |
319 | return false; | |
320 | } | |
321 | ||
322 | last_assigned_to = *l; | |
323 | } | |
324 | ||
f035d41b XL |
325 | // Check that the first and last used locals are only used twice |
326 | // since they are of the form: | |
327 | // | |
328 | // ``` | |
329 | // _first = ((_x as Variant).n: ty); | |
330 | // _n = _first; | |
331 | // ... | |
332 | // ((_y as Variant).n: ty) = _n; | |
333 | // discriminant(_y) = z; | |
334 | // ``` | |
335 | for (l, r) in &opt_info.field_tmp_assignments { | |
336 | if local_uses[*l] != 2 { | |
337 | warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); | |
338 | return false; | |
339 | } else if local_uses[*r] != 2 { | |
340 | warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); | |
341 | return false; | |
342 | } | |
343 | } | |
344 | ||
345 | // Check that debug info only points to full Locals and not projections. | |
346 | for dbg_idx in &opt_info.dbg_info_to_adjust { | |
347 | let dbg_info = &var_debug_info[*dbg_idx]; | |
fc512014 XL |
348 | if let VarDebugInfoContents::Place(p) = dbg_info.value { |
349 | if !p.projection.is_empty() { | |
350 | trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p); | |
351 | return false; | |
352 | } | |
f035d41b XL |
353 | } |
354 | } | |
355 | ||
f9f354fc XL |
356 | if source_local != opt_info.local_temp_0 { |
357 | trace!( | |
358 | "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", | |
359 | source_local, | |
360 | opt_info.local_temp_0 | |
361 | ); | |
362 | return false; | |
363 | } else if last_assigned_to != opt_info.local_tmp_s1 { | |
364 | trace!( | |
365 | "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}", | |
366 | last_assigned_to, | |
367 | opt_info.local_tmp_s1 | |
368 | ); | |
369 | return false; | |
370 | } | |
371 | ||
372 | trace!("SUCCESS: optimization applies!"); | |
3dfed10e | 373 | true |
f9f354fc XL |
374 | } |
375 | ||
60c5eb7d | 376 | impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { |
29967ef6 | 377 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
1b1a35ee XL |
378 | // FIXME(77359): This optimization can result in unsoundness. |
379 | if !tcx.sess.opts.debugging_opts.unsound_mir_opts { | |
f9f354fc XL |
380 | return; |
381 | } | |
382 | ||
29967ef6 | 383 | let source = body.source; |
f9f354fc | 384 | trace!("running SimplifyArmIdentity on {:?}", source); |
29967ef6 | 385 | |
f035d41b XL |
386 | let local_uses = LocalUseCounter::get_local_uses(body); |
387 | let (basic_blocks, local_decls, debug_info) = | |
388 | body.basic_blocks_local_decls_mut_and_var_debug_info(); | |
60c5eb7d | 389 | for bb in basic_blocks { |
f035d41b XL |
390 | if let Some(opt_info) = |
391 | get_arm_identity_info(&bb.statements, local_decls.len(), debug_info) | |
392 | { | |
f9f354fc | 393 | trace!("got opt_info = {:#?}", opt_info); |
f035d41b | 394 | if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) { |
f9f354fc XL |
395 | debug!("optimization skipped for {:?}", source); |
396 | continue; | |
397 | } | |
60c5eb7d | 398 | |
f9f354fc XL |
399 | // Also remove unused Storage{Live,Dead} statements which correspond |
400 | // to temps used previously. | |
401 | for (live_idx, dead_idx, local) in &opt_info.storage_stmts { | |
402 | // The temporary that we've read the variant field into is scoped to this block, | |
403 | // so we can remove the assignment. | |
404 | if *local == opt_info.local_temp_0 { | |
405 | bb.statements[opt_info.get_variant_field_stmt].make_nop(); | |
406 | } | |
407 | ||
408 | for (left, right) in &opt_info.field_tmp_assignments { | |
409 | if local == left || local == right { | |
410 | bb.statements[*live_idx].make_nop(); | |
411 | bb.statements[*dead_idx].make_nop(); | |
412 | } | |
413 | } | |
60c5eb7d | 414 | } |
f9f354fc XL |
415 | |
416 | // Right shape; transform | |
417 | for stmt_idx in opt_info.stmts_to_remove { | |
418 | bb.statements[stmt_idx].make_nop(); | |
419 | } | |
420 | ||
421 | let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; | |
422 | stmt.source_info = opt_info.source_info; | |
423 | stmt.kind = StatementKind::Assign(box ( | |
424 | opt_info.local_0.into(), | |
425 | Rvalue::Use(Operand::Move(opt_info.local_1.into())), | |
426 | )); | |
427 | ||
428 | bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); | |
429 | ||
f035d41b XL |
430 | // Fix the debug info to point to the right local |
431 | for dbg_index in opt_info.dbg_info_to_adjust { | |
432 | let dbg_info = &mut debug_info[dbg_index]; | |
fc512014 XL |
433 | assert!( |
434 | matches!(dbg_info.value, VarDebugInfoContents::Place(_)), | |
435 | "value was not a Place" | |
436 | ); | |
437 | if let VarDebugInfoContents::Place(p) = &mut dbg_info.value { | |
438 | assert!(p.projection.is_empty()); | |
439 | p.local = opt_info.local_0; | |
440 | p.projection = opt_info.dbg_projection; | |
441 | } | |
f035d41b XL |
442 | } |
443 | ||
f9f354fc | 444 | trace!("block is now {:?}", bb.statements); |
60c5eb7d | 445 | } |
60c5eb7d XL |
446 | } |
447 | } | |
448 | } | |
449 | ||
f035d41b XL |
450 | struct LocalUseCounter { |
451 | local_uses: IndexVec<Local, usize>, | |
452 | } | |
453 | ||
454 | impl LocalUseCounter { | |
455 | fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> { | |
456 | let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; | |
457 | counter.visit_body(body); | |
458 | counter.local_uses | |
459 | } | |
460 | } | |
461 | ||
462 | impl<'tcx> Visitor<'tcx> for LocalUseCounter { | |
463 | fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) { | |
464 | if context.is_storage_marker() | |
465 | || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) | |
466 | { | |
467 | return; | |
468 | } | |
469 | ||
470 | self.local_uses[*local] += 1; | |
471 | } | |
472 | } | |
473 | ||
60c5eb7d XL |
474 | /// Match on: |
475 | /// ```rust | |
476 | /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); | |
477 | /// ``` | |
f035d41b XL |
478 | fn match_get_variant_field<'tcx>( |
479 | stmt: &Statement<'tcx>, | |
480 | ) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> { | |
60c5eb7d | 481 | match &stmt.kind { |
3dfed10e XL |
482 | StatementKind::Assign(box ( |
483 | place_into, | |
484 | Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), | |
485 | )) => { | |
486 | let local_into = place_into.as_local()?; | |
487 | let (local_from, vf) = match_variant_field_place(*pf)?; | |
488 | Some((local_into, local_from, vf, pf.projection)) | |
489 | } | |
60c5eb7d XL |
490 | _ => None, |
491 | } | |
492 | } | |
493 | ||
494 | /// Match on: | |
495 | /// ```rust | |
496 | /// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; | |
497 | /// ``` | |
498 | fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { | |
499 | match &stmt.kind { | |
3dfed10e XL |
500 | StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { |
501 | let local_into = place_into.as_local()?; | |
502 | let (local_from, vf) = match_variant_field_place(*place_from)?; | |
503 | Some((local_into, local_from, vf)) | |
504 | } | |
60c5eb7d XL |
505 | _ => None, |
506 | } | |
507 | } | |
508 | ||
509 | /// Match on: | |
510 | /// ```rust | |
511 | /// discriminant(_LOCAL_TO_SET) = VAR_IDX; | |
512 | /// ``` | |
513 | fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> { | |
514 | match &stmt.kind { | |
dfeec247 XL |
515 | StatementKind::SetDiscriminant { place, variant_index } => { |
516 | Some((place.as_local()?, *variant_index)) | |
517 | } | |
60c5eb7d XL |
518 | _ => None, |
519 | } | |
520 | } | |
521 | ||
f9f354fc | 522 | #[derive(PartialEq, Debug)] |
60c5eb7d XL |
523 | struct VarField<'tcx> { |
524 | field: Field, | |
525 | field_ty: Ty<'tcx>, | |
526 | var_idx: VariantIdx, | |
527 | } | |
528 | ||
529 | /// Match on `((_LOCAL as Variant).FIELD: TY)`. | |
ba9703b0 | 530 | fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> { |
60c5eb7d XL |
531 | match place.as_ref() { |
532 | PlaceRef { | |
dfeec247 | 533 | local, |
60c5eb7d | 534 | projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], |
74b04a01 | 535 | } => Some((local, VarField { field, field_ty: ty, var_idx })), |
60c5eb7d XL |
536 | _ => None, |
537 | } | |
538 | } | |
539 | ||
540 | /// Simplifies `SwitchInt(_) -> [targets]`, | |
541 | /// where all the `targets` have the same form, | |
542 | /// into `goto -> target_first`. | |
543 | pub struct SimplifyBranchSame; | |
544 | ||
545 | impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { | |
29967ef6 XL |
546 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
547 | trace!("Running SimplifyBranchSame on {:?}", body.source); | |
3dfed10e XL |
548 | let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; |
549 | let opts = finder.find(); | |
550 | ||
551 | let did_remove_blocks = opts.len() > 0; | |
552 | for opt in opts.iter() { | |
553 | trace!("SUCCESS: Applying optimization {:?}", opt); | |
554 | // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. | |
555 | body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = | |
556 | TerminatorKind::Goto { target: opt.bb_to_goto }; | |
557 | } | |
60c5eb7d | 558 | |
3dfed10e XL |
559 | if did_remove_blocks { |
560 | // We have dead blocks now, so remove those. | |
561 | simplify::remove_dead_blocks(body); | |
562 | } | |
563 | } | |
564 | } | |
565 | ||
566 | #[derive(Debug)] | |
567 | struct SimplifyBranchSameOptimization { | |
568 | /// All basic blocks are equal so go to this one | |
569 | bb_to_goto: BasicBlock, | |
570 | /// Basic block where the terminator can be simplified to a goto | |
571 | bb_to_opt_terminator: BasicBlock, | |
572 | } | |
573 | ||
1b1a35ee XL |
574 | struct SwitchTargetAndValue { |
575 | target: BasicBlock, | |
576 | // None in case of the `otherwise` case | |
577 | value: Option<u128>, | |
578 | } | |
579 | ||
3dfed10e XL |
580 | struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { |
581 | body: &'a Body<'tcx>, | |
582 | tcx: TyCtxt<'tcx>, | |
583 | } | |
584 | ||
585 | impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { | |
586 | fn find(&self) -> Vec<SimplifyBranchSameOptimization> { | |
587 | self.body | |
588 | .basic_blocks() | |
589 | .iter_enumerated() | |
590 | .filter_map(|(bb_idx, bb)| { | |
1b1a35ee | 591 | let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { |
29967ef6 XL |
592 | TerminatorKind::SwitchInt { targets, discr, .. } => { |
593 | let targets_and_values: Vec<_> = targets.iter() | |
594 | .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) | |
595 | .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) | |
1b1a35ee | 596 | .collect(); |
29967ef6 XL |
597 | (discr, targets_and_values) |
598 | }, | |
3dfed10e XL |
599 | _ => return None, |
600 | }; | |
601 | ||
602 | // find the adt that has its discriminant read | |
603 | // assuming this must be the last statement of the block | |
604 | let adt_matched_on = match &bb.statements.last()?.kind { | |
605 | StatementKind::Assign(box (place, rhs)) | |
606 | if Some(*place) == discr_switched_on.place() => | |
607 | { | |
608 | match rhs { | |
609 | Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, | |
610 | _ => { | |
611 | trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); | |
612 | return None; | |
613 | } | |
614 | } | |
615 | } | |
616 | other => { | |
617 | trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); | |
618 | return None | |
619 | }, | |
620 | }; | |
621 | ||
1b1a35ee | 622 | let mut iter_bbs_reachable = targets_and_values |
3dfed10e | 623 | .iter() |
1b1a35ee | 624 | .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target])) |
3dfed10e XL |
625 | .filter(|(_, bb)| { |
626 | // Reaching `unreachable` is UB so assume it doesn't happen. | |
627 | bb.terminator().kind != TerminatorKind::Unreachable | |
60c5eb7d XL |
628 | // But `asm!(...)` could abort the program, |
629 | // so we cannot assume that the `unreachable` terminator itself is reachable. | |
630 | // FIXME(Centril): use a normalization pass instead of a check. | |
631 | || bb.statements.iter().any(|stmt| match stmt.kind { | |
ba9703b0 | 632 | StatementKind::LlvmInlineAsm(..) => true, |
60c5eb7d XL |
633 | _ => false, |
634 | }) | |
3dfed10e XL |
635 | }) |
636 | .peekable(); | |
637 | ||
5869c6ff | 638 | let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx); |
3dfed10e XL |
639 | let mut all_successors_equivalent = StatementEquality::TrivialEqual; |
640 | ||
641 | // All successor basic blocks must be equal or contain statements that are pairwise considered equal. | |
1b1a35ee | 642 | for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { |
3dfed10e | 643 | let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup |
6c58768f XL |
644 | && bb_l.terminator().kind == bb_r.terminator().kind |
645 | && bb_l.statements.len() == bb_r.statements.len(); | |
3dfed10e XL |
646 | let statement_check = || { |
647 | bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { | |
1b1a35ee | 648 | let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); |
3dfed10e XL |
649 | if matches!(stmt_equality, StatementEquality::NotEqual) { |
650 | // short circuit | |
651 | None | |
652 | } else { | |
653 | Some(acc.combine(&stmt_equality)) | |
654 | } | |
655 | }) | |
656 | .unwrap_or(StatementEquality::NotEqual) | |
657 | }; | |
658 | if !trivial_checks { | |
659 | all_successors_equivalent = StatementEquality::NotEqual; | |
660 | break; | |
661 | } | |
662 | all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); | |
663 | }; | |
664 | ||
665 | match all_successors_equivalent{ | |
666 | StatementEquality::TrivialEqual => { | |
667 | // statements are trivially equal, so just take first | |
668 | trace!("Statements are trivially equal"); | |
669 | Some(SimplifyBranchSameOptimization { | |
1b1a35ee | 670 | bb_to_goto: bb_first.target, |
3dfed10e XL |
671 | bb_to_opt_terminator: bb_idx, |
672 | }) | |
673 | } | |
674 | StatementEquality::ConsideredEqual(bb_to_choose) => { | |
675 | trace!("Statements are considered equal"); | |
676 | Some(SimplifyBranchSameOptimization { | |
677 | bb_to_goto: bb_to_choose, | |
678 | bb_to_opt_terminator: bb_idx, | |
679 | }) | |
680 | } | |
681 | StatementEquality::NotEqual => { | |
682 | trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); | |
683 | None | |
684 | } | |
685 | } | |
686 | }) | |
687 | .collect() | |
688 | } | |
689 | ||
690 | /// Tests if two statements can be considered equal | |
691 | /// | |
692 | /// Statements can be trivially equal if the kinds match. | |
693 | /// But they can also be considered equal in the following case A: | |
694 | /// ``` | |
695 | /// discriminant(_0) = 0; // bb1 | |
696 | /// _0 = move _1; // bb2 | |
697 | /// ``` | |
698 | /// In this case the two statements are equal iff | |
6a06907d XL |
699 | /// - `_0` is an enum where the variant index 0 is fieldless, and |
700 | /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on | |
3dfed10e XL |
701 | fn statement_equality( |
702 | &self, | |
703 | adt_matched_on: Place<'tcx>, | |
704 | x: &Statement<'tcx>, | |
1b1a35ee | 705 | x_target_and_value: &SwitchTargetAndValue, |
3dfed10e | 706 | y: &Statement<'tcx>, |
1b1a35ee | 707 | y_target_and_value: &SwitchTargetAndValue, |
3dfed10e XL |
708 | ) -> StatementEquality { |
709 | let helper = |rhs: &Rvalue<'tcx>, | |
1b1a35ee | 710 | place: &Place<'tcx>, |
3dfed10e XL |
711 | variant_index: &VariantIdx, |
712 | side_to_choose| { | |
713 | let place_type = place.ty(self.body, self.tcx).ty; | |
1b1a35ee | 714 | let adt = match *place_type.kind() { |
3dfed10e XL |
715 | ty::Adt(adt, _) if adt.is_enum() => adt, |
716 | _ => return StatementEquality::NotEqual, | |
717 | }; | |
718 | let variant_is_fieldless = adt.variants[*variant_index].fields.is_empty(); | |
719 | if !variant_is_fieldless { | |
720 | trace!("NO: variant {:?} was not fieldless", variant_index); | |
721 | return StatementEquality::NotEqual; | |
722 | } | |
723 | ||
724 | match rhs { | |
725 | Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { | |
1b1a35ee | 726 | StatementEquality::ConsideredEqual(side_to_choose) |
3dfed10e XL |
727 | } |
728 | _ => { | |
729 | trace!( | |
730 | "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", | |
731 | rhs, | |
732 | adt_matched_on | |
733 | ); | |
734 | StatementEquality::NotEqual | |
735 | } | |
736 | } | |
737 | }; | |
738 | match (&x.kind, &y.kind) { | |
739 | // trivial case | |
740 | (x, y) if x == y => StatementEquality::TrivialEqual, | |
741 | ||
742 | // check for case A | |
743 | ( | |
744 | StatementKind::Assign(box (_, rhs)), | |
745 | StatementKind::SetDiscriminant { place, variant_index }, | |
1b1a35ee XL |
746 | ) |
747 | // we need to make sure that the switch value that targets the bb with SetDiscriminant (y), is the same as the variant index | |
748 | if Some(variant_index.index() as u128) == y_target_and_value.value => { | |
3dfed10e | 749 | // choose basic block of x, as that has the assign |
1b1a35ee | 750 | helper(rhs, place, variant_index, x_target_and_value.target) |
3dfed10e XL |
751 | } |
752 | ( | |
753 | StatementKind::SetDiscriminant { place, variant_index }, | |
754 | StatementKind::Assign(box (_, rhs)), | |
1b1a35ee XL |
755 | ) |
756 | // we need to make sure that the switch value that targets the bb with SetDiscriminant (x), is the same as the variant index | |
757 | if Some(variant_index.index() as u128) == x_target_and_value.value => { | |
3dfed10e | 758 | // choose basic block of y, as that has the assign |
1b1a35ee | 759 | helper(rhs, place, variant_index, y_target_and_value.target) |
3dfed10e XL |
760 | } |
761 | _ => { | |
762 | trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); | |
763 | StatementEquality::NotEqual | |
60c5eb7d XL |
764 | } |
765 | } | |
3dfed10e XL |
766 | } |
767 | } | |
60c5eb7d | 768 | |
3dfed10e XL |
769 | #[derive(Copy, Clone, Eq, PartialEq)] |
770 | enum StatementEquality { | |
771 | /// The two statements are trivially equal; same kind | |
772 | TrivialEqual, | |
773 | /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. | |
774 | /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index | |
775 | ConsideredEqual(BasicBlock), | |
776 | /// The two statements are not equal | |
777 | NotEqual, | |
778 | } | |
779 | ||
780 | impl StatementEquality { | |
781 | fn combine(&self, other: &StatementEquality) -> StatementEquality { | |
782 | use StatementEquality::*; | |
783 | match (self, other) { | |
784 | (TrivialEqual, TrivialEqual) => TrivialEqual, | |
785 | (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { | |
786 | ConsideredEqual(*b) | |
787 | } | |
788 | (ConsideredEqual(b1), ConsideredEqual(b2)) => { | |
789 | if b1 == b2 { | |
790 | ConsideredEqual(*b1) | |
791 | } else { | |
792 | NotEqual | |
793 | } | |
794 | } | |
795 | (_, NotEqual) | (NotEqual, _) => NotEqual, | |
60c5eb7d XL |
796 | } |
797 | } | |
798 | } |