]> git.proxmox.com Git - rustc.git/blame - compiler/rustc_mir_transform/src/simplify_try.rs
New upstream version 1.57.0+dfsg1
[rustc.git] / compiler / rustc_mir_transform / src / simplify_try.rs
CommitLineData
60c5eb7d
XL
1//! The general point of the optimizations provided here is to simplify something like:
2//!
3//! ```rust
4//! match x {
5//! Ok(x) => Ok(x),
6//! Err(x) => Err(x)
7//! }
8//! ```
9//!
10//! into just `x`.
11
c295e0f8 12use crate::{simplify, MirPass};
dfeec247 13use itertools::Itertools as _;
f035d41b
XL
14use rustc_index::{bit_set::BitSet, vec::IndexVec};
15use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor};
ba9703b0 16use rustc_middle::mir::*;
3dfed10e 17use rustc_middle::ty::{self, List, Ty, TyCtxt};
60c5eb7d 18use rustc_target::abi::VariantIdx;
1b1a35ee 19use std::iter::{once, Enumerate, Peekable};
f9f354fc 20use std::slice::Iter;
60c5eb7d
XL
21
22/// Simplifies arms of form `Variant(x) => Variant(x)` to just a move.
23///
24/// This is done by transforming basic blocks where the statements match:
25///
26/// ```rust
27/// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY );
f9f354fc
XL
28/// _TMP_2 = _LOCAL_TMP;
29/// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2;
60c5eb7d
XL
30/// discriminant(_LOCAL_0) = VAR_IDX;
31/// ```
32///
33/// into:
34///
35/// ```rust
36/// _LOCAL_0 = move _LOCAL_1
37/// ```
38pub struct SimplifyArmIdentity;
39
f9f354fc
XL
40#[derive(Debug)]
41struct ArmIdentityInfo<'tcx> {
42 /// Storage location for the variant's field
43 local_temp_0: Local,
44 /// Storage location holding the variant being read from
45 local_1: Local,
46 /// The variant field being read from
47 vf_s0: VarField<'tcx>,
48 /// Index of the statement which loads the variant being read
49 get_variant_field_stmt: usize,
50
51 /// Tracks each assignment to a temporary of the variant's field
52 field_tmp_assignments: Vec<(Local, Local)>,
53
54 /// Storage location holding the variant's field that was read from
55 local_tmp_s1: Local,
56 /// Storage location holding the enum that we are writing to
57 local_0: Local,
58 /// The variant field being written to
59 vf_s1: VarField<'tcx>,
60
61 /// Storage location that the discriminant is being written to
62 set_discr_local: Local,
63 /// The variant being written
64 set_discr_var_idx: VariantIdx,
65
66 /// Index of the statement that should be overwritten as a move
67 stmt_to_overwrite: usize,
68 /// SourceInfo for the new move
69 source_info: SourceInfo,
70
71 /// Indices of matching Storage{Live,Dead} statements encountered.
72 /// (StorageLive index,, StorageDead index, Local)
73 storage_stmts: Vec<(usize, usize, Local)>,
74
75 /// The statements that should be removed (turned into nops)
76 stmts_to_remove: Vec<usize>,
f035d41b
XL
77
78 /// Indices of debug variables that need to be adjusted to point to
79 // `{local_0}.{dbg_projection}`.
80 dbg_info_to_adjust: Vec<usize>,
81
82 /// The projection used to rewrite debug info.
83 dbg_projection: &'tcx List<PlaceElem<'tcx>>,
f9f354fc
XL
84}
85
f035d41b
XL
86fn get_arm_identity_info<'a, 'tcx>(
87 stmts: &'a [Statement<'tcx>],
88 locals_count: usize,
89 debug_info: &'a [VarDebugInfo<'tcx>],
90) -> Option<ArmIdentityInfo<'tcx>> {
f9f354fc
XL
91 // This can't possibly match unless there are at least 3 statements in the block
92 // so fail fast on tiny blocks.
93 if stmts.len() < 3 {
94 return None;
95 }
96
97 let mut tmp_assigns = Vec::new();
98 let mut nop_stmts = Vec::new();
99 let mut storage_stmts = Vec::new();
100 let mut storage_live_stmts = Vec::new();
101 let mut storage_dead_stmts = Vec::new();
102
103 type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>;
104
105 fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
106 matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_))
107 }
108
109 /// Eats consecutive Statements which match `test`, performing the specified `action` for each.
110 /// The iterator `stmt_iter` is not advanced if none were matched.
111 fn try_eat<'a, 'tcx>(
112 stmt_iter: &mut StmtIter<'a, 'tcx>,
113 test: impl Fn(&'a Statement<'tcx>) -> bool,
f035d41b 114 mut action: impl FnMut(usize, &'a Statement<'tcx>),
f9f354fc 115 ) {
5869c6ff 116 while stmt_iter.peek().map_or(false, |(_, stmt)| test(stmt)) {
f9f354fc
XL
117 let (idx, stmt) = stmt_iter.next().unwrap();
118
119 action(idx, stmt);
120 }
121 }
122
123 /// Eats consecutive `StorageLive` and `StorageDead` Statements.
124 /// The iterator `stmt_iter` is not advanced if none were found.
125 fn try_eat_storage_stmts<'a, 'tcx>(
126 stmt_iter: &mut StmtIter<'a, 'tcx>,
127 storage_live_stmts: &mut Vec<(usize, Local)>,
128 storage_dead_stmts: &mut Vec<(usize, Local)>,
129 ) {
130 try_eat(stmt_iter, is_storage_stmt, |idx, stmt| {
131 if let StatementKind::StorageLive(l) = stmt.kind {
132 storage_live_stmts.push((idx, l));
133 } else if let StatementKind::StorageDead(l) = stmt.kind {
134 storage_dead_stmts.push((idx, l));
135 }
136 })
137 }
138
139 fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool {
140 use rustc_middle::mir::StatementKind::Assign;
141 if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind {
142 place.as_local().is_some() && p.as_local().is_some()
143 } else {
144 false
145 }
146 }
147
148 /// Eats consecutive `Assign` Statements.
149 // The iterator `stmt_iter` is not advanced if none were found.
150 fn try_eat_assign_tmp_stmts<'a, 'tcx>(
151 stmt_iter: &mut StmtIter<'a, 'tcx>,
152 tmp_assigns: &mut Vec<(Local, Local)>,
153 nop_stmts: &mut Vec<usize>,
154 ) {
155 try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| {
156 use rustc_middle::mir::StatementKind::Assign;
157 if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) =
158 &stmt.kind
159 {
160 tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap()));
161 nop_stmts.push(idx);
162 }
163 })
164 }
165
166 fn find_storage_live_dead_stmts_for_local<'tcx>(
167 local: Local,
168 stmts: &[Statement<'tcx>],
169 ) -> Option<(usize, usize)> {
170 trace!("looking for {:?}", local);
171 let mut storage_live_stmt = None;
172 let mut storage_dead_stmt = None;
173 for (idx, stmt) in stmts.iter().enumerate() {
174 if stmt.kind == StatementKind::StorageLive(local) {
175 storage_live_stmt = Some(idx);
176 } else if stmt.kind == StatementKind::StorageDead(local) {
177 storage_dead_stmt = Some(idx);
178 }
179 }
180
181 Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX)))
182 }
183
184 // Try to match the expected MIR structure with the basic block we're processing.
185 // We want to see something that looks like:
186 // ```
187 // (StorageLive(_) | StorageDead(_));*
188 // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
189 // (StorageLive(_) | StorageDead(_));*
190 // (tmp_n+1 = tmp_n);*
191 // (StorageLive(_) | StorageDead(_));*
192 // (tmp_n+1 = tmp_n);*
193 // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp;
194 // discriminant(LOCAL_FROM) = VariantIdx;
195 // (StorageLive(_) | StorageDead(_));*
196 // ```
197 let mut stmt_iter = stmts.iter().enumerate().peekable();
198
199 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
200
201 let (get_variant_field_stmt, stmt) = stmt_iter.next()?;
f035d41b 202 let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?;
f9f354fc
XL
203
204 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
205
206 try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
207
208 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
209
210 try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts);
211
212 let (idx, stmt) = stmt_iter.next()?;
213 let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?;
214 nop_stmts.push(idx);
215
216 let (idx, stmt) = stmt_iter.next()?;
217 let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?;
218 let discr_stmt_source_info = stmt.source_info;
219 nop_stmts.push(idx);
220
221 try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts);
222
223 for (live_idx, live_local) in storage_live_stmts {
224 if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) {
225 let (dead_idx, _) = storage_dead_stmts.swap_remove(i);
226 storage_stmts.push((live_idx, dead_idx, live_local));
227
228 if live_local == local_tmp_s0 {
229 nop_stmts.push(get_variant_field_stmt);
230 }
231 }
232 }
1b1a35ee
XL
233 // We sort primitive usize here so we can use unstable sort
234 nop_stmts.sort_unstable();
f9f354fc
XL
235
236 // Use one of the statements we're going to discard between the point
237 // where the storage location for the variant field becomes live and
238 // is killed.
239 let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?;
240 let stmt_to_overwrite =
241 nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx);
242
f035d41b
XL
243 let mut tmp_assigned_vars = BitSet::new_empty(locals_count);
244 for (l, r) in &tmp_assigns {
245 tmp_assigned_vars.insert(*l);
246 tmp_assigned_vars.insert(*r);
247 }
248
fc512014
XL
249 let dbg_info_to_adjust: Vec<_> = debug_info
250 .iter()
251 .enumerate()
252 .filter_map(|(i, var_info)| {
253 if let VarDebugInfoContents::Place(p) = var_info.value {
254 if tmp_assigned_vars.contains(p.local) {
255 return Some(i);
256 }
257 }
258
259 None
260 })
261 .collect();
f035d41b 262
f9f354fc
XL
263 Some(ArmIdentityInfo {
264 local_temp_0: local_tmp_s0,
265 local_1,
266 vf_s0,
267 get_variant_field_stmt,
268 field_tmp_assignments: tmp_assigns,
269 local_tmp_s1,
270 local_0,
271 vf_s1,
272 set_discr_local,
273 set_discr_var_idx,
274 stmt_to_overwrite: *stmt_to_overwrite?,
275 source_info: discr_stmt_source_info,
276 storage_stmts,
277 stmts_to_remove: nop_stmts,
f035d41b
XL
278 dbg_info_to_adjust,
279 dbg_projection,
f9f354fc
XL
280 })
281}
282
283fn optimization_applies<'tcx>(
284 opt_info: &ArmIdentityInfo<'tcx>,
285 local_decls: &IndexVec<Local, LocalDecl<'tcx>>,
f035d41b
XL
286 local_uses: &IndexVec<Local, usize>,
287 var_debug_info: &[VarDebugInfo<'tcx>],
f9f354fc
XL
288) -> bool {
289 trace!("testing if optimization applies...");
290
291 // FIXME(wesleywiser): possibly relax this restriction?
292 if opt_info.local_0 == opt_info.local_1 {
293 trace!("NO: moving into ourselves");
294 return false;
295 } else if opt_info.vf_s0 != opt_info.vf_s1 {
296 trace!("NO: the field-and-variant information do not match");
297 return false;
298 } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty {
299 // FIXME(Centril,oli-obk): possibly relax to same layout?
300 trace!("NO: source and target locals have different types");
301 return false;
302 } else if (opt_info.local_0, opt_info.vf_s0.var_idx)
303 != (opt_info.set_discr_local, opt_info.set_discr_var_idx)
304 {
305 trace!("NO: the discriminants do not match");
306 return false;
307 }
308
fc512014 309 // Verify the assignment chain consists of the form b = a; c = b; d = c; etc...
f035d41b 310 if opt_info.field_tmp_assignments.is_empty() {
f9f354fc 311 trace!("NO: no assignments found");
f035d41b 312 return false;
f9f354fc
XL
313 }
314 let mut last_assigned_to = opt_info.field_tmp_assignments[0].1;
315 let source_local = last_assigned_to;
316 for (l, r) in &opt_info.field_tmp_assignments {
317 if *r != last_assigned_to {
318 trace!("NO: found unexpected assignment {:?} = {:?}", l, r);
319 return false;
320 }
321
322 last_assigned_to = *l;
323 }
324
f035d41b
XL
325 // Check that the first and last used locals are only used twice
326 // since they are of the form:
327 //
328 // ```
329 // _first = ((_x as Variant).n: ty);
330 // _n = _first;
331 // ...
332 // ((_y as Variant).n: ty) = _n;
333 // discriminant(_y) = z;
334 // ```
335 for (l, r) in &opt_info.field_tmp_assignments {
336 if local_uses[*l] != 2 {
337 warn!("NO: FAILED assignment chain local {:?} was used more than twice", l);
338 return false;
339 } else if local_uses[*r] != 2 {
340 warn!("NO: FAILED assignment chain local {:?} was used more than twice", r);
341 return false;
342 }
343 }
344
345 // Check that debug info only points to full Locals and not projections.
346 for dbg_idx in &opt_info.dbg_info_to_adjust {
347 let dbg_info = &var_debug_info[*dbg_idx];
fc512014
XL
348 if let VarDebugInfoContents::Place(p) = dbg_info.value {
349 if !p.projection.is_empty() {
350 trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, p);
351 return false;
352 }
f035d41b
XL
353 }
354 }
355
f9f354fc
XL
356 if source_local != opt_info.local_temp_0 {
357 trace!(
358 "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}",
359 source_local,
360 opt_info.local_temp_0
361 );
362 return false;
363 } else if last_assigned_to != opt_info.local_tmp_s1 {
364 trace!(
365 "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}",
366 last_assigned_to,
367 opt_info.local_tmp_s1
368 );
369 return false;
370 }
371
372 trace!("SUCCESS: optimization applies!");
3dfed10e 373 true
f9f354fc
XL
374}
375
60c5eb7d 376impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity {
29967ef6 377 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1b1a35ee
XL
378 // FIXME(77359): This optimization can result in unsoundness.
379 if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
f9f354fc
XL
380 return;
381 }
382
29967ef6 383 let source = body.source;
f9f354fc 384 trace!("running SimplifyArmIdentity on {:?}", source);
29967ef6 385
f035d41b
XL
386 let local_uses = LocalUseCounter::get_local_uses(body);
387 let (basic_blocks, local_decls, debug_info) =
388 body.basic_blocks_local_decls_mut_and_var_debug_info();
60c5eb7d 389 for bb in basic_blocks {
f035d41b
XL
390 if let Some(opt_info) =
391 get_arm_identity_info(&bb.statements, local_decls.len(), debug_info)
392 {
f9f354fc 393 trace!("got opt_info = {:#?}", opt_info);
f035d41b 394 if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) {
f9f354fc
XL
395 debug!("optimization skipped for {:?}", source);
396 continue;
397 }
60c5eb7d 398
f9f354fc
XL
399 // Also remove unused Storage{Live,Dead} statements which correspond
400 // to temps used previously.
401 for (live_idx, dead_idx, local) in &opt_info.storage_stmts {
402 // The temporary that we've read the variant field into is scoped to this block,
403 // so we can remove the assignment.
404 if *local == opt_info.local_temp_0 {
405 bb.statements[opt_info.get_variant_field_stmt].make_nop();
406 }
407
408 for (left, right) in &opt_info.field_tmp_assignments {
409 if local == left || local == right {
410 bb.statements[*live_idx].make_nop();
411 bb.statements[*dead_idx].make_nop();
412 }
413 }
60c5eb7d 414 }
f9f354fc
XL
415
416 // Right shape; transform
417 for stmt_idx in opt_info.stmts_to_remove {
418 bb.statements[stmt_idx].make_nop();
419 }
420
421 let stmt = &mut bb.statements[opt_info.stmt_to_overwrite];
422 stmt.source_info = opt_info.source_info;
94222f64 423 stmt.kind = StatementKind::Assign(Box::new((
f9f354fc
XL
424 opt_info.local_0.into(),
425 Rvalue::Use(Operand::Move(opt_info.local_1.into())),
94222f64 426 )));
f9f354fc
XL
427
428 bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop);
429
f035d41b
XL
430 // Fix the debug info to point to the right local
431 for dbg_index in opt_info.dbg_info_to_adjust {
432 let dbg_info = &mut debug_info[dbg_index];
fc512014
XL
433 assert!(
434 matches!(dbg_info.value, VarDebugInfoContents::Place(_)),
435 "value was not a Place"
436 );
437 if let VarDebugInfoContents::Place(p) = &mut dbg_info.value {
438 assert!(p.projection.is_empty());
439 p.local = opt_info.local_0;
440 p.projection = opt_info.dbg_projection;
441 }
f035d41b
XL
442 }
443
f9f354fc 444 trace!("block is now {:?}", bb.statements);
60c5eb7d 445 }
60c5eb7d
XL
446 }
447 }
448}
449
f035d41b
XL
450struct LocalUseCounter {
451 local_uses: IndexVec<Local, usize>,
452}
453
454impl LocalUseCounter {
455 fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> {
456 let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) };
457 counter.visit_body(body);
458 counter.local_uses
459 }
460}
461
462impl<'tcx> Visitor<'tcx> for LocalUseCounter {
463 fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) {
464 if context.is_storage_marker()
465 || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo)
466 {
467 return;
468 }
469
470 self.local_uses[*local] += 1;
471 }
472}
473
60c5eb7d
XL
474/// Match on:
475/// ```rust
476/// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY);
477/// ```
f035d41b
XL
478fn match_get_variant_field<'tcx>(
479 stmt: &Statement<'tcx>,
480) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> {
60c5eb7d 481 match &stmt.kind {
3dfed10e
XL
482 StatementKind::Assign(box (
483 place_into,
484 Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)),
485 )) => {
486 let local_into = place_into.as_local()?;
487 let (local_from, vf) = match_variant_field_place(*pf)?;
488 Some((local_into, local_from, vf, pf.projection))
489 }
60c5eb7d
XL
490 _ => None,
491 }
492}
493
494/// Match on:
495/// ```rust
496/// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO;
497/// ```
498fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> {
499 match &stmt.kind {
3dfed10e
XL
500 StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => {
501 let local_into = place_into.as_local()?;
502 let (local_from, vf) = match_variant_field_place(*place_from)?;
503 Some((local_into, local_from, vf))
504 }
60c5eb7d
XL
505 _ => None,
506 }
507}
508
509/// Match on:
510/// ```rust
511/// discriminant(_LOCAL_TO_SET) = VAR_IDX;
512/// ```
513fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> {
514 match &stmt.kind {
dfeec247
XL
515 StatementKind::SetDiscriminant { place, variant_index } => {
516 Some((place.as_local()?, *variant_index))
517 }
60c5eb7d
XL
518 _ => None,
519 }
520}
521
f9f354fc 522#[derive(PartialEq, Debug)]
60c5eb7d
XL
523struct VarField<'tcx> {
524 field: Field,
525 field_ty: Ty<'tcx>,
526 var_idx: VariantIdx,
527}
528
529/// Match on `((_LOCAL as Variant).FIELD: TY)`.
ba9703b0 530fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> {
60c5eb7d
XL
531 match place.as_ref() {
532 PlaceRef {
dfeec247 533 local,
60c5eb7d 534 projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)],
74b04a01 535 } => Some((local, VarField { field, field_ty: ty, var_idx })),
60c5eb7d
XL
536 _ => None,
537 }
538}
539
540/// Simplifies `SwitchInt(_) -> [targets]`,
541/// where all the `targets` have the same form,
542/// into `goto -> target_first`.
543pub struct SimplifyBranchSame;
544
545impl<'tcx> MirPass<'tcx> for SimplifyBranchSame {
29967ef6 546 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
dc3f5686
XL
547 // This optimization is disabled by default for now due to
548 // soundness concerns; see issue #89485 and PR #89489.
549 if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
550 return;
551 }
552
29967ef6 553 trace!("Running SimplifyBranchSame on {:?}", body.source);
3dfed10e
XL
554 let finder = SimplifyBranchSameOptimizationFinder { body, tcx };
555 let opts = finder.find();
556
557 let did_remove_blocks = opts.len() > 0;
558 for opt in opts.iter() {
559 trace!("SUCCESS: Applying optimization {:?}", opt);
560 // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`.
561 body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind =
562 TerminatorKind::Goto { target: opt.bb_to_goto };
563 }
60c5eb7d 564
3dfed10e
XL
565 if did_remove_blocks {
566 // We have dead blocks now, so remove those.
17df50a5 567 simplify::remove_dead_blocks(tcx, body);
3dfed10e
XL
568 }
569 }
570}
571
572#[derive(Debug)]
573struct SimplifyBranchSameOptimization {
574 /// All basic blocks are equal so go to this one
575 bb_to_goto: BasicBlock,
576 /// Basic block where the terminator can be simplified to a goto
577 bb_to_opt_terminator: BasicBlock,
578}
579
1b1a35ee
XL
580struct SwitchTargetAndValue {
581 target: BasicBlock,
582 // None in case of the `otherwise` case
583 value: Option<u128>,
584}
585
3dfed10e
XL
586struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
587 body: &'a Body<'tcx>,
588 tcx: TyCtxt<'tcx>,
589}
590
591impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> {
592 fn find(&self) -> Vec<SimplifyBranchSameOptimization> {
593 self.body
594 .basic_blocks()
595 .iter_enumerated()
596 .filter_map(|(bb_idx, bb)| {
1b1a35ee 597 let (discr_switched_on, targets_and_values) = match &bb.terminator().kind {
29967ef6
XL
598 TerminatorKind::SwitchInt { targets, discr, .. } => {
599 let targets_and_values: Vec<_> = targets.iter()
600 .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) })
601 .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None }))
1b1a35ee 602 .collect();
29967ef6
XL
603 (discr, targets_and_values)
604 },
3dfed10e
XL
605 _ => return None,
606 };
607
608 // find the adt that has its discriminant read
609 // assuming this must be the last statement of the block
610 let adt_matched_on = match &bb.statements.last()?.kind {
611 StatementKind::Assign(box (place, rhs))
612 if Some(*place) == discr_switched_on.place() =>
613 {
614 match rhs {
615 Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place,
616 _ => {
617 trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs);
618 return None;
619 }
620 }
621 }
622 other => {
623 trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other);
624 return None
625 },
626 };
627
1b1a35ee 628 let mut iter_bbs_reachable = targets_and_values
3dfed10e 629 .iter()
1b1a35ee 630 .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target]))
3dfed10e
XL
631 .filter(|(_, bb)| {
632 // Reaching `unreachable` is UB so assume it doesn't happen.
633 bb.terminator().kind != TerminatorKind::Unreachable
60c5eb7d
XL
634 // But `asm!(...)` could abort the program,
635 // so we cannot assume that the `unreachable` terminator itself is reachable.
636 // FIXME(Centril): use a normalization pass instead of a check.
17df50a5 637 || bb.statements.iter().any(|stmt| matches!(stmt.kind, StatementKind::LlvmInlineAsm(..)))
3dfed10e
XL
638 })
639 .peekable();
640
5869c6ff 641 let bb_first = iter_bbs_reachable.peek().map_or(&targets_and_values[0], |(idx, _)| *idx);
3dfed10e
XL
642 let mut all_successors_equivalent = StatementEquality::TrivialEqual;
643
644 // All successor basic blocks must be equal or contain statements that are pairwise considered equal.
1b1a35ee 645 for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() {
3dfed10e 646 let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup
6c58768f
XL
647 && bb_l.terminator().kind == bb_r.terminator().kind
648 && bb_l.statements.len() == bb_r.statements.len();
3dfed10e
XL
649 let statement_check = || {
650 bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| {
1b1a35ee 651 let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r);
3dfed10e
XL
652 if matches!(stmt_equality, StatementEquality::NotEqual) {
653 // short circuit
654 None
655 } else {
656 Some(acc.combine(&stmt_equality))
657 }
658 })
659 .unwrap_or(StatementEquality::NotEqual)
660 };
661 if !trivial_checks {
662 all_successors_equivalent = StatementEquality::NotEqual;
663 break;
664 }
665 all_successors_equivalent = all_successors_equivalent.combine(&statement_check());
666 };
667
668 match all_successors_equivalent{
669 StatementEquality::TrivialEqual => {
670 // statements are trivially equal, so just take first
671 trace!("Statements are trivially equal");
672 Some(SimplifyBranchSameOptimization {
1b1a35ee 673 bb_to_goto: bb_first.target,
3dfed10e
XL
674 bb_to_opt_terminator: bb_idx,
675 })
676 }
677 StatementEquality::ConsideredEqual(bb_to_choose) => {
678 trace!("Statements are considered equal");
679 Some(SimplifyBranchSameOptimization {
680 bb_to_goto: bb_to_choose,
681 bb_to_opt_terminator: bb_idx,
682 })
683 }
684 StatementEquality::NotEqual => {
685 trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx);
686 None
687 }
688 }
689 })
690 .collect()
691 }
692
693 /// Tests if two statements can be considered equal
694 ///
695 /// Statements can be trivially equal if the kinds match.
696 /// But they can also be considered equal in the following case A:
697 /// ```
698 /// discriminant(_0) = 0; // bb1
699 /// _0 = move _1; // bb2
700 /// ```
701 /// In this case the two statements are equal iff
6a06907d
XL
702 /// - `_0` is an enum where the variant index 0 is fieldless, and
703 /// - bb1 was targeted by a switch where the discriminant of `_1` was switched on
3dfed10e
XL
704 fn statement_equality(
705 &self,
706 adt_matched_on: Place<'tcx>,
707 x: &Statement<'tcx>,
1b1a35ee 708 x_target_and_value: &SwitchTargetAndValue,
3dfed10e 709 y: &Statement<'tcx>,
1b1a35ee 710 y_target_and_value: &SwitchTargetAndValue,
3dfed10e
XL
711 ) -> StatementEquality {
712 let helper = |rhs: &Rvalue<'tcx>,
1b1a35ee 713 place: &Place<'tcx>,
3dfed10e 714 variant_index: &VariantIdx,
dc3f5686 715 switch_value: u128,
3dfed10e
XL
716 side_to_choose| {
717 let place_type = place.ty(self.body, self.tcx).ty;
1b1a35ee 718 let adt = match *place_type.kind() {
3dfed10e
XL
719 ty::Adt(adt, _) if adt.is_enum() => adt,
720 _ => return StatementEquality::NotEqual,
721 };
dc3f5686
XL
722 // We need to make sure that the switch value that targets the bb with
723 // SetDiscriminant is the same as the variant discriminant.
724 let variant_discr = adt.discriminant_for_variant(self.tcx, *variant_index).val;
725 if variant_discr != switch_value {
726 trace!(
727 "NO: variant discriminant {} does not equal switch value {}",
728 variant_discr,
729 switch_value
730 );
731 return StatementEquality::NotEqual;
732 }
3dfed10e
XL
733 let variant_is_fieldless = adt.variants[*variant_index].fields.is_empty();
734 if !variant_is_fieldless {
735 trace!("NO: variant {:?} was not fieldless", variant_index);
736 return StatementEquality::NotEqual;
737 }
738
739 match rhs {
740 Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => {
1b1a35ee 741 StatementEquality::ConsideredEqual(side_to_choose)
3dfed10e
XL
742 }
743 _ => {
744 trace!(
745 "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}",
746 rhs,
747 adt_matched_on
748 );
749 StatementEquality::NotEqual
750 }
751 }
752 };
753 match (&x.kind, &y.kind) {
754 // trivial case
755 (x, y) if x == y => StatementEquality::TrivialEqual,
756
757 // check for case A
758 (
759 StatementKind::Assign(box (_, rhs)),
760 StatementKind::SetDiscriminant { place, variant_index },
dc3f5686 761 ) if y_target_and_value.value.is_some() => {
3dfed10e 762 // choose basic block of x, as that has the assign
dc3f5686
XL
763 helper(
764 rhs,
765 place,
766 variant_index,
767 y_target_and_value.value.unwrap(),
768 x_target_and_value.target,
769 )
3dfed10e
XL
770 }
771 (
772 StatementKind::SetDiscriminant { place, variant_index },
773 StatementKind::Assign(box (_, rhs)),
dc3f5686 774 ) if x_target_and_value.value.is_some() => {
3dfed10e 775 // choose basic block of y, as that has the assign
dc3f5686
XL
776 helper(
777 rhs,
778 place,
779 variant_index,
780 x_target_and_value.value.unwrap(),
781 y_target_and_value.target,
782 )
3dfed10e
XL
783 }
784 _ => {
785 trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y);
786 StatementEquality::NotEqual
60c5eb7d
XL
787 }
788 }
3dfed10e
XL
789 }
790}
60c5eb7d 791
3dfed10e
XL
792#[derive(Copy, Clone, Eq, PartialEq)]
793enum StatementEquality {
794 /// The two statements are trivially equal; same kind
795 TrivialEqual,
796 /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization.
797 /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index
798 ConsideredEqual(BasicBlock),
799 /// The two statements are not equal
800 NotEqual,
801}
802
803impl StatementEquality {
804 fn combine(&self, other: &StatementEquality) -> StatementEquality {
805 use StatementEquality::*;
806 match (self, other) {
807 (TrivialEqual, TrivialEqual) => TrivialEqual,
808 (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => {
809 ConsideredEqual(*b)
810 }
811 (ConsideredEqual(b1), ConsideredEqual(b2)) => {
812 if b1 == b2 {
813 ConsideredEqual(*b1)
814 } else {
815 NotEqual
816 }
817 }
818 (_, NotEqual) | (NotEqual, _) => NotEqual,
60c5eb7d
XL
819 }
820 }
821}