]> git.proxmox.com Git - rustc.git/blob - compiler/rustc_mir/src/transform/early_otherwise_branch.rs
New upstream version 1.52.0~beta.3+dfsg1
[rustc.git] / compiler / rustc_mir / src / transform / early_otherwise_branch.rs
1 use crate::{transform::MirPass, util::patch::MirPatch};
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```text
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 /// (Some(_), Some(_)) => {0},
14 /// _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```text
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = // get discriminant of x
22 /// let discriminant_y = // get discriminant of y
23 /// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
24 /// ```
25 pub struct EarlyOtherwiseBranch;
26
27 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
28 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
29 // FIXME(#78496)
30 if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
31 return;
32 }
33
34 if tcx.sess.mir_opt_level() < 3 {
35 return;
36 }
37 trace!("running EarlyOtherwiseBranch on {:?}", body.source);
38 // we are only interested in this bb if the terminator is a switchInt
39 let bbs_with_switch =
40 body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
41
42 let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
43 .flat_map(|(bb_idx, bb)| {
44 let switch = bb.terminator();
45 let helper = Helper { body, tcx };
46 let infos = helper.go(bb, switch)?;
47 Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
48 })
49 .collect();
50
51 let should_cleanup = !opts_to_apply.is_empty();
52
53 for opt_to_apply in opts_to_apply {
54 if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) {
55 break;
56 }
57
58 trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
59
60 let statements_before =
61 body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
62 let end_of_block_location = Location {
63 block: opt_to_apply.basic_block_first_switch,
64 statement_index: statements_before,
65 };
66
67 let mut patch = MirPatch::new(body);
68
69 // create temp to store second discriminant in
70 let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
71 let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
72 let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
73
74 patch.add_statement(
75 end_of_block_location,
76 StatementKind::StorageLive(second_discriminant_temp),
77 );
78
79 // create assignment of discriminant
80 let place_of_adt_to_get_discriminant_of =
81 opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
82 patch.add_assign(
83 end_of_block_location,
84 Place::from(second_discriminant_temp),
85 Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
86 );
87
88 // create temp to store NotEqual comparison between the two discriminants
89 let not_equal = BinOp::Ne;
90 let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
91 let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
92 patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
93
94 // create NotEqual comparison between the two discriminants
95 let first_descriminant_place =
96 opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
97 let not_equal_rvalue = Rvalue::BinaryOp(
98 not_equal,
99 box (
100 Operand::Copy(Place::from(second_discriminant_temp)),
101 Operand::Copy(first_descriminant_place),
102 ),
103 );
104 patch.add_statement(
105 end_of_block_location,
106 StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)),
107 );
108
109 let new_targets = opt_to_apply
110 .infos
111 .iter()
112 .flat_map(|x| x.second_switch_info.targets_with_values.iter())
113 .cloned();
114
115 let targets = SwitchTargets::new(
116 new_targets,
117 opt_to_apply.infos[0].first_switch_info.otherwise_bb,
118 );
119
120 // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
121 let new_switch_data = BasicBlockData::new(Some(Terminator {
122 source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
123 kind: TerminatorKind::SwitchInt {
124 // the first and second discriminants are equal, so just pick one
125 discr: Operand::Copy(first_descriminant_place),
126 switch_ty: discr_type,
127 targets,
128 },
129 }));
130
131 let new_switch_bb = patch.new_block(new_switch_data);
132
133 // switch on the NotEqual. If true, then jump to the `otherwise` case.
134 // If false, then jump to a basic block that then jumps to the correct disciminant case
135 let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
136 let false_case = new_switch_bb;
137 patch.patch_terminator(
138 opt_to_apply.basic_block_first_switch,
139 TerminatorKind::if_(
140 tcx,
141 Operand::Move(Place::from(not_equal_temp)),
142 true_case,
143 false_case,
144 ),
145 );
146
147 // generate StorageDead for the second_discriminant_temp not in use anymore
148 patch.add_statement(
149 end_of_block_location,
150 StatementKind::StorageDead(second_discriminant_temp),
151 );
152
153 // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
154 for bb in [false_case, true_case].iter() {
155 patch.add_statement(
156 Location { block: *bb, statement_index: 0 },
157 StatementKind::StorageDead(not_equal_temp),
158 );
159 }
160
161 patch.apply(body);
162 }
163
164 // Since this optimization adds new basic blocks and invalidates others,
165 // clean up the cfg to make it nicer for other passes
166 if should_cleanup {
167 simplify_cfg(body);
168 }
169 }
170 }
171
172 fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool {
173 match terminator.kind {
174 TerminatorKind::SwitchInt { .. } => true,
175 _ => false,
176 }
177 }
178
179 struct Helper<'a, 'tcx> {
180 body: &'a Body<'tcx>,
181 tcx: TyCtxt<'tcx>,
182 }
183
184 #[derive(Debug, Clone)]
185 struct SwitchDiscriminantInfo<'tcx> {
186 /// Type of the discriminant being switched on
187 discr_ty: Ty<'tcx>,
188 /// The basic block that the otherwise branch points to
189 otherwise_bb: BasicBlock,
190 /// Target along with the value being branched from. Otherwise is not included
191 targets_with_values: Vec<(u128, BasicBlock)>,
192 discr_source_info: SourceInfo,
193 /// The place of the discriminant used in the switch
194 discr_used_in_switch: Place<'tcx>,
195 /// The place of the adt that has its discriminant read
196 place_of_adt_discr_read: Place<'tcx>,
197 /// The type of the adt that has its discriminant read
198 type_adt_matched_on: Ty<'tcx>,
199 }
200
201 #[derive(Debug)]
202 struct OptimizationToApply<'tcx> {
203 infos: Vec<OptimizationInfo<'tcx>>,
204 /// Basic block of the original first switch
205 basic_block_first_switch: BasicBlock,
206 }
207
208 #[derive(Debug)]
209 struct OptimizationInfo<'tcx> {
210 /// Info about the first switch and discriminant
211 first_switch_info: SwitchDiscriminantInfo<'tcx>,
212 /// Info about the second switch and discriminant
213 second_switch_info: SwitchDiscriminantInfo<'tcx>,
214 }
215
216 impl<'a, 'tcx> Helper<'a, 'tcx> {
217 pub fn go(
218 &self,
219 bb: &BasicBlockData<'tcx>,
220 switch: &Terminator<'tcx>,
221 ) -> Option<Vec<OptimizationInfo<'tcx>>> {
222 // try to find the statement that defines the discriminant that is used for the switch
223 let discr = self.find_switch_discriminant_info(bb, switch)?;
224
225 // go through each target, finding a discriminant read, and a switch
226 let results = discr
227 .targets_with_values
228 .iter()
229 .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value));
230
231 // if the optimization did not apply for one of the targets, then abort
232 if results.clone().any(|x| x.is_none()) || results.len() == 0 {
233 trace!("NO: not all of the targets matched the pattern for optimization");
234 return None;
235 }
236
237 Some(results.flatten().collect())
238 }
239
240 fn find_discriminant_switch_pairing(
241 &self,
242 discr_info: &SwitchDiscriminantInfo<'tcx>,
243 target: BasicBlock,
244 value: u128,
245 ) -> Option<OptimizationInfo<'tcx>> {
246 let bb = &self.body.basic_blocks()[target];
247 // find switch
248 let terminator = bb.terminator();
249 if is_switch(terminator) {
250 let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
251
252 // the types of the two adts matched on have to be equalfor this optimization to apply
253 if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
254 trace!(
255 "NO: types do not match. LHS: {:?}, RHS: {:?}",
256 discr_info.type_adt_matched_on,
257 this_bb_discr_info.type_adt_matched_on
258 );
259 return None;
260 }
261
262 // the otherwise branch of the two switches have to point to the same bb
263 if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
264 trace!("NO: otherwise target is not the same");
265 return None;
266 }
267
268 // check that the value being matched on is the same. The
269 if this_bb_discr_info.targets_with_values.iter().find(|x| x.0 == value).is_none() {
270 trace!("NO: values being matched on are not the same");
271 return None;
272 }
273
274 // only allow optimization if the left and right of the tuple being matched are the same variants.
275 // so the following should not optimize
276 // ```rust
277 // let x: Option<()>;
278 // let y: Option<()>;
279 // match (x,y) {
280 // (Some(_), None) => {},
281 // _ => {}
282 // }
283 // ```
284 // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
285 if !(this_bb_discr_info.targets_with_values.len() == 1
286 && this_bb_discr_info.targets_with_values[0].0 == value)
287 {
288 trace!(
289 "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
290 );
291 return None;
292 }
293
294 // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially.
295 // for example, this should not be optimized:
296 //
297 // ```rust
298 // enum E<'a> { Empty, Some(&'a E<'a>), }
299 // let Some(Some(_)) = e;
300 // ```
301 //
302 // ```mir
303 // bb0: {
304 // _2 = discriminant(*_1)
305 // switchInt(_2) -> [...]
306 // }
307 // bb1: {
308 // _3 = discriminant(*(((*_1) as Some).0: &E))
309 // switchInt(_3) -> [...]
310 // }
311 // ```
312 let discr_place = discr_info.place_of_adt_discr_read;
313 let this_discr_place = this_bb_discr_info.place_of_adt_discr_read;
314 if discr_place.local == this_discr_place.local
315 && this_discr_place.projection.starts_with(discr_place.projection)
316 {
317 trace!("NO: one target is the projection of another");
318 return None;
319 }
320
321 // if we reach this point, the optimization applies, and we should be able to optimize this case
322 // store the info that is needed to apply the optimization
323
324 Some(OptimizationInfo {
325 first_switch_info: discr_info.clone(),
326 second_switch_info: this_bb_discr_info,
327 })
328 } else {
329 None
330 }
331 }
332
333 fn find_switch_discriminant_info(
334 &self,
335 bb: &BasicBlockData<'tcx>,
336 switch: &Terminator<'tcx>,
337 ) -> Option<SwitchDiscriminantInfo<'tcx>> {
338 match &switch.kind {
339 TerminatorKind::SwitchInt { discr, targets, .. } => {
340 let discr_local = discr.place()?.as_local()?;
341 // the declaration of the discriminant read. Place of this read is being used in the switch
342 let discr_decl = &self.body.local_decls()[discr_local];
343 let discr_ty = discr_decl.ty;
344 // the otherwise target lies as the last element
345 let otherwise_bb = targets.otherwise();
346 let targets_with_values = targets.iter().collect();
347
348 // find the place of the adt where the discriminant is being read from
349 // assume this is the last statement of the block
350 let place_of_adt_discr_read = match bb.statements.last()?.kind {
351 StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
352 Some(adt_place)
353 }
354 _ => None,
355 }?;
356
357 let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
358
359 Some(SwitchDiscriminantInfo {
360 discr_used_in_switch: discr.place()?,
361 discr_ty,
362 otherwise_bb,
363 targets_with_values,
364 discr_source_info: discr_decl.source_info,
365 place_of_adt_discr_read,
366 type_adt_matched_on,
367 })
368 }
369 _ => unreachable!("must only be passed terminator that is a switch"),
370 }
371 }
372 }