]>
Commit | Line | Data |
---|---|---|
29967ef6 | 1 | use crate::{transform::MirPass, util::patch::MirPatch}; |
1b1a35ee XL |
2 | use rustc_middle::mir::*; |
3 | use rustc_middle::ty::{Ty, TyCtxt}; | |
29967ef6 | 4 | use std::fmt::Debug; |
1b1a35ee XL |
5 | |
6 | use super::simplify::simplify_cfg; | |
7 | ||
8 | /// This pass optimizes something like | |
9 | /// ```text | |
10 | /// let x: Option<()>; | |
11 | /// let y: Option<()>; | |
12 | /// match (x,y) { | |
13 | /// (Some(_), Some(_)) => {0}, | |
14 | /// _ => {1} | |
15 | /// } | |
16 | /// ``` | |
17 | /// into something like | |
18 | /// ```text | |
19 | /// let x: Option<()>; | |
20 | /// let y: Option<()>; | |
21 | /// let discriminant_x = // get discriminant of x | |
22 | /// let discriminant_y = // get discriminant of y | |
23 | /// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0} | |
24 | /// ``` | |
25 | pub struct EarlyOtherwiseBranch; | |
26 | ||
27 | impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch { | |
29967ef6 | 28 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
6a06907d XL |
29 | // FIXME(#78496) |
30 | if !tcx.sess.opts.debugging_opts.unsound_mir_opts { | |
31 | return; | |
32 | } | |
33 | ||
34 | if tcx.sess.mir_opt_level() < 3 { | |
1b1a35ee XL |
35 | return; |
36 | } | |
29967ef6 | 37 | trace!("running EarlyOtherwiseBranch on {:?}", body.source); |
1b1a35ee XL |
38 | // we are only interested in this bb if the terminator is a switchInt |
39 | let bbs_with_switch = | |
40 | body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator())); | |
41 | ||
42 | let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch | |
43 | .flat_map(|(bb_idx, bb)| { | |
44 | let switch = bb.terminator(); | |
45 | let helper = Helper { body, tcx }; | |
46 | let infos = helper.go(bb, switch)?; | |
47 | Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx }) | |
48 | }) | |
49 | .collect(); | |
50 | ||
51 | let should_cleanup = !opts_to_apply.is_empty(); | |
52 | ||
53 | for opt_to_apply in opts_to_apply { | |
fc512014 XL |
54 | if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) { |
55 | break; | |
56 | } | |
57 | ||
1b1a35ee XL |
58 | trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply); |
59 | ||
60 | let statements_before = | |
61 | body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len(); | |
62 | let end_of_block_location = Location { | |
63 | block: opt_to_apply.basic_block_first_switch, | |
64 | statement_index: statements_before, | |
65 | }; | |
66 | ||
67 | let mut patch = MirPatch::new(body); | |
68 | ||
69 | // create temp to store second discriminant in | |
70 | let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty; | |
71 | let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span; | |
72 | let second_discriminant_temp = patch.new_temp(discr_type, discr_span); | |
73 | ||
74 | patch.add_statement( | |
75 | end_of_block_location, | |
76 | StatementKind::StorageLive(second_discriminant_temp), | |
77 | ); | |
78 | ||
79 | // create assignment of discriminant | |
80 | let place_of_adt_to_get_discriminant_of = | |
81 | opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read; | |
82 | patch.add_assign( | |
83 | end_of_block_location, | |
84 | Place::from(second_discriminant_temp), | |
85 | Rvalue::Discriminant(place_of_adt_to_get_discriminant_of), | |
86 | ); | |
87 | ||
88 | // create temp to store NotEqual comparison between the two discriminants | |
89 | let not_equal = BinOp::Ne; | |
90 | let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type); | |
91 | let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span); | |
92 | patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp)); | |
93 | ||
94 | // create NotEqual comparison between the two discriminants | |
95 | let first_descriminant_place = | |
96 | opt_to_apply.infos[0].first_switch_info.discr_used_in_switch; | |
97 | let not_equal_rvalue = Rvalue::BinaryOp( | |
98 | not_equal, | |
6a06907d XL |
99 | box ( |
100 | Operand::Copy(Place::from(second_discriminant_temp)), | |
101 | Operand::Copy(first_descriminant_place), | |
102 | ), | |
1b1a35ee XL |
103 | ); |
104 | patch.add_statement( | |
105 | end_of_block_location, | |
106 | StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)), | |
107 | ); | |
108 | ||
29967ef6 | 109 | let new_targets = opt_to_apply |
1b1a35ee XL |
110 | .infos |
111 | .iter() | |
112 | .flat_map(|x| x.second_switch_info.targets_with_values.iter()) | |
29967ef6 XL |
113 | .cloned(); |
114 | ||
115 | let targets = SwitchTargets::new( | |
116 | new_targets, | |
117 | opt_to_apply.infos[0].first_switch_info.otherwise_bb, | |
118 | ); | |
1b1a35ee | 119 | |
1b1a35ee XL |
120 | // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal |
121 | let new_switch_data = BasicBlockData::new(Some(Terminator { | |
122 | source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info, | |
123 | kind: TerminatorKind::SwitchInt { | |
124 | // the first and second discriminants are equal, so just pick one | |
125 | discr: Operand::Copy(first_descriminant_place), | |
126 | switch_ty: discr_type, | |
29967ef6 | 127 | targets, |
1b1a35ee XL |
128 | }, |
129 | })); | |
130 | ||
131 | let new_switch_bb = patch.new_block(new_switch_data); | |
132 | ||
133 | // switch on the NotEqual. If true, then jump to the `otherwise` case. | |
134 | // If false, then jump to a basic block that then jumps to the correct disciminant case | |
135 | let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb; | |
136 | let false_case = new_switch_bb; | |
137 | patch.patch_terminator( | |
138 | opt_to_apply.basic_block_first_switch, | |
139 | TerminatorKind::if_( | |
140 | tcx, | |
141 | Operand::Move(Place::from(not_equal_temp)), | |
142 | true_case, | |
143 | false_case, | |
144 | ), | |
145 | ); | |
146 | ||
147 | // generate StorageDead for the second_discriminant_temp not in use anymore | |
148 | patch.add_statement( | |
149 | end_of_block_location, | |
150 | StatementKind::StorageDead(second_discriminant_temp), | |
151 | ); | |
152 | ||
153 | // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch | |
154 | for bb in [false_case, true_case].iter() { | |
155 | patch.add_statement( | |
156 | Location { block: *bb, statement_index: 0 }, | |
157 | StatementKind::StorageDead(not_equal_temp), | |
158 | ); | |
159 | } | |
160 | ||
161 | patch.apply(body); | |
162 | } | |
163 | ||
164 | // Since this optimization adds new basic blocks and invalidates others, | |
165 | // clean up the cfg to make it nicer for other passes | |
166 | if should_cleanup { | |
167 | simplify_cfg(body); | |
168 | } | |
169 | } | |
170 | } | |
171 | ||
172 | fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool { | |
173 | match terminator.kind { | |
174 | TerminatorKind::SwitchInt { .. } => true, | |
175 | _ => false, | |
176 | } | |
177 | } | |
178 | ||
179 | struct Helper<'a, 'tcx> { | |
180 | body: &'a Body<'tcx>, | |
181 | tcx: TyCtxt<'tcx>, | |
182 | } | |
183 | ||
184 | #[derive(Debug, Clone)] | |
185 | struct SwitchDiscriminantInfo<'tcx> { | |
186 | /// Type of the discriminant being switched on | |
187 | discr_ty: Ty<'tcx>, | |
188 | /// The basic block that the otherwise branch points to | |
189 | otherwise_bb: BasicBlock, | |
190 | /// Target along with the value being branched from. Otherwise is not included | |
29967ef6 | 191 | targets_with_values: Vec<(u128, BasicBlock)>, |
1b1a35ee XL |
192 | discr_source_info: SourceInfo, |
193 | /// The place of the discriminant used in the switch | |
194 | discr_used_in_switch: Place<'tcx>, | |
195 | /// The place of the adt that has its discriminant read | |
196 | place_of_adt_discr_read: Place<'tcx>, | |
197 | /// The type of the adt that has its discriminant read | |
198 | type_adt_matched_on: Ty<'tcx>, | |
199 | } | |
200 | ||
201 | #[derive(Debug)] | |
202 | struct OptimizationToApply<'tcx> { | |
203 | infos: Vec<OptimizationInfo<'tcx>>, | |
204 | /// Basic block of the original first switch | |
205 | basic_block_first_switch: BasicBlock, | |
206 | } | |
207 | ||
208 | #[derive(Debug)] | |
209 | struct OptimizationInfo<'tcx> { | |
210 | /// Info about the first switch and discriminant | |
211 | first_switch_info: SwitchDiscriminantInfo<'tcx>, | |
212 | /// Info about the second switch and discriminant | |
213 | second_switch_info: SwitchDiscriminantInfo<'tcx>, | |
214 | } | |
215 | ||
216 | impl<'a, 'tcx> Helper<'a, 'tcx> { | |
217 | pub fn go( | |
218 | &self, | |
219 | bb: &BasicBlockData<'tcx>, | |
220 | switch: &Terminator<'tcx>, | |
221 | ) -> Option<Vec<OptimizationInfo<'tcx>>> { | |
222 | // try to find the statement that defines the discriminant that is used for the switch | |
223 | let discr = self.find_switch_discriminant_info(bb, switch)?; | |
224 | ||
225 | // go through each target, finding a discriminant read, and a switch | |
fc512014 XL |
226 | let results = discr |
227 | .targets_with_values | |
228 | .iter() | |
229 | .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value)); | |
1b1a35ee XL |
230 | |
231 | // if the optimization did not apply for one of the targets, then abort | |
232 | if results.clone().any(|x| x.is_none()) || results.len() == 0 { | |
233 | trace!("NO: not all of the targets matched the pattern for optimization"); | |
234 | return None; | |
235 | } | |
236 | ||
237 | Some(results.flatten().collect()) | |
238 | } | |
239 | ||
240 | fn find_discriminant_switch_pairing( | |
241 | &self, | |
242 | discr_info: &SwitchDiscriminantInfo<'tcx>, | |
243 | target: BasicBlock, | |
244 | value: u128, | |
245 | ) -> Option<OptimizationInfo<'tcx>> { | |
246 | let bb = &self.body.basic_blocks()[target]; | |
247 | // find switch | |
248 | let terminator = bb.terminator(); | |
249 | if is_switch(terminator) { | |
250 | let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?; | |
251 | ||
252 | // the types of the two adts matched on have to be equalfor this optimization to apply | |
253 | if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on { | |
254 | trace!( | |
255 | "NO: types do not match. LHS: {:?}, RHS: {:?}", | |
256 | discr_info.type_adt_matched_on, | |
257 | this_bb_discr_info.type_adt_matched_on | |
258 | ); | |
259 | return None; | |
260 | } | |
261 | ||
262 | // the otherwise branch of the two switches have to point to the same bb | |
263 | if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb { | |
264 | trace!("NO: otherwise target is not the same"); | |
265 | return None; | |
266 | } | |
267 | ||
268 | // check that the value being matched on is the same. The | |
29967ef6 | 269 | if this_bb_discr_info.targets_with_values.iter().find(|x| x.0 == value).is_none() { |
1b1a35ee XL |
270 | trace!("NO: values being matched on are not the same"); |
271 | return None; | |
272 | } | |
273 | ||
274 | // only allow optimization if the left and right of the tuple being matched are the same variants. | |
275 | // so the following should not optimize | |
276 | // ```rust | |
277 | // let x: Option<()>; | |
278 | // let y: Option<()>; | |
279 | // match (x,y) { | |
280 | // (Some(_), None) => {}, | |
281 | // _ => {} | |
282 | // } | |
283 | // ``` | |
284 | // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch | |
285 | if !(this_bb_discr_info.targets_with_values.len() == 1 | |
29967ef6 | 286 | && this_bb_discr_info.targets_with_values[0].0 == value) |
1b1a35ee XL |
287 | { |
288 | trace!( | |
289 | "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here" | |
290 | ); | |
291 | return None; | |
292 | } | |
293 | ||
fc512014 XL |
294 | // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially. |
295 | // for example, this should not be optimized: | |
296 | // | |
297 | // ```rust | |
298 | // enum E<'a> { Empty, Some(&'a E<'a>), } | |
299 | // let Some(Some(_)) = e; | |
300 | // ``` | |
301 | // | |
302 | // ```mir | |
303 | // bb0: { | |
304 | // _2 = discriminant(*_1) | |
305 | // switchInt(_2) -> [...] | |
306 | // } | |
307 | // bb1: { | |
308 | // _3 = discriminant(*(((*_1) as Some).0: &E)) | |
309 | // switchInt(_3) -> [...] | |
310 | // } | |
311 | // ``` | |
312 | let discr_place = discr_info.place_of_adt_discr_read; | |
313 | let this_discr_place = this_bb_discr_info.place_of_adt_discr_read; | |
314 | if discr_place.local == this_discr_place.local | |
315 | && this_discr_place.projection.starts_with(discr_place.projection) | |
316 | { | |
317 | trace!("NO: one target is the projection of another"); | |
318 | return None; | |
319 | } | |
320 | ||
1b1a35ee XL |
321 | // if we reach this point, the optimization applies, and we should be able to optimize this case |
322 | // store the info that is needed to apply the optimization | |
323 | ||
324 | Some(OptimizationInfo { | |
325 | first_switch_info: discr_info.clone(), | |
326 | second_switch_info: this_bb_discr_info, | |
327 | }) | |
328 | } else { | |
329 | None | |
330 | } | |
331 | } | |
332 | ||
333 | fn find_switch_discriminant_info( | |
334 | &self, | |
335 | bb: &BasicBlockData<'tcx>, | |
336 | switch: &Terminator<'tcx>, | |
337 | ) -> Option<SwitchDiscriminantInfo<'tcx>> { | |
338 | match &switch.kind { | |
29967ef6 | 339 | TerminatorKind::SwitchInt { discr, targets, .. } => { |
1b1a35ee XL |
340 | let discr_local = discr.place()?.as_local()?; |
341 | // the declaration of the discriminant read. Place of this read is being used in the switch | |
342 | let discr_decl = &self.body.local_decls()[discr_local]; | |
343 | let discr_ty = discr_decl.ty; | |
344 | // the otherwise target lies as the last element | |
29967ef6 XL |
345 | let otherwise_bb = targets.otherwise(); |
346 | let targets_with_values = targets.iter().collect(); | |
1b1a35ee XL |
347 | |
348 | // find the place of the adt where the discriminant is being read from | |
349 | // assume this is the last statement of the block | |
350 | let place_of_adt_discr_read = match bb.statements.last()?.kind { | |
351 | StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => { | |
352 | Some(adt_place) | |
353 | } | |
354 | _ => None, | |
355 | }?; | |
356 | ||
357 | let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty; | |
358 | ||
359 | Some(SwitchDiscriminantInfo { | |
360 | discr_used_in_switch: discr.place()?, | |
361 | discr_ty, | |
362 | otherwise_bb, | |
363 | targets_with_values, | |
364 | discr_source_info: discr_decl.source_info, | |
365 | place_of_adt_discr_read, | |
366 | type_adt_matched_on, | |
367 | }) | |
368 | } | |
369 | _ => unreachable!("must only be passed terminator that is a switch"), | |
370 | } | |
371 | } | |
372 | } |