]>
Commit | Line | Data |
---|---|---|
60c5eb7d XL |
1 | //! The general point of the optimizations provided here is to simplify something like: |
2 | //! | |
3 | //! ```rust | |
4 | //! match x { | |
5 | //! Ok(x) => Ok(x), | |
6 | //! Err(x) => Err(x) | |
7 | //! } | |
8 | //! ``` | |
9 | //! | |
10 | //! into just `x`. | |
11 | ||
29967ef6 | 12 | use crate::transform::{simplify, MirPass}; |
dfeec247 | 13 | use itertools::Itertools as _; |
f035d41b XL |
14 | use rustc_index::{bit_set::BitSet, vec::IndexVec}; |
15 | use rustc_middle::mir::visit::{NonUseContext, PlaceContext, Visitor}; | |
ba9703b0 | 16 | use rustc_middle::mir::*; |
3dfed10e | 17 | use rustc_middle::ty::{self, List, Ty, TyCtxt}; |
60c5eb7d | 18 | use rustc_target::abi::VariantIdx; |
1b1a35ee | 19 | use std::iter::{once, Enumerate, Peekable}; |
f9f354fc | 20 | use std::slice::Iter; |
60c5eb7d XL |
21 | |
22 | /// Simplifies arms of form `Variant(x) => Variant(x)` to just a move. | |
23 | /// | |
24 | /// This is done by transforming basic blocks where the statements match: | |
25 | /// | |
26 | /// ```rust | |
27 | /// _LOCAL_TMP = ((_LOCAL_1 as Variant ).FIELD: TY ); | |
f9f354fc XL |
28 | /// _TMP_2 = _LOCAL_TMP; |
29 | /// ((_LOCAL_0 as Variant).FIELD: TY) = move _TMP_2; | |
60c5eb7d XL |
30 | /// discriminant(_LOCAL_0) = VAR_IDX; |
31 | /// ``` | |
32 | /// | |
33 | /// into: | |
34 | /// | |
35 | /// ```rust | |
36 | /// _LOCAL_0 = move _LOCAL_1 | |
37 | /// ``` | |
38 | pub struct SimplifyArmIdentity; | |
39 | ||
f9f354fc XL |
40 | #[derive(Debug)] |
41 | struct ArmIdentityInfo<'tcx> { | |
42 | /// Storage location for the variant's field | |
43 | local_temp_0: Local, | |
44 | /// Storage location holding the variant being read from | |
45 | local_1: Local, | |
46 | /// The variant field being read from | |
47 | vf_s0: VarField<'tcx>, | |
48 | /// Index of the statement which loads the variant being read | |
49 | get_variant_field_stmt: usize, | |
50 | ||
51 | /// Tracks each assignment to a temporary of the variant's field | |
52 | field_tmp_assignments: Vec<(Local, Local)>, | |
53 | ||
54 | /// Storage location holding the variant's field that was read from | |
55 | local_tmp_s1: Local, | |
56 | /// Storage location holding the enum that we are writing to | |
57 | local_0: Local, | |
58 | /// The variant field being written to | |
59 | vf_s1: VarField<'tcx>, | |
60 | ||
61 | /// Storage location that the discriminant is being written to | |
62 | set_discr_local: Local, | |
63 | /// The variant being written | |
64 | set_discr_var_idx: VariantIdx, | |
65 | ||
66 | /// Index of the statement that should be overwritten as a move | |
67 | stmt_to_overwrite: usize, | |
68 | /// SourceInfo for the new move | |
69 | source_info: SourceInfo, | |
70 | ||
71 | /// Indices of matching Storage{Live,Dead} statements encountered. | |
72 | /// (StorageLive index,, StorageDead index, Local) | |
73 | storage_stmts: Vec<(usize, usize, Local)>, | |
74 | ||
75 | /// The statements that should be removed (turned into nops) | |
76 | stmts_to_remove: Vec<usize>, | |
f035d41b XL |
77 | |
78 | /// Indices of debug variables that need to be adjusted to point to | |
79 | // `{local_0}.{dbg_projection}`. | |
80 | dbg_info_to_adjust: Vec<usize>, | |
81 | ||
82 | /// The projection used to rewrite debug info. | |
83 | dbg_projection: &'tcx List<PlaceElem<'tcx>>, | |
f9f354fc XL |
84 | } |
85 | ||
f035d41b XL |
86 | fn get_arm_identity_info<'a, 'tcx>( |
87 | stmts: &'a [Statement<'tcx>], | |
88 | locals_count: usize, | |
89 | debug_info: &'a [VarDebugInfo<'tcx>], | |
90 | ) -> Option<ArmIdentityInfo<'tcx>> { | |
f9f354fc XL |
91 | // This can't possibly match unless there are at least 3 statements in the block |
92 | // so fail fast on tiny blocks. | |
93 | if stmts.len() < 3 { | |
94 | return None; | |
95 | } | |
96 | ||
97 | let mut tmp_assigns = Vec::new(); | |
98 | let mut nop_stmts = Vec::new(); | |
99 | let mut storage_stmts = Vec::new(); | |
100 | let mut storage_live_stmts = Vec::new(); | |
101 | let mut storage_dead_stmts = Vec::new(); | |
102 | ||
103 | type StmtIter<'a, 'tcx> = Peekable<Enumerate<Iter<'a, Statement<'tcx>>>>; | |
104 | ||
105 | fn is_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool { | |
106 | matches!(stmt.kind, StatementKind::StorageLive(_) | StatementKind::StorageDead(_)) | |
107 | } | |
108 | ||
109 | /// Eats consecutive Statements which match `test`, performing the specified `action` for each. | |
110 | /// The iterator `stmt_iter` is not advanced if none were matched. | |
111 | fn try_eat<'a, 'tcx>( | |
112 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
113 | test: impl Fn(&'a Statement<'tcx>) -> bool, | |
f035d41b | 114 | mut action: impl FnMut(usize, &'a Statement<'tcx>), |
f9f354fc XL |
115 | ) { |
116 | while stmt_iter.peek().map(|(_, stmt)| test(stmt)).unwrap_or(false) { | |
117 | let (idx, stmt) = stmt_iter.next().unwrap(); | |
118 | ||
119 | action(idx, stmt); | |
120 | } | |
121 | } | |
122 | ||
123 | /// Eats consecutive `StorageLive` and `StorageDead` Statements. | |
124 | /// The iterator `stmt_iter` is not advanced if none were found. | |
125 | fn try_eat_storage_stmts<'a, 'tcx>( | |
126 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
127 | storage_live_stmts: &mut Vec<(usize, Local)>, | |
128 | storage_dead_stmts: &mut Vec<(usize, Local)>, | |
129 | ) { | |
130 | try_eat(stmt_iter, is_storage_stmt, |idx, stmt| { | |
131 | if let StatementKind::StorageLive(l) = stmt.kind { | |
132 | storage_live_stmts.push((idx, l)); | |
133 | } else if let StatementKind::StorageDead(l) = stmt.kind { | |
134 | storage_dead_stmts.push((idx, l)); | |
135 | } | |
136 | }) | |
137 | } | |
138 | ||
139 | fn is_tmp_storage_stmt<'tcx>(stmt: &Statement<'tcx>) -> bool { | |
140 | use rustc_middle::mir::StatementKind::Assign; | |
141 | if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = &stmt.kind { | |
142 | place.as_local().is_some() && p.as_local().is_some() | |
143 | } else { | |
144 | false | |
145 | } | |
146 | } | |
147 | ||
148 | /// Eats consecutive `Assign` Statements. | |
149 | // The iterator `stmt_iter` is not advanced if none were found. | |
150 | fn try_eat_assign_tmp_stmts<'a, 'tcx>( | |
151 | stmt_iter: &mut StmtIter<'a, 'tcx>, | |
152 | tmp_assigns: &mut Vec<(Local, Local)>, | |
153 | nop_stmts: &mut Vec<usize>, | |
154 | ) { | |
155 | try_eat(stmt_iter, is_tmp_storage_stmt, |idx, stmt| { | |
156 | use rustc_middle::mir::StatementKind::Assign; | |
157 | if let Assign(box (place, Rvalue::Use(Operand::Copy(p) | Operand::Move(p)))) = | |
158 | &stmt.kind | |
159 | { | |
160 | tmp_assigns.push((place.as_local().unwrap(), p.as_local().unwrap())); | |
161 | nop_stmts.push(idx); | |
162 | } | |
163 | }) | |
164 | } | |
165 | ||
166 | fn find_storage_live_dead_stmts_for_local<'tcx>( | |
167 | local: Local, | |
168 | stmts: &[Statement<'tcx>], | |
169 | ) -> Option<(usize, usize)> { | |
170 | trace!("looking for {:?}", local); | |
171 | let mut storage_live_stmt = None; | |
172 | let mut storage_dead_stmt = None; | |
173 | for (idx, stmt) in stmts.iter().enumerate() { | |
174 | if stmt.kind == StatementKind::StorageLive(local) { | |
175 | storage_live_stmt = Some(idx); | |
176 | } else if stmt.kind == StatementKind::StorageDead(local) { | |
177 | storage_dead_stmt = Some(idx); | |
178 | } | |
179 | } | |
180 | ||
181 | Some((storage_live_stmt?, storage_dead_stmt.unwrap_or(usize::MAX))) | |
182 | } | |
183 | ||
184 | // Try to match the expected MIR structure with the basic block we're processing. | |
185 | // We want to see something that looks like: | |
186 | // ``` | |
187 | // (StorageLive(_) | StorageDead(_));* | |
188 | // _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); | |
189 | // (StorageLive(_) | StorageDead(_));* | |
190 | // (tmp_n+1 = tmp_n);* | |
191 | // (StorageLive(_) | StorageDead(_));* | |
192 | // (tmp_n+1 = tmp_n);* | |
193 | // ((LOCAL_FROM as Variant).FIELD: TY) = move tmp; | |
194 | // discriminant(LOCAL_FROM) = VariantIdx; | |
195 | // (StorageLive(_) | StorageDead(_));* | |
196 | // ``` | |
197 | let mut stmt_iter = stmts.iter().enumerate().peekable(); | |
198 | ||
199 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
200 | ||
201 | let (get_variant_field_stmt, stmt) = stmt_iter.next()?; | |
f035d41b | 202 | let (local_tmp_s0, local_1, vf_s0, dbg_projection) = match_get_variant_field(stmt)?; |
f9f354fc XL |
203 | |
204 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
205 | ||
206 | try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); | |
207 | ||
208 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
209 | ||
210 | try_eat_assign_tmp_stmts(&mut stmt_iter, &mut tmp_assigns, &mut nop_stmts); | |
211 | ||
212 | let (idx, stmt) = stmt_iter.next()?; | |
213 | let (local_tmp_s1, local_0, vf_s1) = match_set_variant_field(stmt)?; | |
214 | nop_stmts.push(idx); | |
215 | ||
216 | let (idx, stmt) = stmt_iter.next()?; | |
217 | let (set_discr_local, set_discr_var_idx) = match_set_discr(stmt)?; | |
218 | let discr_stmt_source_info = stmt.source_info; | |
219 | nop_stmts.push(idx); | |
220 | ||
221 | try_eat_storage_stmts(&mut stmt_iter, &mut storage_live_stmts, &mut storage_dead_stmts); | |
222 | ||
223 | for (live_idx, live_local) in storage_live_stmts { | |
224 | if let Some(i) = storage_dead_stmts.iter().rposition(|(_, l)| *l == live_local) { | |
225 | let (dead_idx, _) = storage_dead_stmts.swap_remove(i); | |
226 | storage_stmts.push((live_idx, dead_idx, live_local)); | |
227 | ||
228 | if live_local == local_tmp_s0 { | |
229 | nop_stmts.push(get_variant_field_stmt); | |
230 | } | |
231 | } | |
232 | } | |
1b1a35ee XL |
233 | // We sort primitive usize here so we can use unstable sort |
234 | nop_stmts.sort_unstable(); | |
f9f354fc XL |
235 | |
236 | // Use one of the statements we're going to discard between the point | |
237 | // where the storage location for the variant field becomes live and | |
238 | // is killed. | |
239 | let (live_idx, dead_idx) = find_storage_live_dead_stmts_for_local(local_tmp_s0, stmts)?; | |
240 | let stmt_to_overwrite = | |
241 | nop_stmts.iter().find(|stmt_idx| live_idx < **stmt_idx && **stmt_idx < dead_idx); | |
242 | ||
f035d41b XL |
243 | let mut tmp_assigned_vars = BitSet::new_empty(locals_count); |
244 | for (l, r) in &tmp_assigns { | |
245 | tmp_assigned_vars.insert(*l); | |
246 | tmp_assigned_vars.insert(*r); | |
247 | } | |
248 | ||
3dfed10e XL |
249 | let dbg_info_to_adjust: Vec<_> = |
250 | debug_info | |
251 | .iter() | |
252 | .enumerate() | |
253 | .filter_map(|(i, var_info)| { | |
254 | if tmp_assigned_vars.contains(var_info.place.local) { Some(i) } else { None } | |
255 | }) | |
256 | .collect(); | |
f035d41b | 257 | |
f9f354fc XL |
258 | Some(ArmIdentityInfo { |
259 | local_temp_0: local_tmp_s0, | |
260 | local_1, | |
261 | vf_s0, | |
262 | get_variant_field_stmt, | |
263 | field_tmp_assignments: tmp_assigns, | |
264 | local_tmp_s1, | |
265 | local_0, | |
266 | vf_s1, | |
267 | set_discr_local, | |
268 | set_discr_var_idx, | |
269 | stmt_to_overwrite: *stmt_to_overwrite?, | |
270 | source_info: discr_stmt_source_info, | |
271 | storage_stmts, | |
272 | stmts_to_remove: nop_stmts, | |
f035d41b XL |
273 | dbg_info_to_adjust, |
274 | dbg_projection, | |
f9f354fc XL |
275 | }) |
276 | } | |
277 | ||
278 | fn optimization_applies<'tcx>( | |
279 | opt_info: &ArmIdentityInfo<'tcx>, | |
280 | local_decls: &IndexVec<Local, LocalDecl<'tcx>>, | |
f035d41b XL |
281 | local_uses: &IndexVec<Local, usize>, |
282 | var_debug_info: &[VarDebugInfo<'tcx>], | |
f9f354fc XL |
283 | ) -> bool { |
284 | trace!("testing if optimization applies..."); | |
285 | ||
286 | // FIXME(wesleywiser): possibly relax this restriction? | |
287 | if opt_info.local_0 == opt_info.local_1 { | |
288 | trace!("NO: moving into ourselves"); | |
289 | return false; | |
290 | } else if opt_info.vf_s0 != opt_info.vf_s1 { | |
291 | trace!("NO: the field-and-variant information do not match"); | |
292 | return false; | |
293 | } else if local_decls[opt_info.local_0].ty != local_decls[opt_info.local_1].ty { | |
294 | // FIXME(Centril,oli-obk): possibly relax to same layout? | |
295 | trace!("NO: source and target locals have different types"); | |
296 | return false; | |
297 | } else if (opt_info.local_0, opt_info.vf_s0.var_idx) | |
298 | != (opt_info.set_discr_local, opt_info.set_discr_var_idx) | |
299 | { | |
300 | trace!("NO: the discriminants do not match"); | |
301 | return false; | |
302 | } | |
303 | ||
304 | // Verify the assigment chain consists of the form b = a; c = b; d = c; etc... | |
f035d41b | 305 | if opt_info.field_tmp_assignments.is_empty() { |
f9f354fc | 306 | trace!("NO: no assignments found"); |
f035d41b | 307 | return false; |
f9f354fc XL |
308 | } |
309 | let mut last_assigned_to = opt_info.field_tmp_assignments[0].1; | |
310 | let source_local = last_assigned_to; | |
311 | for (l, r) in &opt_info.field_tmp_assignments { | |
312 | if *r != last_assigned_to { | |
313 | trace!("NO: found unexpected assignment {:?} = {:?}", l, r); | |
314 | return false; | |
315 | } | |
316 | ||
317 | last_assigned_to = *l; | |
318 | } | |
319 | ||
f035d41b XL |
320 | // Check that the first and last used locals are only used twice |
321 | // since they are of the form: | |
322 | // | |
323 | // ``` | |
324 | // _first = ((_x as Variant).n: ty); | |
325 | // _n = _first; | |
326 | // ... | |
327 | // ((_y as Variant).n: ty) = _n; | |
328 | // discriminant(_y) = z; | |
329 | // ``` | |
330 | for (l, r) in &opt_info.field_tmp_assignments { | |
331 | if local_uses[*l] != 2 { | |
332 | warn!("NO: FAILED assignment chain local {:?} was used more than twice", l); | |
333 | return false; | |
334 | } else if local_uses[*r] != 2 { | |
335 | warn!("NO: FAILED assignment chain local {:?} was used more than twice", r); | |
336 | return false; | |
337 | } | |
338 | } | |
339 | ||
340 | // Check that debug info only points to full Locals and not projections. | |
341 | for dbg_idx in &opt_info.dbg_info_to_adjust { | |
342 | let dbg_info = &var_debug_info[*dbg_idx]; | |
343 | if !dbg_info.place.projection.is_empty() { | |
344 | trace!("NO: debug info for {:?} had a projection {:?}", dbg_info.name, dbg_info.place); | |
345 | return false; | |
346 | } | |
347 | } | |
348 | ||
f9f354fc XL |
349 | if source_local != opt_info.local_temp_0 { |
350 | trace!( | |
351 | "NO: start of assignment chain does not match enum variant temp: {:?} != {:?}", | |
352 | source_local, | |
353 | opt_info.local_temp_0 | |
354 | ); | |
355 | return false; | |
356 | } else if last_assigned_to != opt_info.local_tmp_s1 { | |
357 | trace!( | |
358 | "NO: end of assignemnt chain does not match written enum temp: {:?} != {:?}", | |
359 | last_assigned_to, | |
360 | opt_info.local_tmp_s1 | |
361 | ); | |
362 | return false; | |
363 | } | |
364 | ||
365 | trace!("SUCCESS: optimization applies!"); | |
3dfed10e | 366 | true |
f9f354fc XL |
367 | } |
368 | ||
60c5eb7d | 369 | impl<'tcx> MirPass<'tcx> for SimplifyArmIdentity { |
29967ef6 | 370 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
1b1a35ee XL |
371 | // FIXME(77359): This optimization can result in unsoundness. |
372 | if !tcx.sess.opts.debugging_opts.unsound_mir_opts { | |
f9f354fc XL |
373 | return; |
374 | } | |
375 | ||
29967ef6 | 376 | let source = body.source; |
f9f354fc | 377 | trace!("running SimplifyArmIdentity on {:?}", source); |
29967ef6 | 378 | |
f035d41b XL |
379 | let local_uses = LocalUseCounter::get_local_uses(body); |
380 | let (basic_blocks, local_decls, debug_info) = | |
381 | body.basic_blocks_local_decls_mut_and_var_debug_info(); | |
60c5eb7d | 382 | for bb in basic_blocks { |
f035d41b XL |
383 | if let Some(opt_info) = |
384 | get_arm_identity_info(&bb.statements, local_decls.len(), debug_info) | |
385 | { | |
f9f354fc | 386 | trace!("got opt_info = {:#?}", opt_info); |
f035d41b | 387 | if !optimization_applies(&opt_info, local_decls, &local_uses, &debug_info) { |
f9f354fc XL |
388 | debug!("optimization skipped for {:?}", source); |
389 | continue; | |
390 | } | |
60c5eb7d | 391 | |
f9f354fc XL |
392 | // Also remove unused Storage{Live,Dead} statements which correspond |
393 | // to temps used previously. | |
394 | for (live_idx, dead_idx, local) in &opt_info.storage_stmts { | |
395 | // The temporary that we've read the variant field into is scoped to this block, | |
396 | // so we can remove the assignment. | |
397 | if *local == opt_info.local_temp_0 { | |
398 | bb.statements[opt_info.get_variant_field_stmt].make_nop(); | |
399 | } | |
400 | ||
401 | for (left, right) in &opt_info.field_tmp_assignments { | |
402 | if local == left || local == right { | |
403 | bb.statements[*live_idx].make_nop(); | |
404 | bb.statements[*dead_idx].make_nop(); | |
405 | } | |
406 | } | |
60c5eb7d | 407 | } |
f9f354fc XL |
408 | |
409 | // Right shape; transform | |
410 | for stmt_idx in opt_info.stmts_to_remove { | |
411 | bb.statements[stmt_idx].make_nop(); | |
412 | } | |
413 | ||
414 | let stmt = &mut bb.statements[opt_info.stmt_to_overwrite]; | |
415 | stmt.source_info = opt_info.source_info; | |
416 | stmt.kind = StatementKind::Assign(box ( | |
417 | opt_info.local_0.into(), | |
418 | Rvalue::Use(Operand::Move(opt_info.local_1.into())), | |
419 | )); | |
420 | ||
421 | bb.statements.retain(|stmt| stmt.kind != StatementKind::Nop); | |
422 | ||
f035d41b XL |
423 | // Fix the debug info to point to the right local |
424 | for dbg_index in opt_info.dbg_info_to_adjust { | |
425 | let dbg_info = &mut debug_info[dbg_index]; | |
426 | assert!(dbg_info.place.projection.is_empty()); | |
427 | dbg_info.place.local = opt_info.local_0; | |
428 | dbg_info.place.projection = opt_info.dbg_projection; | |
429 | } | |
430 | ||
f9f354fc | 431 | trace!("block is now {:?}", bb.statements); |
60c5eb7d | 432 | } |
60c5eb7d XL |
433 | } |
434 | } | |
435 | } | |
436 | ||
f035d41b XL |
437 | struct LocalUseCounter { |
438 | local_uses: IndexVec<Local, usize>, | |
439 | } | |
440 | ||
441 | impl LocalUseCounter { | |
442 | fn get_local_uses<'tcx>(body: &Body<'tcx>) -> IndexVec<Local, usize> { | |
443 | let mut counter = LocalUseCounter { local_uses: IndexVec::from_elem(0, &body.local_decls) }; | |
444 | counter.visit_body(body); | |
445 | counter.local_uses | |
446 | } | |
447 | } | |
448 | ||
449 | impl<'tcx> Visitor<'tcx> for LocalUseCounter { | |
450 | fn visit_local(&mut self, local: &Local, context: PlaceContext, _location: Location) { | |
451 | if context.is_storage_marker() | |
452 | || context == PlaceContext::NonUse(NonUseContext::VarDebugInfo) | |
453 | { | |
454 | return; | |
455 | } | |
456 | ||
457 | self.local_uses[*local] += 1; | |
458 | } | |
459 | } | |
460 | ||
60c5eb7d XL |
461 | /// Match on: |
462 | /// ```rust | |
463 | /// _LOCAL_INTO = ((_LOCAL_FROM as Variant).FIELD: TY); | |
464 | /// ``` | |
f035d41b XL |
465 | fn match_get_variant_field<'tcx>( |
466 | stmt: &Statement<'tcx>, | |
467 | ) -> Option<(Local, Local, VarField<'tcx>, &'tcx List<PlaceElem<'tcx>>)> { | |
60c5eb7d | 468 | match &stmt.kind { |
3dfed10e XL |
469 | StatementKind::Assign(box ( |
470 | place_into, | |
471 | Rvalue::Use(Operand::Copy(pf) | Operand::Move(pf)), | |
472 | )) => { | |
473 | let local_into = place_into.as_local()?; | |
474 | let (local_from, vf) = match_variant_field_place(*pf)?; | |
475 | Some((local_into, local_from, vf, pf.projection)) | |
476 | } | |
60c5eb7d XL |
477 | _ => None, |
478 | } | |
479 | } | |
480 | ||
481 | /// Match on: | |
482 | /// ```rust | |
483 | /// ((_LOCAL_FROM as Variant).FIELD: TY) = move _LOCAL_INTO; | |
484 | /// ``` | |
485 | fn match_set_variant_field<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, Local, VarField<'tcx>)> { | |
486 | match &stmt.kind { | |
3dfed10e XL |
487 | StatementKind::Assign(box (place_from, Rvalue::Use(Operand::Move(place_into)))) => { |
488 | let local_into = place_into.as_local()?; | |
489 | let (local_from, vf) = match_variant_field_place(*place_from)?; | |
490 | Some((local_into, local_from, vf)) | |
491 | } | |
60c5eb7d XL |
492 | _ => None, |
493 | } | |
494 | } | |
495 | ||
496 | /// Match on: | |
497 | /// ```rust | |
498 | /// discriminant(_LOCAL_TO_SET) = VAR_IDX; | |
499 | /// ``` | |
500 | fn match_set_discr<'tcx>(stmt: &Statement<'tcx>) -> Option<(Local, VariantIdx)> { | |
501 | match &stmt.kind { | |
dfeec247 XL |
502 | StatementKind::SetDiscriminant { place, variant_index } => { |
503 | Some((place.as_local()?, *variant_index)) | |
504 | } | |
60c5eb7d XL |
505 | _ => None, |
506 | } | |
507 | } | |
508 | ||
f9f354fc | 509 | #[derive(PartialEq, Debug)] |
60c5eb7d XL |
510 | struct VarField<'tcx> { |
511 | field: Field, | |
512 | field_ty: Ty<'tcx>, | |
513 | var_idx: VariantIdx, | |
514 | } | |
515 | ||
516 | /// Match on `((_LOCAL as Variant).FIELD: TY)`. | |
ba9703b0 | 517 | fn match_variant_field_place<'tcx>(place: Place<'tcx>) -> Option<(Local, VarField<'tcx>)> { |
60c5eb7d XL |
518 | match place.as_ref() { |
519 | PlaceRef { | |
dfeec247 | 520 | local, |
60c5eb7d | 521 | projection: &[ProjectionElem::Downcast(_, var_idx), ProjectionElem::Field(field, ty)], |
74b04a01 | 522 | } => Some((local, VarField { field, field_ty: ty, var_idx })), |
60c5eb7d XL |
523 | _ => None, |
524 | } | |
525 | } | |
526 | ||
527 | /// Simplifies `SwitchInt(_) -> [targets]`, | |
528 | /// where all the `targets` have the same form, | |
529 | /// into `goto -> target_first`. | |
530 | pub struct SimplifyBranchSame; | |
531 | ||
532 | impl<'tcx> MirPass<'tcx> for SimplifyBranchSame { | |
29967ef6 XL |
533 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { |
534 | trace!("Running SimplifyBranchSame on {:?}", body.source); | |
3dfed10e XL |
535 | let finder = SimplifyBranchSameOptimizationFinder { body, tcx }; |
536 | let opts = finder.find(); | |
537 | ||
538 | let did_remove_blocks = opts.len() > 0; | |
539 | for opt in opts.iter() { | |
540 | trace!("SUCCESS: Applying optimization {:?}", opt); | |
541 | // Replace `SwitchInt(..) -> [bb_first, ..];` with a `goto -> bb_first;`. | |
542 | body.basic_blocks_mut()[opt.bb_to_opt_terminator].terminator_mut().kind = | |
543 | TerminatorKind::Goto { target: opt.bb_to_goto }; | |
544 | } | |
60c5eb7d | 545 | |
3dfed10e XL |
546 | if did_remove_blocks { |
547 | // We have dead blocks now, so remove those. | |
548 | simplify::remove_dead_blocks(body); | |
549 | } | |
550 | } | |
551 | } | |
552 | ||
553 | #[derive(Debug)] | |
554 | struct SimplifyBranchSameOptimization { | |
555 | /// All basic blocks are equal so go to this one | |
556 | bb_to_goto: BasicBlock, | |
557 | /// Basic block where the terminator can be simplified to a goto | |
558 | bb_to_opt_terminator: BasicBlock, | |
559 | } | |
560 | ||
1b1a35ee XL |
561 | struct SwitchTargetAndValue { |
562 | target: BasicBlock, | |
563 | // None in case of the `otherwise` case | |
564 | value: Option<u128>, | |
565 | } | |
566 | ||
3dfed10e XL |
567 | struct SimplifyBranchSameOptimizationFinder<'a, 'tcx> { |
568 | body: &'a Body<'tcx>, | |
569 | tcx: TyCtxt<'tcx>, | |
570 | } | |
571 | ||
572 | impl<'a, 'tcx> SimplifyBranchSameOptimizationFinder<'a, 'tcx> { | |
573 | fn find(&self) -> Vec<SimplifyBranchSameOptimization> { | |
574 | self.body | |
575 | .basic_blocks() | |
576 | .iter_enumerated() | |
577 | .filter_map(|(bb_idx, bb)| { | |
1b1a35ee | 578 | let (discr_switched_on, targets_and_values) = match &bb.terminator().kind { |
29967ef6 XL |
579 | TerminatorKind::SwitchInt { targets, discr, .. } => { |
580 | let targets_and_values: Vec<_> = targets.iter() | |
581 | .map(|(val, target)| SwitchTargetAndValue { target, value: Some(val) }) | |
582 | .chain(once(SwitchTargetAndValue { target: targets.otherwise(), value: None })) | |
1b1a35ee | 583 | .collect(); |
29967ef6 XL |
584 | (discr, targets_and_values) |
585 | }, | |
3dfed10e XL |
586 | _ => return None, |
587 | }; | |
588 | ||
589 | // find the adt that has its discriminant read | |
590 | // assuming this must be the last statement of the block | |
591 | let adt_matched_on = match &bb.statements.last()?.kind { | |
592 | StatementKind::Assign(box (place, rhs)) | |
593 | if Some(*place) == discr_switched_on.place() => | |
594 | { | |
595 | match rhs { | |
596 | Rvalue::Discriminant(adt_place) if adt_place.ty(self.body, self.tcx).ty.is_enum() => adt_place, | |
597 | _ => { | |
598 | trace!("NO: expected a discriminant read of an enum instead of: {:?}", rhs); | |
599 | return None; | |
600 | } | |
601 | } | |
602 | } | |
603 | other => { | |
604 | trace!("NO: expected an assignment of a discriminant read to a place. Found: {:?}", other); | |
605 | return None | |
606 | }, | |
607 | }; | |
608 | ||
1b1a35ee | 609 | let mut iter_bbs_reachable = targets_and_values |
3dfed10e | 610 | .iter() |
1b1a35ee | 611 | .map(|target_and_value| (target_and_value, &self.body.basic_blocks()[target_and_value.target])) |
3dfed10e XL |
612 | .filter(|(_, bb)| { |
613 | // Reaching `unreachable` is UB so assume it doesn't happen. | |
614 | bb.terminator().kind != TerminatorKind::Unreachable | |
60c5eb7d XL |
615 | // But `asm!(...)` could abort the program, |
616 | // so we cannot assume that the `unreachable` terminator itself is reachable. | |
617 | // FIXME(Centril): use a normalization pass instead of a check. | |
618 | || bb.statements.iter().any(|stmt| match stmt.kind { | |
ba9703b0 | 619 | StatementKind::LlvmInlineAsm(..) => true, |
60c5eb7d XL |
620 | _ => false, |
621 | }) | |
3dfed10e XL |
622 | }) |
623 | .peekable(); | |
624 | ||
1b1a35ee | 625 | let bb_first = iter_bbs_reachable.peek().map(|(idx, _)| *idx).unwrap_or(&targets_and_values[0]); |
3dfed10e XL |
626 | let mut all_successors_equivalent = StatementEquality::TrivialEqual; |
627 | ||
628 | // All successor basic blocks must be equal or contain statements that are pairwise considered equal. | |
1b1a35ee | 629 | for ((target_and_value_l,bb_l), (target_and_value_r,bb_r)) in iter_bbs_reachable.tuple_windows() { |
3dfed10e | 630 | let trivial_checks = bb_l.is_cleanup == bb_r.is_cleanup |
6c58768f XL |
631 | && bb_l.terminator().kind == bb_r.terminator().kind |
632 | && bb_l.statements.len() == bb_r.statements.len(); | |
3dfed10e XL |
633 | let statement_check = || { |
634 | bb_l.statements.iter().zip(&bb_r.statements).try_fold(StatementEquality::TrivialEqual, |acc,(l,r)| { | |
1b1a35ee | 635 | let stmt_equality = self.statement_equality(*adt_matched_on, &l, target_and_value_l, &r, target_and_value_r); |
3dfed10e XL |
636 | if matches!(stmt_equality, StatementEquality::NotEqual) { |
637 | // short circuit | |
638 | None | |
639 | } else { | |
640 | Some(acc.combine(&stmt_equality)) | |
641 | } | |
642 | }) | |
643 | .unwrap_or(StatementEquality::NotEqual) | |
644 | }; | |
645 | if !trivial_checks { | |
646 | all_successors_equivalent = StatementEquality::NotEqual; | |
647 | break; | |
648 | } | |
649 | all_successors_equivalent = all_successors_equivalent.combine(&statement_check()); | |
650 | }; | |
651 | ||
652 | match all_successors_equivalent{ | |
653 | StatementEquality::TrivialEqual => { | |
654 | // statements are trivially equal, so just take first | |
655 | trace!("Statements are trivially equal"); | |
656 | Some(SimplifyBranchSameOptimization { | |
1b1a35ee | 657 | bb_to_goto: bb_first.target, |
3dfed10e XL |
658 | bb_to_opt_terminator: bb_idx, |
659 | }) | |
660 | } | |
661 | StatementEquality::ConsideredEqual(bb_to_choose) => { | |
662 | trace!("Statements are considered equal"); | |
663 | Some(SimplifyBranchSameOptimization { | |
664 | bb_to_goto: bb_to_choose, | |
665 | bb_to_opt_terminator: bb_idx, | |
666 | }) | |
667 | } | |
668 | StatementEquality::NotEqual => { | |
669 | trace!("NO: not all successors of basic block {:?} were equivalent", bb_idx); | |
670 | None | |
671 | } | |
672 | } | |
673 | }) | |
674 | .collect() | |
675 | } | |
676 | ||
677 | /// Tests if two statements can be considered equal | |
678 | /// | |
679 | /// Statements can be trivially equal if the kinds match. | |
680 | /// But they can also be considered equal in the following case A: | |
681 | /// ``` | |
682 | /// discriminant(_0) = 0; // bb1 | |
683 | /// _0 = move _1; // bb2 | |
684 | /// ``` | |
685 | /// In this case the two statements are equal iff | |
686 | /// 1: _0 is an enum where the variant index 0 is fieldless, and | |
687 | /// 2: bb1 was targeted by a switch where the discriminant of _1 was switched on | |
688 | fn statement_equality( | |
689 | &self, | |
690 | adt_matched_on: Place<'tcx>, | |
691 | x: &Statement<'tcx>, | |
1b1a35ee | 692 | x_target_and_value: &SwitchTargetAndValue, |
3dfed10e | 693 | y: &Statement<'tcx>, |
1b1a35ee | 694 | y_target_and_value: &SwitchTargetAndValue, |
3dfed10e XL |
695 | ) -> StatementEquality { |
696 | let helper = |rhs: &Rvalue<'tcx>, | |
1b1a35ee | 697 | place: &Place<'tcx>, |
3dfed10e XL |
698 | variant_index: &VariantIdx, |
699 | side_to_choose| { | |
700 | let place_type = place.ty(self.body, self.tcx).ty; | |
1b1a35ee | 701 | let adt = match *place_type.kind() { |
3dfed10e XL |
702 | ty::Adt(adt, _) if adt.is_enum() => adt, |
703 | _ => return StatementEquality::NotEqual, | |
704 | }; | |
705 | let variant_is_fieldless = adt.variants[*variant_index].fields.is_empty(); | |
706 | if !variant_is_fieldless { | |
707 | trace!("NO: variant {:?} was not fieldless", variant_index); | |
708 | return StatementEquality::NotEqual; | |
709 | } | |
710 | ||
711 | match rhs { | |
712 | Rvalue::Use(operand) if operand.place() == Some(adt_matched_on) => { | |
1b1a35ee | 713 | StatementEquality::ConsideredEqual(side_to_choose) |
3dfed10e XL |
714 | } |
715 | _ => { | |
716 | trace!( | |
717 | "NO: RHS of assignment was {:?}, but expected it to match the adt being matched on in the switch, which is {:?}", | |
718 | rhs, | |
719 | adt_matched_on | |
720 | ); | |
721 | StatementEquality::NotEqual | |
722 | } | |
723 | } | |
724 | }; | |
725 | match (&x.kind, &y.kind) { | |
726 | // trivial case | |
727 | (x, y) if x == y => StatementEquality::TrivialEqual, | |
728 | ||
729 | // check for case A | |
730 | ( | |
731 | StatementKind::Assign(box (_, rhs)), | |
732 | StatementKind::SetDiscriminant { place, variant_index }, | |
1b1a35ee XL |
733 | ) |
734 | // we need to make sure that the switch value that targets the bb with SetDiscriminant (y), is the same as the variant index | |
735 | if Some(variant_index.index() as u128) == y_target_and_value.value => { | |
3dfed10e | 736 | // choose basic block of x, as that has the assign |
1b1a35ee | 737 | helper(rhs, place, variant_index, x_target_and_value.target) |
3dfed10e XL |
738 | } |
739 | ( | |
740 | StatementKind::SetDiscriminant { place, variant_index }, | |
741 | StatementKind::Assign(box (_, rhs)), | |
1b1a35ee XL |
742 | ) |
743 | // we need to make sure that the switch value that targets the bb with SetDiscriminant (x), is the same as the variant index | |
744 | if Some(variant_index.index() as u128) == x_target_and_value.value => { | |
3dfed10e | 745 | // choose basic block of y, as that has the assign |
1b1a35ee | 746 | helper(rhs, place, variant_index, y_target_and_value.target) |
3dfed10e XL |
747 | } |
748 | _ => { | |
749 | trace!("NO: statements `{:?}` and `{:?}` not considered equal", x, y); | |
750 | StatementEquality::NotEqual | |
60c5eb7d XL |
751 | } |
752 | } | |
3dfed10e XL |
753 | } |
754 | } | |
60c5eb7d | 755 | |
3dfed10e XL |
756 | #[derive(Copy, Clone, Eq, PartialEq)] |
757 | enum StatementEquality { | |
758 | /// The two statements are trivially equal; same kind | |
759 | TrivialEqual, | |
760 | /// The two statements are considered equal, but may be of different kinds. The BasicBlock field is the basic block to jump to when performing the branch-same optimization. | |
761 | /// For example, `_0 = _1` and `discriminant(_0) = discriminant(0)` are considered equal if 0 is a fieldless variant of an enum. But we don't want to jump to the basic block with the SetDiscriminant, as that is not legal if _1 is not the 0 variant index | |
762 | ConsideredEqual(BasicBlock), | |
763 | /// The two statements are not equal | |
764 | NotEqual, | |
765 | } | |
766 | ||
767 | impl StatementEquality { | |
768 | fn combine(&self, other: &StatementEquality) -> StatementEquality { | |
769 | use StatementEquality::*; | |
770 | match (self, other) { | |
771 | (TrivialEqual, TrivialEqual) => TrivialEqual, | |
772 | (TrivialEqual, ConsideredEqual(b)) | (ConsideredEqual(b), TrivialEqual) => { | |
773 | ConsideredEqual(*b) | |
774 | } | |
775 | (ConsideredEqual(b1), ConsideredEqual(b2)) => { | |
776 | if b1 == b2 { | |
777 | ConsideredEqual(*b1) | |
778 | } else { | |
779 | NotEqual | |
780 | } | |
781 | } | |
782 | (_, NotEqual) | (NotEqual, _) => NotEqual, | |
60c5eb7d XL |
783 | } |
784 | } | |
785 | } |