use rustc_middle::mir::patch::MirPatch;
use rustc_middle::mir::*;
-use rustc_middle::ty::{Ty, TyCtxt};
+use rustc_middle::ty::{self, Ty, TyCtxt};
use std::fmt::Debug;
use super::simplify::simplify_cfg;
/// This pass optimizes something like
-/// ```text
+/// ```ignore (syntax-highlighting-only)
/// let x: Option<()>;
/// let y: Option<()>;
/// match (x,y) {
/// }
/// ```
/// into something like
-/// ```text
+/// ```ignore (syntax-highlighting-only)
/// let x: Option<()>;
/// let y: Option<()>;
-/// let discriminant_x = // get discriminant of x
-/// let discriminant_y = // get discriminant of y
-/// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
+/// let discriminant_x = std::mem::discriminant(x);
+/// let discriminant_y = std::mem::discriminant(y);
+/// if discriminant_x == discriminant_y {
+/// match x {
+/// Some(_) => 0,
+/// _ => 1, // <----
+/// } // | Actually the same bb
+/// } else { // |
+/// 1 // <--------------
+/// }
+/// ```
+///
+/// Specifically, it looks for instances of control flow like this:
+/// ```text
+///
+/// =================
+/// | BB1 |
+/// |---------------| ============================
+/// | ... | /------> | BBC |
+/// |---------------| | |--------------------------|
+/// | switchInt(Q) | | | _cl = discriminant(P) |
+/// | c | --------/ |--------------------------|
+/// | d | -------\ | switchInt(_cl) |
+/// | ... | | | c | ---> BBC.2
+/// | otherwise | --\ | /--- | otherwise |
+/// ================= | | | ============================
+/// | | |
+/// ================= | | |
+/// | BBU | <-| | | ============================
+/// |---------------| | \-------> | BBD |
+/// |---------------| | | |--------------------------|
+/// | unreachable | | | | _dl = discriminant(P) |
+/// ================= | | |--------------------------|
+/// | | | switchInt(_dl) |
+/// ================= | | | d | ---> BBD.2
+/// | BB9 | <--------------- | otherwise |
+/// |---------------| ============================
+/// | ... |
+/// =================
/// ```
+/// Where the `otherwise` branch on `BB1` is permitted to either go to `BBU` or to `BB9`. In the
+/// code:
+/// - `BB1` is `parent` and `BBC, BBD` are children
+/// - `P` is `child_place`
+/// - `child_ty` is the type of `_cl`.
+/// - `Q` is `parent_op`.
+/// - `parent_ty` is the type of `Q`.
+/// - `BB9` is `destination`
+/// All this is then transformed into:
+/// ```text
+///
+/// =======================
+/// | BB1 |
+/// |---------------------| ============================
+/// | ... | /------> | BBEq |
+/// | _s = discriminant(P)| | |--------------------------|
+/// | _t = Ne(Q, _s) | | |--------------------------|
+/// |---------------------| | | switchInt(Q) |
+/// | switchInt(_t) | | | c | ---> BBC.2
+/// | false | --------/ | d | ---> BBD.2
+/// | otherwise | ---------------- | otherwise |
+/// ======================= | ============================
+/// |
+/// ================= |
+/// | BB9 | <-----------/
+/// |---------------|
+/// | ... |
+/// =================
+/// ```
+///
+/// This is only correct for some `P`, since `P` is now computed outside the original `switchInt`.
+/// The filter on which `P` are allowed (together with discussion of its correctness) is found in
+/// `may_hoist`.
pub struct EarlyOtherwiseBranch;
impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
- // FIXME(#78496)
- sess.opts.debugging_opts.unsound_mir_opts && sess.mir_opt_level() >= 3
+ sess.mir_opt_level() >= 2
}
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
trace!("running EarlyOtherwiseBranch on {:?}", body.source);
- // we are only interested in this bb if the terminator is a switchInt
- let bbs_with_switch =
- body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
+ let mut should_cleanup = false;
- let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
- .flat_map(|(bb_idx, bb)| {
- let switch = bb.terminator();
- let helper = Helper { body, tcx };
- let infos = helper.go(bb, switch)?;
- Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
- })
- .collect();
-
- let should_cleanup = !opts_to_apply.is_empty();
+ // Also consider newly generated bbs in the same pass
+ for i in 0..body.basic_blocks().len() {
+ let bbs = body.basic_blocks();
+ let parent = BasicBlock::from_usize(i);
+ let Some(opt_data) = evaluate_candidate(tcx, body, parent) else {
+ continue
+ };
- for opt_to_apply in opts_to_apply {
- if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) {
+ if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_data)) {
break;
}
- trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
+ trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_data);
- let statements_before =
- body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
- let end_of_block_location = Location {
- block: opt_to_apply.basic_block_first_switch,
- statement_index: statements_before,
+ should_cleanup = true;
+
+ let TerminatorKind::SwitchInt {
+ discr: parent_op,
+ switch_ty: parent_ty,
+ targets: parent_targets
+ } = &bbs[parent].terminator().kind else {
+ unreachable!()
+ };
+ // Always correct since we can only switch on `Copy` types
+ let parent_op = match parent_op {
+ Operand::Move(x) => Operand::Copy(*x),
+ Operand::Copy(x) => Operand::Copy(*x),
+ Operand::Constant(x) => Operand::Constant(x.clone()),
};
+ let statements_before = bbs[parent].statements.len();
+ let parent_end = Location { block: parent, statement_index: statements_before };
let mut patch = MirPatch::new(body);
- // create temp to store second discriminant in
- let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
- let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
- let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
+ // create temp to store second discriminant in, `_s` in example above
+ let second_discriminant_temp =
+ patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
- patch.add_statement(
- end_of_block_location,
- StatementKind::StorageLive(second_discriminant_temp),
- );
+ patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
// create assignment of discriminant
- let place_of_adt_to_get_discriminant_of =
- opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
patch.add_assign(
- end_of_block_location,
+ parent_end,
Place::from(second_discriminant_temp),
- Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
+ Rvalue::Discriminant(opt_data.child_place),
);
- // create temp to store NotEqual comparison between the two discriminants
- let not_equal = BinOp::Ne;
- let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
- let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
- patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
-
- // create NotEqual comparison between the two discriminants
- let first_descriminant_place =
- opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
- let not_equal_rvalue = Rvalue::BinaryOp(
- not_equal,
- Box::new((
- Operand::Copy(Place::from(second_discriminant_temp)),
- Operand::Copy(first_descriminant_place),
- )),
+ // create temp to store inequality comparison between the two discriminants, `_t` in
+ // example above
+ let nequal = BinOp::Ne;
+ let comp_res_type = nequal.ty(tcx, *parent_ty, opt_data.child_ty);
+ let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
+ patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
+
+ // create inequality comparison between the two discriminants
+ let comp_rvalue = Rvalue::BinaryOp(
+ nequal,
+ Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
);
patch.add_statement(
- end_of_block_location,
- StatementKind::Assign(Box::new((Place::from(not_equal_temp), not_equal_rvalue))),
+ parent_end,
+ StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
);
- let new_targets = opt_to_apply
- .infos
- .iter()
- .flat_map(|x| x.second_switch_info.targets_with_values.iter())
- .cloned();
-
- let targets = SwitchTargets::new(
- new_targets,
- opt_to_apply.infos[0].first_switch_info.otherwise_bb,
- );
-
- // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
- let new_switch_data = BasicBlockData::new(Some(Terminator {
- source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
+ let eq_new_targets = parent_targets.iter().map(|(value, child)| {
+ let TerminatorKind::SwitchInt{ targets, .. } = &bbs[child].terminator().kind else {
+ unreachable!()
+ };
+ (value, targets.target_for_value(value))
+ });
+ let eq_targets = SwitchTargets::new(eq_new_targets, opt_data.destination);
+
+ // Create `bbEq` in example above
+ let eq_switch = BasicBlockData::new(Some(Terminator {
+ source_info: bbs[parent].terminator().source_info,
kind: TerminatorKind::SwitchInt {
- // the first and second discriminants are equal, so just pick one
- discr: Operand::Copy(first_descriminant_place),
- switch_ty: discr_type,
- targets,
+ // switch on the first discriminant, so we can mark the second one as dead
+ discr: parent_op,
+ switch_ty: opt_data.child_ty,
+ targets: eq_targets,
},
}));
- let new_switch_bb = patch.new_block(new_switch_data);
+ let eq_bb = patch.new_block(eq_switch);
- // switch on the NotEqual. If true, then jump to the `otherwise` case.
- // If false, then jump to a basic block that then jumps to the correct disciminant case
- let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
- let false_case = new_switch_bb;
+ // Jump to it on the basis of the inequality comparison
+ let true_case = opt_data.destination;
+ let false_case = eq_bb;
patch.patch_terminator(
- opt_to_apply.basic_block_first_switch,
+ parent,
TerminatorKind::if_(
tcx,
- Operand::Move(Place::from(not_equal_temp)),
+ Operand::Move(Place::from(comp_temp)),
true_case,
false_case,
),
);
// generate StorageDead for the second_discriminant_temp not in use anymore
- patch.add_statement(
- end_of_block_location,
- StatementKind::StorageDead(second_discriminant_temp),
- );
+ patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
- // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
+ // Generate a StorageDead for comp_temp in each of the targets, since we moved it into
+ // the switch
for bb in [false_case, true_case].iter() {
patch.add_statement(
Location { block: *bb, statement_index: 0 },
- StatementKind::StorageDead(not_equal_temp),
+ StatementKind::StorageDead(comp_temp),
);
}
}
}
-fn is_switch(terminator: &Terminator<'_>) -> bool {
- matches!(terminator.kind, TerminatorKind::SwitchInt { .. })
-}
-
-struct Helper<'a, 'tcx> {
- body: &'a Body<'tcx>,
- tcx: TyCtxt<'tcx>,
-}
-
-#[derive(Debug, Clone)]
-struct SwitchDiscriminantInfo<'tcx> {
- /// Type of the discriminant being switched on
- discr_ty: Ty<'tcx>,
- /// The basic block that the otherwise branch points to
- otherwise_bb: BasicBlock,
- /// Target along with the value being branched from. Otherwise is not included
- targets_with_values: Vec<(u128, BasicBlock)>,
- discr_source_info: SourceInfo,
- /// The place of the discriminant used in the switch
- discr_used_in_switch: Place<'tcx>,
- /// The place of the adt that has its discriminant read
- place_of_adt_discr_read: Place<'tcx>,
- /// The type of the adt that has its discriminant read
- type_adt_matched_on: Ty<'tcx>,
-}
-
-#[derive(Debug)]
-struct OptimizationToApply<'tcx> {
- infos: Vec<OptimizationInfo<'tcx>>,
- /// Basic block of the original first switch
- basic_block_first_switch: BasicBlock,
+/// Returns true if computing the discriminant of `place` may be hoisted out of the branch
+fn may_hoist<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, place: Place<'tcx>) -> bool {
+ for (place, proj) in place.iter_projections() {
+ match proj {
+ // Dereferencing in the computation of `place` might cause issues from one of two
+ // cateogires. First, the referrent might be invalid. We protect against this by
+ // dereferencing references only (not pointers). Second, the use of a reference may
+ // invalidate other references that are used later (for aliasing reasons). Consider
+ // where such an invalidated reference may appear:
+ // - In `Q`: Not possible since `Q` is used as the operand of a `SwitchInt` and so
+ // cannot contain referenced data.
+ // - In `BBU`: Not possible since that block contains only the `unreachable` terminator
+ // - In `BBC.2, BBD.2`: Not possible, since `discriminant(P)` was computed prior to
+ // reaching that block in the input to our transformation, and so any data
+ // invalidated by that computation could not have been used there.
+ // - In `BB9`: Not possible since control flow might have reached `BB9` via the
+ // `otherwise` branch in `BBC, BBD` in the input to our transformation, which would
+ // have invalidated the data when computing `discriminant(P)`
+ // So dereferencing here is correct.
+ ProjectionElem::Deref => match place.ty(body.local_decls(), tcx).ty.kind() {
+ ty::Ref(..) => {}
+ _ => return false,
+ },
+ // Field projections are always valid
+ ProjectionElem::Field(..) => {}
+ // We cannot allow
+ // downcasts either, since the correctness of the downcast may depend on the parent
+ // branch being taken. An easy example of this is
+ // ```
+ // Q = discriminant(_3)
+ // P = (_3 as Variant)
+ // ```
+ // However, checking if the child and parent place are the same and only erroring then
+ // is not sufficient either, since the `discriminant(_3) == 1` (or whatever) check may
+ // be replaced by another optimization pass with any other condition that can be proven
+ // equivalent.
+ ProjectionElem::Downcast(..) => {
+ return false;
+ }
+ // We cannot allow indexing since the index may be out of bounds.
+ _ => {
+ return false;
+ }
+ }
+ }
+ true
}
#[derive(Debug)]
-struct OptimizationInfo<'tcx> {
- /// Info about the first switch and discriminant
- first_switch_info: SwitchDiscriminantInfo<'tcx>,
- /// Info about the second switch and discriminant
- second_switch_info: SwitchDiscriminantInfo<'tcx>,
+struct OptimizationData<'tcx> {
+ destination: BasicBlock,
+ child_place: Place<'tcx>,
+ child_ty: Ty<'tcx>,
+ child_source: SourceInfo,
}
-impl<'tcx> Helper<'_, 'tcx> {
- pub fn go(
- &self,
- bb: &BasicBlockData<'tcx>,
- switch: &Terminator<'tcx>,
- ) -> Option<Vec<OptimizationInfo<'tcx>>> {
- // try to find the statement that defines the discriminant that is used for the switch
- let discr = self.find_switch_discriminant_info(bb, switch)?;
-
- // go through each target, finding a discriminant read, and a switch
- let results = discr
- .targets_with_values
- .iter()
- .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value));
-
- // if the optimization did not apply for one of the targets, then abort
- if results.clone().any(|x| x.is_none()) || results.len() == 0 {
- trace!("NO: not all of the targets matched the pattern for optimization");
- return None;
+fn evaluate_candidate<'tcx>(
+ tcx: TyCtxt<'tcx>,
+ body: &Body<'tcx>,
+ parent: BasicBlock,
+) -> Option<OptimizationData<'tcx>> {
+ let bbs = body.basic_blocks();
+ let TerminatorKind::SwitchInt {
+ targets,
+ switch_ty: parent_ty,
+ ..
+ } = &bbs[parent].terminator().kind else {
+ return None
+ };
+ let parent_dest = {
+ let poss = targets.otherwise();
+ // If the fallthrough on the parent is trivially unreachable, we can let the
+ // children choose the destination
+ if bbs[poss].statements.len() == 0
+ && bbs[poss].terminator().kind == TerminatorKind::Unreachable
+ {
+ None
+ } else {
+ Some(poss)
}
-
- Some(results.flatten().collect())
+ };
+ let Some((_, child)) = targets.iter().next() else {
+ return None
+ };
+ let child_terminator = &bbs[child].terminator();
+ let TerminatorKind::SwitchInt {
+ switch_ty: child_ty,
+ targets: child_targets,
+ ..
+ } = &child_terminator.kind else {
+ return None
+ };
+ if child_ty != parent_ty {
+ return None;
+ }
+ let Some(StatementKind::Assign(boxed))
+ = &bbs[child].statements.first().map(|x| &x.kind) else {
+ return None;
+ };
+ let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
+ return None;
+ };
+ let destination = parent_dest.unwrap_or(child_targets.otherwise());
+
+ // Verify that the optimization is legal in general
+ // We can hoist evaluating the child discriminant out of the branch
+ if !may_hoist(tcx, body, *child_place) {
+ return None;
}
- fn find_discriminant_switch_pairing(
- &self,
- discr_info: &SwitchDiscriminantInfo<'tcx>,
- target: BasicBlock,
- value: u128,
- ) -> Option<OptimizationInfo<'tcx>> {
- let bb = &self.body.basic_blocks()[target];
- // find switch
- let terminator = bb.terminator();
- if is_switch(terminator) {
- let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
-
- // the types of the two adts matched on have to be equalfor this optimization to apply
- if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
- trace!(
- "NO: types do not match. LHS: {:?}, RHS: {:?}",
- discr_info.type_adt_matched_on,
- this_bb_discr_info.type_adt_matched_on
- );
- return None;
- }
-
- // the otherwise branch of the two switches have to point to the same bb
- if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
- trace!("NO: otherwise target is not the same");
- return None;
- }
-
- // check that the value being matched on is the same. The
- if !this_bb_discr_info.targets_with_values.iter().any(|x| x.0 == value) {
- trace!("NO: values being matched on are not the same");
- return None;
- }
-
- // only allow optimization if the left and right of the tuple being matched are the same variants.
- // so the following should not optimize
- // ```rust
- // let x: Option<()>;
- // let y: Option<()>;
- // match (x,y) {
- // (Some(_), None) => {},
- // _ => {}
- // }
- // ```
- // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
- if !(this_bb_discr_info.targets_with_values.len() == 1
- && this_bb_discr_info.targets_with_values[0].0 == value)
- {
- trace!(
- "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
- );
- return None;
- }
-
- // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially.
- // for example, this should not be optimized:
- //
- // ```rust
- // enum E<'a> { Empty, Some(&'a E<'a>), }
- // let Some(Some(_)) = e;
- // ```
- //
- // ```mir
- // bb0: {
- // _2 = discriminant(*_1)
- // switchInt(_2) -> [...]
- // }
- // bb1: {
- // _3 = discriminant(*(((*_1) as Some).0: &E))
- // switchInt(_3) -> [...]
- // }
- // ```
- let discr_place = discr_info.place_of_adt_discr_read;
- let this_discr_place = this_bb_discr_info.place_of_adt_discr_read;
- if discr_place.local == this_discr_place.local
- && this_discr_place.projection.starts_with(discr_place.projection)
- {
- trace!("NO: one target is the projection of another");
- return None;
- }
-
- // if we reach this point, the optimization applies, and we should be able to optimize this case
- // store the info that is needed to apply the optimization
-
- Some(OptimizationInfo {
- first_switch_info: discr_info.clone(),
- second_switch_info: this_bb_discr_info,
- })
- } else {
- None
+ // Verify that the optimization is legal for each branch
+ for (value, child) in targets.iter() {
+ if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
+ return None;
}
}
+ Some(OptimizationData {
+ destination,
+ child_place: *child_place,
+ child_ty: *child_ty,
+ child_source: child_terminator.source_info,
+ })
+}
- fn find_switch_discriminant_info(
- &self,
- bb: &BasicBlockData<'tcx>,
- switch: &Terminator<'tcx>,
- ) -> Option<SwitchDiscriminantInfo<'tcx>> {
- match &switch.kind {
- TerminatorKind::SwitchInt { discr, targets, .. } => {
- let discr_local = discr.place()?.as_local()?;
- // the declaration of the discriminant read. Place of this read is being used in the switch
- let discr_decl = &self.body.local_decls()[discr_local];
- let discr_ty = discr_decl.ty;
- // the otherwise target lies as the last element
- let otherwise_bb = targets.otherwise();
- let targets_with_values = targets.iter().collect();
-
- // find the place of the adt where the discriminant is being read from
- // assume this is the last statement of the block
- let place_of_adt_discr_read = match bb.statements.last()?.kind {
- StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
- Some(adt_place)
- }
- _ => None,
- }?;
-
- let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
-
- Some(SwitchDiscriminantInfo {
- discr_used_in_switch: discr.place()?,
- discr_ty,
- otherwise_bb,
- targets_with_values,
- discr_source_info: discr_decl.source_info,
- place_of_adt_discr_read,
- type_adt_matched_on,
- })
- }
- _ => unreachable!("must only be passed terminator that is a switch"),
- }
+fn verify_candidate_branch<'tcx>(
+ branch: &BasicBlockData<'tcx>,
+ value: u128,
+ place: Place<'tcx>,
+ destination: BasicBlock,
+) -> bool {
+ // In order for the optimization to be correct, the branch must...
+ // ...have exactly one statement
+ if branch.statements.len() != 1 {
+ return false;
+ }
+ // ...assign the descriminant of `place` in that statement
+ let StatementKind::Assign(boxed) = &branch.statements[0].kind else {
+ return false
+ };
+ let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed else {
+ return false
+ };
+ if *from_place != place {
+ return false;
+ }
+ // ...make that assignment to a local
+ if discr_place.projection.len() != 0 {
+ return false;
+ }
+ // ...terminate on a `SwitchInt` that invalidates that local
+ let TerminatorKind::SwitchInt{ discr: switch_op, targets, .. } = &branch.terminator().kind else {
+ return false
+ };
+ if *switch_op != Operand::Move(*discr_place) {
+ return false;
+ }
+ // ...fall through to `destination` if the switch misses
+ if destination != targets.otherwise() {
+ return false;
+ }
+ // ...have a branch for value `value`
+ let mut iter = targets.iter();
+ let Some((target_value, _)) = iter.next() else {
+ return false;
+ };
+ if target_value != value {
+ return false;
+ }
+ // ...and have no more branches
+ if let Some(_) = iter.next() {
+ return false;
}
+ return true;
}