]> git.proxmox.com Git - rustc.git/blob - compiler/rustc_mir_transform/src/sroa.rs
New upstream version 1.69.0+dfsg1
[rustc.git] / compiler / rustc_mir_transform / src / sroa.rs
1 use crate::MirPass;
2 use rustc_index::bit_set::{BitSet, GrowableBitSet};
3 use rustc_index::vec::IndexVec;
4 use rustc_middle::mir::patch::MirPatch;
5 use rustc_middle::mir::visit::*;
6 use rustc_middle::mir::*;
7 use rustc_middle::ty::{Ty, TyCtxt};
8 use rustc_mir_dataflow::value_analysis::{excluded_locals, iter_fields};
9
10 pub struct ScalarReplacementOfAggregates;
11
12 impl<'tcx> MirPass<'tcx> for ScalarReplacementOfAggregates {
13 fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
14 sess.mir_opt_level() >= 3
15 }
16
17 #[instrument(level = "debug", skip(self, tcx, body))]
18 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
19 debug!(def_id = ?body.source.def_id());
20 let mut excluded = excluded_locals(body);
21 loop {
22 debug!(?excluded);
23 let escaping = escaping_locals(&excluded, body);
24 debug!(?escaping);
25 let replacements = compute_flattening(tcx, body, escaping);
26 debug!(?replacements);
27 let all_dead_locals = replace_flattened_locals(tcx, body, replacements);
28 if !all_dead_locals.is_empty() {
29 excluded.union(&all_dead_locals);
30 excluded = {
31 let mut growable = GrowableBitSet::from(excluded);
32 growable.ensure(body.local_decls.len());
33 growable.into()
34 };
35 } else {
36 break;
37 }
38 }
39 }
40 }
41
42 /// Identify all locals that are not eligible for SROA.
43 ///
44 /// There are 3 cases:
45 /// - the aggregated local is used or passed to other code (function parameters and arguments);
46 /// - the locals is a union or an enum;
47 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
48 /// client code.
49 fn escaping_locals(excluded: &BitSet<Local>, body: &Body<'_>) -> BitSet<Local> {
50 let mut set = BitSet::new_empty(body.local_decls.len());
51 set.insert_range(RETURN_PLACE..=Local::from_usize(body.arg_count));
52 for (local, decl) in body.local_decls().iter_enumerated() {
53 if decl.ty.is_union() || decl.ty.is_enum() || excluded.contains(local) {
54 set.insert(local);
55 }
56 }
57 let mut visitor = EscapeVisitor { set };
58 visitor.visit_body(body);
59 return visitor.set;
60
61 struct EscapeVisitor {
62 set: BitSet<Local>,
63 }
64
65 impl<'tcx> Visitor<'tcx> for EscapeVisitor {
66 fn visit_local(&mut self, local: Local, _: PlaceContext, _: Location) {
67 self.set.insert(local);
68 }
69
70 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
71 // Mirror the implementation in PreFlattenVisitor.
72 if let &[PlaceElem::Field(..), ..] = &place.projection[..] {
73 return;
74 }
75 self.super_place(place, context, location);
76 }
77
78 fn visit_assign(
79 &mut self,
80 lvalue: &Place<'tcx>,
81 rvalue: &Rvalue<'tcx>,
82 location: Location,
83 ) {
84 if lvalue.as_local().is_some() {
85 match rvalue {
86 // Aggregate assignments are expanded in run_pass.
87 Rvalue::Aggregate(..) | Rvalue::Use(..) => {
88 self.visit_rvalue(rvalue, location);
89 return;
90 }
91 _ => {}
92 }
93 }
94 self.super_assign(lvalue, rvalue, location)
95 }
96
97 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
98 match statement.kind {
99 // Storage statements are expanded in run_pass.
100 StatementKind::StorageLive(..)
101 | StatementKind::StorageDead(..)
102 | StatementKind::Deinit(..) => return,
103 _ => self.super_statement(statement, location),
104 }
105 }
106
107 // We ignore anything that happens in debuginfo, since we expand it using
108 // `VarDebugInfoContents::Composite`.
109 fn visit_var_debug_info(&mut self, _: &VarDebugInfo<'tcx>) {}
110 }
111 }
112
113 #[derive(Default, Debug)]
114 struct ReplacementMap<'tcx> {
115 /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
116 /// and deinit statement and debuginfo.
117 fragments: IndexVec<Local, Option<IndexVec<Field, Option<(Ty<'tcx>, Local)>>>>,
118 }
119
120 impl<'tcx> ReplacementMap<'tcx> {
121 fn replace_place(&self, tcx: TyCtxt<'tcx>, place: PlaceRef<'tcx>) -> Option<Place<'tcx>> {
122 let &[PlaceElem::Field(f, _), ref rest @ ..] = place.projection else { return None; };
123 let fields = self.fragments[place.local].as_ref()?;
124 let (_, new_local) = fields[f]?;
125 Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) })
126 }
127
128 fn place_fragments(
129 &self,
130 place: Place<'tcx>,
131 ) -> Option<impl Iterator<Item = (Field, Ty<'tcx>, Local)> + '_> {
132 let local = place.as_local()?;
133 let fields = self.fragments[local].as_ref()?;
134 Some(fields.iter_enumerated().filter_map(|(field, &opt_ty_local)| {
135 let (ty, local) = opt_ty_local?;
136 Some((field, ty, local))
137 }))
138 }
139 }
140
141 /// Compute the replacement of flattened places into locals.
142 ///
143 /// For each eligible place, we assign a new local to each accessed field.
144 /// The replacement will be done later in `ReplacementVisitor`.
145 fn compute_flattening<'tcx>(
146 tcx: TyCtxt<'tcx>,
147 body: &mut Body<'tcx>,
148 escaping: BitSet<Local>,
149 ) -> ReplacementMap<'tcx> {
150 let mut fragments = IndexVec::from_elem(None, &body.local_decls);
151
152 for local in body.local_decls.indices() {
153 if escaping.contains(local) {
154 continue;
155 }
156 let decl = body.local_decls[local].clone();
157 let ty = decl.ty;
158 iter_fields(ty, tcx, |variant, field, field_ty| {
159 if variant.is_some() {
160 // Downcasts are currently not supported.
161 return;
162 };
163 let new_local =
164 body.local_decls.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() });
165 fragments.get_or_insert_with(local, IndexVec::new).insert(field, (field_ty, new_local));
166 });
167 }
168 ReplacementMap { fragments }
169 }
170
171 /// Perform the replacement computed by `compute_flattening`.
172 fn replace_flattened_locals<'tcx>(
173 tcx: TyCtxt<'tcx>,
174 body: &mut Body<'tcx>,
175 replacements: ReplacementMap<'tcx>,
176 ) -> BitSet<Local> {
177 let mut all_dead_locals = BitSet::new_empty(replacements.fragments.len());
178 for (local, replacements) in replacements.fragments.iter_enumerated() {
179 if replacements.is_some() {
180 all_dead_locals.insert(local);
181 }
182 }
183 debug!(?all_dead_locals);
184 if all_dead_locals.is_empty() {
185 return all_dead_locals;
186 }
187
188 let mut visitor = ReplacementVisitor {
189 tcx,
190 local_decls: &body.local_decls,
191 replacements: &replacements,
192 all_dead_locals,
193 patch: MirPatch::new(body),
194 };
195 for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
196 visitor.visit_basic_block_data(bb, data);
197 }
198 for scope in &mut body.source_scopes {
199 visitor.visit_source_scope_data(scope);
200 }
201 for (index, annotation) in body.user_type_annotations.iter_enumerated_mut() {
202 visitor.visit_user_type_annotation(index, annotation);
203 }
204 for var_debug_info in &mut body.var_debug_info {
205 visitor.visit_var_debug_info(var_debug_info);
206 }
207 let ReplacementVisitor { patch, all_dead_locals, .. } = visitor;
208 patch.apply(body);
209 all_dead_locals
210 }
211
212 struct ReplacementVisitor<'tcx, 'll> {
213 tcx: TyCtxt<'tcx>,
214 /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
215 local_decls: &'ll LocalDecls<'tcx>,
216 /// Work to do.
217 replacements: &'ll ReplacementMap<'tcx>,
218 /// This is used to check that we are not leaving references to replaced locals behind.
219 all_dead_locals: BitSet<Local>,
220 patch: MirPatch<'tcx>,
221 }
222
223 impl<'tcx> ReplacementVisitor<'tcx, '_> {
224 fn gather_debug_info_fragments(&self, local: Local) -> Option<Vec<VarDebugInfoFragment<'tcx>>> {
225 let mut fragments = Vec::new();
226 let parts = self.replacements.place_fragments(local.into())?;
227 for (field, ty, replacement_local) in parts {
228 fragments.push(VarDebugInfoFragment {
229 projection: vec![PlaceElem::Field(field, ty)],
230 contents: Place::from(replacement_local),
231 });
232 }
233 Some(fragments)
234 }
235 }
236
237 impl<'tcx, 'll> MutVisitor<'tcx> for ReplacementVisitor<'tcx, 'll> {
238 fn tcx(&self) -> TyCtxt<'tcx> {
239 self.tcx
240 }
241
242 fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
243 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
244 *place = repl
245 } else {
246 self.super_place(place, context, location)
247 }
248 }
249
250 #[instrument(level = "trace", skip(self))]
251 fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) {
252 match statement.kind {
253 // Duplicate storage and deinit statements, as they pretty much apply to all fields.
254 StatementKind::StorageLive(l) => {
255 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
256 for (_, _, fl) in final_locals {
257 self.patch.add_statement(location, StatementKind::StorageLive(fl));
258 }
259 statement.make_nop();
260 }
261 return;
262 }
263 StatementKind::StorageDead(l) => {
264 if let Some(final_locals) = self.replacements.place_fragments(l.into()) {
265 for (_, _, fl) in final_locals {
266 self.patch.add_statement(location, StatementKind::StorageDead(fl));
267 }
268 statement.make_nop();
269 }
270 return;
271 }
272 StatementKind::Deinit(box place) => {
273 if let Some(final_locals) = self.replacements.place_fragments(place) {
274 for (_, _, fl) in final_locals {
275 self.patch
276 .add_statement(location, StatementKind::Deinit(Box::new(fl.into())));
277 }
278 statement.make_nop();
279 return;
280 }
281 }
282
283 // We have `a = Struct { 0: x, 1: y, .. }`.
284 // We replace it by
285 // ```
286 // a_0 = x
287 // a_1 = y
288 // ...
289 // ```
290 StatementKind::Assign(box (place, Rvalue::Aggregate(_, ref mut operands))) => {
291 if let Some(local) = place.as_local()
292 && let Some(final_locals) = &self.replacements.fragments[local]
293 {
294 // This is ok as we delete the statement later.
295 let operands = std::mem::take(operands);
296 for (&opt_ty_local, mut operand) in final_locals.iter().zip(operands) {
297 if let Some((_, new_local)) = opt_ty_local {
298 // Replace mentions of SROA'd locals that appear in the operand.
299 self.visit_operand(&mut operand, location);
300
301 let rvalue = Rvalue::Use(operand);
302 self.patch.add_statement(
303 location,
304 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
305 );
306 }
307 }
308 statement.make_nop();
309 return;
310 }
311 }
312
313 // We have `a = some constant`
314 // We add the projections.
315 // ```
316 // a_0 = a.0
317 // a_1 = a.1
318 // ...
319 // ```
320 // ConstProp will pick up the pieces and replace them by actual constants.
321 StatementKind::Assign(box (place, Rvalue::Use(Operand::Constant(_)))) => {
322 if let Some(final_locals) = self.replacements.place_fragments(place) {
323 // Put the deaggregated statements *after* the original one.
324 let location = location.successor_within_block();
325 for (field, ty, new_local) in final_locals {
326 let rplace = self.tcx.mk_place_field(place, field, ty);
327 let rvalue = Rvalue::Use(Operand::Move(rplace));
328 self.patch.add_statement(
329 location,
330 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
331 );
332 }
333 // We still need `place.local` to exist, so don't make it nop.
334 return;
335 }
336 }
337
338 // We have `a = move? place`
339 // We replace it by
340 // ```
341 // a_0 = move? place.0
342 // a_1 = move? place.1
343 // ...
344 // ```
345 StatementKind::Assign(box (lhs, Rvalue::Use(ref op))) => {
346 let (rplace, copy) = match *op {
347 Operand::Copy(rplace) => (rplace, true),
348 Operand::Move(rplace) => (rplace, false),
349 Operand::Constant(_) => bug!(),
350 };
351 if let Some(final_locals) = self.replacements.place_fragments(lhs) {
352 for (field, ty, new_local) in final_locals {
353 let rplace = self.tcx.mk_place_field(rplace, field, ty);
354 debug!(?rplace);
355 let rplace = self
356 .replacements
357 .replace_place(self.tcx, rplace.as_ref())
358 .unwrap_or(rplace);
359 debug!(?rplace);
360 let rvalue = if copy {
361 Rvalue::Use(Operand::Copy(rplace))
362 } else {
363 Rvalue::Use(Operand::Move(rplace))
364 };
365 self.patch.add_statement(
366 location,
367 StatementKind::Assign(Box::new((new_local.into(), rvalue))),
368 );
369 }
370 statement.make_nop();
371 return;
372 }
373 }
374
375 _ => {}
376 }
377 self.super_statement(statement, location)
378 }
379
380 #[instrument(level = "trace", skip(self))]
381 fn visit_var_debug_info(&mut self, var_debug_info: &mut VarDebugInfo<'tcx>) {
382 match &mut var_debug_info.value {
383 VarDebugInfoContents::Place(ref mut place) => {
384 if let Some(repl) = self.replacements.replace_place(self.tcx, place.as_ref()) {
385 *place = repl;
386 } else if let Some(local) = place.as_local()
387 && let Some(fragments) = self.gather_debug_info_fragments(local)
388 {
389 let ty = place.ty(self.local_decls, self.tcx).ty;
390 var_debug_info.value = VarDebugInfoContents::Composite { ty, fragments };
391 }
392 }
393 VarDebugInfoContents::Composite { ty: _, ref mut fragments } => {
394 let mut new_fragments = Vec::new();
395 debug!(?fragments);
396 fragments
397 .drain_filter(|fragment| {
398 if let Some(repl) =
399 self.replacements.replace_place(self.tcx, fragment.contents.as_ref())
400 {
401 fragment.contents = repl;
402 false
403 } else if let Some(local) = fragment.contents.as_local()
404 && let Some(frg) = self.gather_debug_info_fragments(local)
405 {
406 new_fragments.extend(frg.into_iter().map(|mut f| {
407 f.projection.splice(0..0, fragment.projection.iter().copied());
408 f
409 }));
410 true
411 } else {
412 false
413 }
414 })
415 .for_each(drop);
416 debug!(?fragments);
417 debug!(?new_fragments);
418 fragments.extend(new_fragments);
419 }
420 VarDebugInfoContents::Const(_) => {}
421 }
422 }
423
424 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
425 assert!(!self.all_dead_locals.contains(*local));
426 }
427 }