]>
Commit | Line | Data |
---|---|---|
60c5eb7d XL |
1 | //! A pass that eliminates branches on uninhabited enum variants. |
2 | ||
3 | use crate::transform::{MirPass, MirSource}; | |
ba9703b0 | 4 | use rustc_middle::mir::{ |
60c5eb7d XL |
5 | BasicBlock, BasicBlockData, Body, BodyAndCache, Local, Operand, Rvalue, StatementKind, |
6 | TerminatorKind, | |
7 | }; | |
ba9703b0 XL |
8 | use rustc_middle::ty::layout::TyAndLayout; |
9 | use rustc_middle::ty::{Ty, TyCtxt}; | |
10 | use rustc_target::abi::{Abi, Variants}; | |
60c5eb7d XL |
11 | |
12 | pub struct UninhabitedEnumBranching; | |
13 | ||
14 | fn get_discriminant_local(terminator: &TerminatorKind<'_>) -> Option<Local> { | |
15 | if let TerminatorKind::SwitchInt { discr: Operand::Move(p), .. } = terminator { | |
16 | p.as_local() | |
17 | } else { | |
18 | None | |
19 | } | |
20 | } | |
21 | ||
22 | /// If the basic block terminates by switching on a discriminant, this returns the `Ty` the | |
23 | /// discriminant is read from. Otherwise, returns None. | |
24 | fn get_switched_on_type<'tcx>( | |
25 | block_data: &BasicBlockData<'tcx>, | |
26 | body: &Body<'tcx>, | |
27 | ) -> Option<Ty<'tcx>> { | |
28 | let terminator = block_data.terminator(); | |
29 | ||
30 | // Only bother checking blocks which terminate by switching on a local. | |
31 | if let Some(local) = get_discriminant_local(&terminator.kind) { | |
74b04a01 | 32 | let stmt_before_term = (!block_data.statements.is_empty()) |
60c5eb7d XL |
33 | .then(|| &block_data.statements[block_data.statements.len() - 1].kind); |
34 | ||
35 | if let Some(StatementKind::Assign(box (l, Rvalue::Discriminant(place)))) = stmt_before_term | |
36 | { | |
37 | if l.as_local() == Some(local) { | |
38 | if let Some(r_local) = place.as_local() { | |
39 | let ty = body.local_decls[r_local].ty; | |
40 | ||
41 | if ty.is_enum() { | |
42 | return Some(ty); | |
43 | } | |
44 | } | |
45 | } | |
46 | } | |
47 | } | |
48 | ||
49 | None | |
50 | } | |
51 | ||
52 | fn variant_discriminants<'tcx>( | |
ba9703b0 | 53 | layout: &TyAndLayout<'tcx>, |
60c5eb7d XL |
54 | ty: Ty<'tcx>, |
55 | tcx: TyCtxt<'tcx>, | |
56 | ) -> Vec<u128> { | |
ba9703b0 | 57 | match &layout.variants { |
60c5eb7d XL |
58 | Variants::Single { index } => vec![index.as_u32() as u128], |
59 | Variants::Multiple { variants, .. } => variants | |
60 | .iter_enumerated() | |
61 | .filter_map(|(idx, layout)| { | |
62 | (layout.abi != Abi::Uninhabited) | |
63 | .then(|| ty.discriminant_for_variant(tcx, idx).unwrap().val) | |
64 | }) | |
65 | .collect(), | |
66 | } | |
67 | } | |
68 | ||
69 | impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching { | |
70 | fn run_pass(&self, tcx: TyCtxt<'tcx>, source: MirSource<'tcx>, body: &mut BodyAndCache<'tcx>) { | |
71 | if source.promoted.is_some() { | |
72 | return; | |
73 | } | |
74 | ||
75 | trace!("UninhabitedEnumBranching starting for {:?}", source); | |
76 | ||
77 | let basic_block_count = body.basic_blocks().len(); | |
78 | ||
79 | for bb in 0..basic_block_count { | |
80 | let bb = BasicBlock::from_usize(bb); | |
81 | trace!("processing block {:?}", bb); | |
82 | ||
83 | let discriminant_ty = | |
84 | if let Some(ty) = get_switched_on_type(&body.basic_blocks()[bb], body) { | |
85 | ty | |
86 | } else { | |
87 | continue; | |
88 | }; | |
89 | ||
90 | let layout = tcx.layout_of(tcx.param_env(source.def_id()).and(discriminant_ty)); | |
91 | ||
92 | let allowed_variants = if let Ok(layout) = layout { | |
93 | variant_discriminants(&layout, discriminant_ty, tcx) | |
94 | } else { | |
95 | continue; | |
96 | }; | |
97 | ||
98 | trace!("allowed_variants = {:?}", allowed_variants); | |
99 | ||
100 | if let TerminatorKind::SwitchInt { values, targets, .. } = | |
101 | &mut body.basic_blocks_mut()[bb].terminator_mut().kind | |
102 | { | |
103 | let vals = &*values; | |
74b04a01 | 104 | let zipped = vals.iter().zip(targets.iter()); |
60c5eb7d XL |
105 | |
106 | let mut matched_values = Vec::with_capacity(allowed_variants.len()); | |
107 | let mut matched_targets = Vec::with_capacity(allowed_variants.len() + 1); | |
108 | ||
109 | for (val, target) in zipped { | |
110 | if allowed_variants.contains(val) { | |
111 | matched_values.push(*val); | |
112 | matched_targets.push(*target); | |
113 | } else { | |
114 | trace!("eliminating {:?} -> {:?}", val, target); | |
115 | } | |
116 | } | |
117 | ||
118 | // handle the "otherwise" branch | |
119 | matched_targets.push(targets.pop().unwrap()); | |
120 | ||
121 | *values = matched_values.into(); | |
122 | *targets = matched_targets; | |
123 | } else { | |
124 | unreachable!() | |
125 | } | |
126 | } | |
127 | } | |
128 | } |