]> git.proxmox.com Git - rustc.git/blob - compiler/rustc_mir_transform/src/early_otherwise_branch.rs
New upstream version 1.61.0+dfsg1
[rustc.git] / compiler / rustc_mir_transform / src / early_otherwise_branch.rs
1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{self, Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```ignore (syntax-highlighting-only)
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 /// (Some(_), Some(_)) => {0},
14 /// _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```ignore (syntax-highlighting-only)
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = std::mem::discriminant(x);
22 /// let discriminant_y = std::mem::discriminant(y);
23 /// if discriminant_x == discriminant_y {
24 /// match x {
25 /// Some(_) => 0,
26 /// _ => 1, // <----
27 /// } // | Actually the same bb
28 /// } else { // |
29 /// 1 // <--------------
30 /// }
31 /// ```
32 ///
33 /// Specifically, it looks for instances of control flow like this:
34 /// ```text
35 ///
36 /// =================
37 /// | BB1 |
38 /// |---------------| ============================
39 /// | ... | /------> | BBC |
40 /// |---------------| | |--------------------------|
41 /// | switchInt(Q) | | | _cl = discriminant(P) |
42 /// | c | --------/ |--------------------------|
43 /// | d | -------\ | switchInt(_cl) |
44 /// | ... | | | c | ---> BBC.2
45 /// | otherwise | --\ | /--- | otherwise |
46 /// ================= | | | ============================
47 /// | | |
48 /// ================= | | |
49 /// | BBU | <-| | | ============================
50 /// |---------------| | \-------> | BBD |
51 /// |---------------| | | |--------------------------|
52 /// | unreachable | | | | _dl = discriminant(P) |
53 /// ================= | | |--------------------------|
54 /// | | | switchInt(_dl) |
55 /// ================= | | | d | ---> BBD.2
56 /// | BB9 | <--------------- | otherwise |
57 /// |---------------| ============================
58 /// | ... |
59 /// =================
60 /// ```
61 /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
62 /// code:
63 /// - `BB1` is `parent` and `BBC, BBD` are children
64 /// - `P` is `child_place`
65 /// - `child_ty` is the type of `_cl`.
66 /// - `Q` is `parent_op`.
67 /// - `parent_ty` is the type of `Q`.
68 /// - `BB9` is `destination`
69 /// All this is then transformed into:
70 /// ```text
71 ///
72 /// =======================
73 /// | BB1 |
74 /// |---------------------| ============================
75 /// | ... | /------> | BBEq |
76 /// | _s = discriminant(P)| | |--------------------------|
77 /// | _t = Ne(Q, _s) | | |--------------------------|
78 /// |---------------------| | | switchInt(Q) |
79 /// | switchInt(_t) | | | c | ---> BBC.2
80 /// | false | --------/ | d | ---> BBD.2
81 /// | otherwise | ---------------- | otherwise |
82 /// ======================= | ============================
83 /// |
84 /// ================= |
85 /// | BB9 | <-----------/
86 /// |---------------|
87 /// | ... |
88 /// =================
89 /// ```
90 ///
91 /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92 /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
93 /// `may_hoist`.
94 pub struct EarlyOtherwiseBranch;
95
96 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
97 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
98 sess.mir_opt_level() >= 3 && sess.opts.debugging_opts.unsound_mir_opts
99 }
100
101 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
102 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
103
104 let mut should_cleanup = false;
105
106 // Also consider newly generated bbs in the same pass
107 for i in 0..body.basic_blocks().len() {
108 let bbs = body.basic_blocks();
109 let parent = BasicBlock::from_usize(i);
110 let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
111 continue
112 };
113
114 if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
115 break;
116 }
117
118 trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
119
120 should_cleanup = true;
121
122 let TerminatorKind::SwitchInt {
123 discr: parent_op,
124 switch_ty: parent_ty,
125 targets: parent_targets
126 } = &bbs[parent].terminator().kind else {
127 unreachable!()
128 };
129 // Always correct since we can only switch on `Copy` types
130 let parent_op = match parent_op {
131 Operand::Move(x) => Operand::Copy(*x),
132 Operand::Copy(x) => Operand::Copy(*x),
133 Operand::Constant(x) => Operand::Constant(x.clone()),
134 };
135 let statements_before = bbs[parent].statements.len();
136 let parent_end = Location { block: parent, statement_index: statements_before };
137
138 let mut patch = MirPatch::new(body);
139
140 // create temp to store second discriminant in, `_s` in example above
141 let second_discriminant_temp =
142 patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
143
144 patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
145
146 // create assignment of discriminant
147 patch.add_assign(
148 parent_end,
149 Place::from(second_discriminant_temp),
150 Rvalue::Discriminant(opt_data.child_place),
151 );
152
153 // create temp to store inequality comparison between the two discriminants, `_t` in
154 // example above
155 let nequal = BinOp::Ne;
156 let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty);
157 let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
158 patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
159
160 // create inequality comparison between the two discriminants
161 let comp_rvalue = Rvalue::BinaryOp(
162 nequal,
163 Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
164 );
165 patch.add_statement(
166 parent_end,
167 StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
168 );
169
170 let eq_new_targets = parent_targets.iter().map(|(value, child)| {
171 let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
172 unreachable!()
173 };
174 (value, targets.target_for_value(value))
175 });
176 let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
177
178 // Create `bbEq` in example above
179 let eq_switch = BasicBlockData::new(Some(Terminator {
180 source_info: bbs[parent].terminator().source_info,
181 kind: TerminatorKind::SwitchInt {
182 // switch on the first discriminant, so we can mark the second one as dead
183 discr: parent_op,
184 switch_ty: opt_data.child_ty,
185 targets: eq_targets,
186 },
187 }));
188
189 let eq_bb = patch.new_block(eq_switch);
190
191 // Jump to it on the basis of the inequality comparison
192 let true_case = opt_data.destination;
193 let false_case = eq_bb;
194 patch.patch_terminator(
195 parent,
196 TerminatorKind::if_(
197 tcx,
198 Operand::Move(Place::from(comp_temp)),
199 true_case,
200 false_case,
201 ),
202 );
203
204 // generate StorageDead for the second_discriminant_temp not in use anymore
205 patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
206
207 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
208 // the switch
209 for bb in [false_case, true_case].iter() {
210 patch.add_statement(
211 Location { block: *bb, statement_index: 0 },
212 StatementKind::StorageDead(comp_temp),
213 );
214 }
215
216 patch.apply(body);
217 }
218
219 // Since this optimization adds new basic blocks and invalidates others,
220 // clean up the cfg to make it nicer for other passes
221 if should_cleanup {
222 simplify_cfg(tcx, body);
223 }
224 }
225 }
226
227 /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
228 fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
229 // FIXME(JakobDegen): This is unsound. Someone could write code like this:
230 // ```rust
231 // let Q = val;
232 // if discriminant(P) == otherwise {
233 // let ptr = &mut Q as *mut _ as *mut u8;
234 // unsafe { *ptr = 10; } // Any invalid value for the type
235 // }
236 //
237 // match P {
238 // A => match Q {
239 // A => {
240 // // code
241 // }
242 // _ => {
243 // // don't use Q
244 // }
245 // }
246 // _ => {
247 // // don't use Q
248 // }
249 // };
250 // ```
251 //
252 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
253 // invalid value, which is UB.
254 //
255 // In order to fix this, we would either need to show that the discriminant computation of
256 // `place` is computed in all branches, including the `otherwise` branch, or we would need
257 // another analysis pass to determine that the place is fully initialized. It might even be best
258 // to have the hoisting be performed in a different pass and just do the CFG changing in this
259 // pass.
260 for (place, proj) in place.iter_projections() {
261 match proj {
262 // Dereferencing in the computation of `place` might cause issues from one of two
263 // categories. First, the referent might be invalid. We protect against this by
264 // dereferencing references only (not pointers). Second, the use of a reference may
265 // invalidate other references that are used later (for aliasing reasons). Consider
266 // where such an invalidated reference may appear:
267 // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
268 // cannot contain referenced data.
269 // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
270 // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
271 // reaching that block in the input to our transformation, and so any data
272 // invalidated by that computation could not have been used there.
273 // - In `BB9`: Not possible since control flow might have reached `BB9` via the
274 // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
275 // have invalidated the data when computing `discriminant(P)`
276 // So dereferencing here is correct.
277 ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
278 ty::Ref(..) => {}
279 _ => return false,
280 },
281 // Field projections are always valid
282 ProjectionElem::Field(..) => {}
283 // We cannot allow
284 // downcasts either, since the correctness of the downcast may depend on the parent
285 // branch being taken. An easy example of this is
286 // ```
287 // Q = discriminant(_3)
288 // P = (_3 as Variant)
289 // ```
290 // However, checking if the child and parent place are the same and only erroring then
291 // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
292 // be replaced by another optimization pass with any other condition that can be proven
293 // equivalent.
294 ProjectionElem::Downcast(..) => {
295 return false;
296 }
297 // We cannot allow indexing since the index may be out of bounds.
298 _ => {
299 return false;
300 }
301 }
302 }
303 true
304 }
305
306 #[derive(Debug)]
307 struct OptimizationData<'tcx> {
308 destination: BasicBlock,
309 child_place: Place<'tcx>,
310 child_ty: Ty<'tcx>,
311 child_source: SourceInfo,
312 }
313
314 fn evaluate_candidate<'tcx>(
315 tcx: TyCtxt<'tcx>,
316 body: &Body<'tcx>,
317 parent: BasicBlock,
318 ) -> Option<OptimizationData<'tcx>> {
319 let bbs = body.basic_blocks();
320 let TerminatorKind::SwitchInt {
321 targets,
322 switch_ty: parent_ty,
323 ..
324 } = &bbs[parent].terminator().kind else {
325 return None
326 };
327 let parent_dest = {
328 let poss = targets.otherwise();
329 // If the fallthrough on the parent is trivially unreachable, we can let the
330 // children choose the destination
331 if bbs[poss].statements.len() == 0
332 && bbs[poss].terminator().kind == TerminatorKind::Unreachable
333 {
334 None
335 } else {
336 Some(poss)
337 }
338 };
339 let Some((_, child)) = targets.iter().next() else {
340 return None
341 };
342 let child_terminator = &bbs[child].terminator();
343 let TerminatorKind::SwitchInt {
344 switch_ty: child_ty,
345 targets: child_targets,
346 ..
347 } = &child_terminator.kind else {
348 return None
349 };
350 if child_ty != parent_ty {
351 return None;
352 }
353 let Some(StatementKind::Assign(boxed))
354 = &bbs[child].statements.first().map(|x| &x.kind) else {
355 return None;
356 };
357 let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
358 return None;
359 };
360 let destination = parent_dest.unwrap_or(child_targets.otherwise());
361
362 // Verify that the optimization is legal in general
363 // We can hoist evaluating the child discriminant out of the branch
364 if !may_hoist(tcx, body, *child_place) {
365 return None;
366 }
367
368 // Verify that the optimization is legal for each branch
369 for (value, child) in targets.iter() {
370 if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
371 return None;
372 }
373 }
374 Some(OptimizationData {
375 destination,
376 child_place: *child_place,
377 child_ty: *child_ty,
378 child_source: child_terminator.source_info,
379 })
380 }
381
382 fn verify_candidate_branch<'tcx>(
383 branch: &BasicBlockData<'tcx>,
384 value: u128,
385 place: Place<'tcx>,
386 destination: BasicBlock,
387 ) -> bool {
388 // In order for the optimization to be correct, the branch must...
389 // ...have exactly one statement
390 if branch.statements.len() != 1 {
391 return false;
392 }
393 // ...assign the discriminant of `place` in that statement
394 let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
395 return false
396 };
397 let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
398 return false
399 };
400 if *from_place != place {
401 return false;
402 }
403 // ...make that assignment to a local
404 if discr_place.projection.len() != 0 {
405 return false;
406 }
407 // ...terminate on a `SwitchInt` that invalidates that local
408 let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
409 return false
410 };
411 if *switch_op != Operand::Move(*discr_place) {
412 return false;
413 }
414 // ...fall through to `destination` if the switch misses
415 if destination != targets.otherwise() {
416 return false;
417 }
418 // ...have a branch for value `value`
419 let mut iter = targets.iter();
420 let Some((target_value, _)) = iter.next() else {
421 return false;
422 };
423 if target_value != value {
424 return false;
425 }
426 // ...and have no more branches
427 if let Some(_) = iter.next() {
428 return false;
429 }
430 return true;
431 }