1 use rustc_middle
::mir
::patch
::MirPatch
;
2 use rustc_middle
::mir
::*;
3 use rustc_middle
::ty
::{self, Ty, TyCtxt}
;
6 use super::simplify
::simplify_cfg
;
8 /// This pass optimizes something like
9 /// ```ignore (syntax-highlighting-only)
10 /// let x: Option<()>;
11 /// let y: Option<()>;
13 /// (Some(_), Some(_)) => {0},
17 /// into something like
18 /// ```ignore (syntax-highlighting-only)
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = std::mem::discriminant(x);
22 /// let discriminant_y = std::mem::discriminant(y);
23 /// if discriminant_x == discriminant_y {
27 /// } // | Actually the same bb
29 /// 1 // <--------------
33 /// Specifically, it looks for instances of control flow like this:
38 /// |---------------| ============================
39 /// | ... | /------> | BBC |
40 /// |---------------| | |--------------------------|
41 /// | switchInt(Q) | | | _cl = discriminant(P) |
42 /// | c | --------/ |--------------------------|
43 /// | d | -------\ | switchInt(_cl) |
44 /// | ... | | | c | ---> BBC.2
45 /// | otherwise | --\ | /--- | otherwise |
46 /// ================= | | | ============================
48 /// ================= | | |
49 /// | BBU | <-| | | ============================
50 /// |---------------| | \-------> | BBD |
51 /// |---------------| | | |--------------------------|
52 /// | unreachable | | | | _dl = discriminant(P) |
53 /// ================= | | |--------------------------|
54 /// | | | switchInt(_dl) |
55 /// ================= | | | d | ---> BBD.2
56 /// | BB9 | <--------------- | otherwise |
57 /// |---------------| ============================
61 /// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
63 /// - `BB1` is `parent` and `BBC, BBD` are children
64 /// - `P` is `child_place`
65 /// - `child_ty` is the type of `_cl`.
66 /// - `Q` is `parent_op`.
67 /// - `parent_ty` is the type of `Q`.
68 /// - `BB9` is `destination`
69 /// All this is then transformed into:
72 /// =======================
74 /// |---------------------| ============================
75 /// | ... | /------> | BBEq |
76 /// | _s = discriminant(P)| | |--------------------------|
77 /// | _t = Ne(Q, _s) | | |--------------------------|
78 /// |---------------------| | | switchInt(Q) |
79 /// | switchInt(_t) | | | c | ---> BBC.2
80 /// | false | --------/ | d | ---> BBD.2
81 /// | otherwise | ---------------- | otherwise |
82 /// ======================= | ============================
84 /// ================= |
85 /// | BB9 | <-----------/
91 /// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
92 /// The filter on which `P` are allowed (together with discussion of its correctness) is found in
94 pub struct EarlyOtherwiseBranch
;
96 impl<'tcx
> MirPass
<'tcx
> for EarlyOtherwiseBranch
{
97 fn is_enabled(&self, sess
: &rustc_session
::Session
) -> bool
{
98 sess
.mir_opt_level() >= 3 && sess
.opts
.unstable_opts
.unsound_mir_opts
101 fn run_pass(&self, tcx
: TyCtxt
<'tcx
>, body
: &mut Body
<'tcx
>) {
102 trace
!("running EarlyOtherwiseBranch on {:?}", body
.source
);
104 let mut should_cleanup
= false;
106 // Also consider newly generated bbs in the same pass
107 for i
in 0..body
.basic_blocks().len() {
108 let bbs
= body
.basic_blocks();
109 let parent
= BasicBlock
::from_usize(i
);
110 let Some(opt_data
) = evaluate_candidate(tcx
, body
, parent
) else {
114 if !tcx
.consider_optimizing(|| format
!("EarlyOtherwiseBranch {:?}", &opt_data
)) {
118 trace
!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data
);
120 should_cleanup
= true;
122 let TerminatorKind
::SwitchInt
{
124 switch_ty
: parent_ty
,
125 targets
: parent_targets
126 } = &bbs
[parent
].terminator().kind
else {
129 // Always correct since we can only switch on `Copy` types
130 let parent_op
= match parent_op
{
131 Operand
::Move(x
) => Operand
::Copy(*x
),
132 Operand
::Copy(x
) => Operand
::Copy(*x
),
133 Operand
::Constant(x
) => Operand
::Constant(x
.clone()),
135 let statements_before
= bbs
[parent
].statements
.len();
136 let parent_end
= Location { block: parent, statement_index: statements_before }
;
138 let mut patch
= MirPatch
::new(body
);
140 // create temp to store second discriminant in, `_s` in example above
141 let second_discriminant_temp
=
142 patch
.new_temp(opt_data
.child_ty
, opt_data
.child_source
.span
);
144 patch
.add_statement(parent_end
, StatementKind
::StorageLive(second_discriminant_temp
));
146 // create assignment of discriminant
149 Place
::from(second_discriminant_temp
),
150 Rvalue
::Discriminant(opt_data
.child_place
),
153 // create temp to store inequality comparison between the two discriminants, `_t` in
155 let nequal
= BinOp
::Ne
;
156 let comp_res_type
= nequal
.ty(tcx
, *parent_ty
, opt_data
.child_ty
);
157 let comp_temp
= patch
.new_temp(comp_res_type
, opt_data
.child_source
.span
);
158 patch
.add_statement(parent_end
, StatementKind
::StorageLive(comp_temp
));
160 // create inequality comparison between the two discriminants
161 let comp_rvalue
= Rvalue
::BinaryOp(
163 Box
::new((parent_op
.clone(), Operand
::Move(Place
::from(second_discriminant_temp
)))),
167 StatementKind
::Assign(Box
::new((Place
::from(comp_temp
), comp_rvalue
))),
170 let eq_new_targets
= parent_targets
.iter().map(|(value
, child
)| {
171 let TerminatorKind
::SwitchInt{ targets, .. }
= &bbs
[child
].terminator().kind
else {
174 (value
, targets
.target_for_value(value
))
176 let eq_targets
= SwitchTargets
::new(eq_new_targets
, opt_data
.destination
);
178 // Create `bbEq` in example above
179 let eq_switch
= BasicBlockData
::new(Some(Terminator
{
180 source_info
: bbs
[parent
].terminator().source_info
,
181 kind
: TerminatorKind
::SwitchInt
{
182 // switch on the first discriminant, so we can mark the second one as dead
184 switch_ty
: opt_data
.child_ty
,
189 let eq_bb
= patch
.new_block(eq_switch
);
191 // Jump to it on the basis of the inequality comparison
192 let true_case
= opt_data
.destination
;
193 let false_case
= eq_bb
;
194 patch
.patch_terminator(
198 Operand
::Move(Place
::from(comp_temp
)),
204 // generate StorageDead for the second_discriminant_temp not in use anymore
205 patch
.add_statement(parent_end
, StatementKind
::StorageDead(second_discriminant_temp
));
207 // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
209 for bb
in [false_case
, true_case
].iter() {
211 Location { block: *bb, statement_index: 0 }
,
212 StatementKind
::StorageDead(comp_temp
),
219 // Since this optimization adds new basic blocks and invalidates others,
220 // clean up the cfg to make it nicer for other passes
222 simplify_cfg(tcx
, body
);
227 /// Returns true if computing the discriminant of `place` may be hoisted out of the branch
228 fn may_hoist
<'tcx
>(tcx
: TyCtxt
<'tcx
>, body
: &Body
<'tcx
>, place
: Place
<'tcx
>) -> bool
{
229 // FIXME(JakobDegen): This is unsound. Someone could write code like this:
232 // if discriminant(P) == otherwise {
233 // let ptr = &mut Q as *mut _ as *mut u8;
234 // unsafe { *ptr = 10; } // Any invalid value for the type
252 // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
253 // invalid value, which is UB.
255 // In order to fix this, we would either need to show that the discriminant computation of
256 // `place` is computed in all branches, including the `otherwise` branch, or we would need
257 // another analysis pass to determine that the place is fully initialized. It might even be best
258 // to have the hoisting be performed in a different pass and just do the CFG changing in this
260 for (place
, proj
) in place
.iter_projections() {
262 // Dereferencing in the computation of `place` might cause issues from one of two
263 // categories. First, the referent might be invalid. We protect against this by
264 // dereferencing references only (not pointers). Second, the use of a reference may
265 // invalidate other references that are used later (for aliasing reasons). Consider
266 // where such an invalidated reference may appear:
267 // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
268 // cannot contain referenced data.
269 // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
270 // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
271 // reaching that block in the input to our transformation, and so any data
272 // invalidated by that computation could not have been used there.
273 // - In `BB9`: Not possible since control flow might have reached `BB9` via the
274 // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
275 // have invalidated the data when computing `discriminant(P)`
276 // So dereferencing here is correct.
277 ProjectionElem
::Deref
=> match place
.ty(body
.local_decls(), tcx
).ty
.kind() {
281 // Field projections are always valid
282 ProjectionElem
::Field(..) => {}
284 // downcasts either, since the correctness of the downcast may depend on the parent
285 // branch being taken. An easy example of this is
287 // Q = discriminant(_3)
288 // P = (_3 as Variant)
290 // However, checking if the child and parent place are the same and only erroring then
291 // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
292 // be replaced by another optimization pass with any other condition that can be proven
294 ProjectionElem
::Downcast(..) => {
297 // We cannot allow indexing since the index may be out of bounds.
307 struct OptimizationData
<'tcx
> {
308 destination
: BasicBlock
,
309 child_place
: Place
<'tcx
>,
311 child_source
: SourceInfo
,
314 fn evaluate_candidate
<'tcx
>(
318 ) -> Option
<OptimizationData
<'tcx
>> {
319 let bbs
= body
.basic_blocks();
320 let TerminatorKind
::SwitchInt
{
322 switch_ty
: parent_ty
,
324 } = &bbs
[parent
].terminator().kind
else {
328 let poss
= targets
.otherwise();
329 // If the fallthrough on the parent is trivially unreachable, we can let the
330 // children choose the destination
331 if bbs
[poss
].statements
.len() == 0
332 && bbs
[poss
].terminator().kind
== TerminatorKind
::Unreachable
339 let (_
, child
) = targets
.iter().next()?
;
340 let child_terminator
= &bbs
[child
].terminator();
341 let TerminatorKind
::SwitchInt
{
343 targets
: child_targets
,
345 } = &child_terminator
.kind
else {
348 if child_ty
!= parent_ty
{
351 let Some(StatementKind
::Assign(boxed
))
352 = &bbs
[child
].statements
.first().map(|x
| &x
.kind
) else {
355 let (_
, Rvalue
::Discriminant(child_place
)) = &**boxed
else {
358 let destination
= parent_dest
.unwrap_or(child_targets
.otherwise());
360 // Verify that the optimization is legal in general
361 // We can hoist evaluating the child discriminant out of the branch
362 if !may_hoist(tcx
, body
, *child_place
) {
366 // Verify that the optimization is legal for each branch
367 for (value
, child
) in targets
.iter() {
368 if !verify_candidate_branch(&bbs
[child
], value
, *child_place
, destination
) {
372 Some(OptimizationData
{
374 child_place
: *child_place
,
376 child_source
: child_terminator
.source_info
,
380 fn verify_candidate_branch
<'tcx
>(
381 branch
: &BasicBlockData
<'tcx
>,
384 destination
: BasicBlock
,
386 // In order for the optimization to be correct, the branch must...
387 // ...have exactly one statement
388 if branch
.statements
.len() != 1 {
391 // ...assign the discriminant of `place` in that statement
392 let StatementKind
::Assign(boxed
) = &branch
.statements
[0].kind
else {
395 let (discr_place
, Rvalue
::Discriminant(from_place
)) = &**boxed
else {
398 if *from_place
!= place
{
401 // ...make that assignment to a local
402 if discr_place
.projection
.len() != 0 {
405 // ...terminate on a `SwitchInt` that invalidates that local
406 let TerminatorKind
::SwitchInt{ discr: switch_op, targets, .. }
= &branch
.terminator().kind
else {
409 if *switch_op
!= Operand
::Move(*discr_place
) {
412 // ...fall through to `destination` if the switch misses
413 if destination
!= targets
.otherwise() {
416 // ...have a branch for value `value`
417 let mut iter
= targets
.iter();
418 let Some((target_value
, _
)) = iter
.next() else {
421 if target_value
!= value
{
424 // ...and have no more branches
425 if let Some(_
) = iter
.next() {