2 use rustc_index
::bit_set
::{BitSet, GrowableBitSet}
;
3 use rustc_index
::vec
::IndexVec
;
4 use rustc_middle
::mir
::patch
::MirPatch
;
5 use rustc_middle
::mir
::visit
::*;
6 use rustc_middle
::mir
::*;
7 use rustc_middle
::ty
::{Ty, TyCtxt}
;
8 use rustc_mir_dataflow
::value_analysis
::{excluded_locals, iter_fields}
;
10 pub struct ScalarReplacementOfAggregates
;
12 impl<'tcx
> MirPass
<'tcx
> for ScalarReplacementOfAggregates
{
13 fn is_enabled(&self, sess
: &rustc_session
::Session
) -> bool
{
14 sess
.mir_opt_level() >= 3
17 #[instrument(level = "debug", skip(self, tcx, body))]
18 fn run_pass(&self, tcx
: TyCtxt
<'tcx
>, body
: &mut Body
<'tcx
>) {
19 debug
!(def_id
= ?body
.source
.def_id());
20 let mut excluded
= excluded_locals(body
);
23 let escaping
= escaping_locals(&excluded
, body
);
25 let replacements
= compute_flattening(tcx
, body
, escaping
);
26 debug
!(?replacements
);
27 let all_dead_locals
= replace_flattened_locals(tcx
, body
, replacements
);
28 if !all_dead_locals
.is_empty() {
29 excluded
.union(&all_dead_locals
);
31 let mut growable
= GrowableBitSet
::from(excluded
);
32 growable
.ensure(body
.local_decls
.len());
42 /// Identify all locals that are not eligible for SROA.
44 /// There are 3 cases:
45 /// - the aggregated local is used or passed to other code (function parameters and arguments);
46 /// - the locals is a union or an enum;
47 /// - the local's address is taken, and thus the relative addresses of the fields are observable to
49 fn escaping_locals(excluded
: &BitSet
<Local
>, body
: &Body
<'_
>) -> BitSet
<Local
> {
50 let mut set
= BitSet
::new_empty(body
.local_decls
.len());
51 set
.insert_range(RETURN_PLACE
..=Local
::from_usize(body
.arg_count
));
52 for (local
, decl
) in body
.local_decls().iter_enumerated() {
53 if decl
.ty
.is_union() || decl
.ty
.is_enum() || excluded
.contains(local
) {
57 let mut visitor
= EscapeVisitor { set }
;
58 visitor
.visit_body(body
);
61 struct EscapeVisitor
{
65 impl<'tcx
> Visitor
<'tcx
> for EscapeVisitor
{
66 fn visit_local(&mut self, local
: Local
, _
: PlaceContext
, _
: Location
) {
67 self.set
.insert(local
);
70 fn visit_place(&mut self, place
: &Place
<'tcx
>, context
: PlaceContext
, location
: Location
) {
71 // Mirror the implementation in PreFlattenVisitor.
72 if let &[PlaceElem
::Field(..), ..] = &place
.projection
[..] {
75 self.super_place(place
, context
, location
);
81 rvalue
: &Rvalue
<'tcx
>,
84 if lvalue
.as_local().is_some() {
86 // Aggregate assignments are expanded in run_pass.
87 Rvalue
::Aggregate(..) | Rvalue
::Use(..) => {
88 self.visit_rvalue(rvalue
, location
);
94 self.super_assign(lvalue
, rvalue
, location
)
97 fn visit_statement(&mut self, statement
: &Statement
<'tcx
>, location
: Location
) {
98 match statement
.kind
{
99 // Storage statements are expanded in run_pass.
100 StatementKind
::StorageLive(..)
101 | StatementKind
::StorageDead(..)
102 | StatementKind
::Deinit(..) => return,
103 _
=> self.super_statement(statement
, location
),
107 // We ignore anything that happens in debuginfo, since we expand it using
108 // `VarDebugInfoContents::Composite`.
109 fn visit_var_debug_info(&mut self, _
: &VarDebugInfo
<'tcx
>) {}
113 #[derive(Default, Debug)]
114 struct ReplacementMap
<'tcx
> {
115 /// Pre-computed list of all "new" locals for each "old" local. This is used to expand storage
116 /// and deinit statement and debuginfo.
117 fragments
: IndexVec
<Local
, Option
<IndexVec
<Field
, Option
<(Ty
<'tcx
>, Local
)>>>>,
120 impl<'tcx
> ReplacementMap
<'tcx
> {
121 fn replace_place(&self, tcx
: TyCtxt
<'tcx
>, place
: PlaceRef
<'tcx
>) -> Option
<Place
<'tcx
>> {
122 let &[PlaceElem
::Field(f
, _
), ref rest @
..] = place
.projection
else { return None; }
;
123 let fields
= self.fragments
[place
.local
].as_ref()?
;
124 let (_
, new_local
) = fields
[f
]?
;
125 Some(Place { local: new_local, projection: tcx.mk_place_elems(&rest) }
)
131 ) -> Option
<impl Iterator
<Item
= (Field
, Ty
<'tcx
>, Local
)> + '_
> {
132 let local
= place
.as_local()?
;
133 let fields
= self.fragments
[local
].as_ref()?
;
134 Some(fields
.iter_enumerated().filter_map(|(field
, &opt_ty_local
)| {
135 let (ty
, local
) = opt_ty_local?
;
136 Some((field
, ty
, local
))
141 /// Compute the replacement of flattened places into locals.
143 /// For each eligible place, we assign a new local to each accessed field.
144 /// The replacement will be done later in `ReplacementVisitor`.
145 fn compute_flattening
<'tcx
>(
147 body
: &mut Body
<'tcx
>,
148 escaping
: BitSet
<Local
>,
149 ) -> ReplacementMap
<'tcx
> {
150 let mut fragments
= IndexVec
::from_elem(None
, &body
.local_decls
);
152 for local
in body
.local_decls
.indices() {
153 if escaping
.contains(local
) {
156 let decl
= body
.local_decls
[local
].clone();
158 iter_fields(ty
, tcx
, |variant
, field
, field_ty
| {
159 if variant
.is_some() {
160 // Downcasts are currently not supported.
164 body
.local_decls
.push(LocalDecl { ty: field_ty, user_ty: None, ..decl.clone() }
);
165 fragments
.get_or_insert_with(local
, IndexVec
::new
).insert(field
, (field_ty
, new_local
));
168 ReplacementMap { fragments }
171 /// Perform the replacement computed by `compute_flattening`.
172 fn replace_flattened_locals
<'tcx
>(
174 body
: &mut Body
<'tcx
>,
175 replacements
: ReplacementMap
<'tcx
>,
177 let mut all_dead_locals
= BitSet
::new_empty(replacements
.fragments
.len());
178 for (local
, replacements
) in replacements
.fragments
.iter_enumerated() {
179 if replacements
.is_some() {
180 all_dead_locals
.insert(local
);
183 debug
!(?all_dead_locals
);
184 if all_dead_locals
.is_empty() {
185 return all_dead_locals
;
188 let mut visitor
= ReplacementVisitor
{
190 local_decls
: &body
.local_decls
,
191 replacements
: &replacements
,
193 patch
: MirPatch
::new(body
),
195 for (bb
, data
) in body
.basic_blocks
.as_mut_preserves_cfg().iter_enumerated_mut() {
196 visitor
.visit_basic_block_data(bb
, data
);
198 for scope
in &mut body
.source_scopes
{
199 visitor
.visit_source_scope_data(scope
);
201 for (index
, annotation
) in body
.user_type_annotations
.iter_enumerated_mut() {
202 visitor
.visit_user_type_annotation(index
, annotation
);
204 for var_debug_info
in &mut body
.var_debug_info
{
205 visitor
.visit_var_debug_info(var_debug_info
);
207 let ReplacementVisitor { patch, all_dead_locals, .. }
= visitor
;
212 struct ReplacementVisitor
<'tcx
, 'll
> {
214 /// This is only used to compute the type for `VarDebugInfoContents::Composite`.
215 local_decls
: &'ll LocalDecls
<'tcx
>,
217 replacements
: &'ll ReplacementMap
<'tcx
>,
218 /// This is used to check that we are not leaving references to replaced locals behind.
219 all_dead_locals
: BitSet
<Local
>,
220 patch
: MirPatch
<'tcx
>,
223 impl<'tcx
> ReplacementVisitor
<'tcx
, '_
> {
224 fn gather_debug_info_fragments(&self, local
: Local
) -> Option
<Vec
<VarDebugInfoFragment
<'tcx
>>> {
225 let mut fragments
= Vec
::new();
226 let parts
= self.replacements
.place_fragments(local
.into())?
;
227 for (field
, ty
, replacement_local
) in parts
{
228 fragments
.push(VarDebugInfoFragment
{
229 projection
: vec
![PlaceElem
::Field(field
, ty
)],
230 contents
: Place
::from(replacement_local
),
237 impl<'tcx
, 'll
> MutVisitor
<'tcx
> for ReplacementVisitor
<'tcx
, 'll
> {
238 fn tcx(&self) -> TyCtxt
<'tcx
> {
242 fn visit_place(&mut self, place
: &mut Place
<'tcx
>, context
: PlaceContext
, location
: Location
) {
243 if let Some(repl
) = self.replacements
.replace_place(self.tcx
, place
.as_ref()) {
246 self.super_place(place
, context
, location
)
250 #[instrument(level = "trace", skip(self))]
251 fn visit_statement(&mut self, statement
: &mut Statement
<'tcx
>, location
: Location
) {
252 match statement
.kind
{
253 // Duplicate storage and deinit statements, as they pretty much apply to all fields.
254 StatementKind
::StorageLive(l
) => {
255 if let Some(final_locals
) = self.replacements
.place_fragments(l
.into()) {
256 for (_
, _
, fl
) in final_locals
{
257 self.patch
.add_statement(location
, StatementKind
::StorageLive(fl
));
259 statement
.make_nop();
263 StatementKind
::StorageDead(l
) => {
264 if let Some(final_locals
) = self.replacements
.place_fragments(l
.into()) {
265 for (_
, _
, fl
) in final_locals
{
266 self.patch
.add_statement(location
, StatementKind
::StorageDead(fl
));
268 statement
.make_nop();
272 StatementKind
::Deinit(box place
) => {
273 if let Some(final_locals
) = self.replacements
.place_fragments(place
) {
274 for (_
, _
, fl
) in final_locals
{
276 .add_statement(location
, StatementKind
::Deinit(Box
::new(fl
.into())));
278 statement
.make_nop();
283 // We have `a = Struct { 0: x, 1: y, .. }`.
290 StatementKind
::Assign(box (place
, Rvalue
::Aggregate(_
, ref mut operands
))) => {
291 if let Some(local
) = place
.as_local()
292 && let Some(final_locals
) = &self.replacements
.fragments
[local
]
294 // This is ok as we delete the statement later.
295 let operands
= std
::mem
::take(operands
);
296 for (&opt_ty_local
, mut operand
) in final_locals
.iter().zip(operands
) {
297 if let Some((_
, new_local
)) = opt_ty_local
{
298 // Replace mentions of SROA'd locals that appear in the operand.
299 self.visit_operand(&mut operand
, location
);
301 let rvalue
= Rvalue
::Use(operand
);
302 self.patch
.add_statement(
304 StatementKind
::Assign(Box
::new((new_local
.into(), rvalue
))),
308 statement
.make_nop();
313 // We have `a = some constant`
314 // We add the projections.
320 // ConstProp will pick up the pieces and replace them by actual constants.
321 StatementKind
::Assign(box (place
, Rvalue
::Use(Operand
::Constant(_
)))) => {
322 if let Some(final_locals
) = self.replacements
.place_fragments(place
) {
323 // Put the deaggregated statements *after* the original one.
324 let location
= location
.successor_within_block();
325 for (field
, ty
, new_local
) in final_locals
{
326 let rplace
= self.tcx
.mk_place_field(place
, field
, ty
);
327 let rvalue
= Rvalue
::Use(Operand
::Move(rplace
));
328 self.patch
.add_statement(
330 StatementKind
::Assign(Box
::new((new_local
.into(), rvalue
))),
333 // We still need `place.local` to exist, so don't make it nop.
338 // We have `a = move? place`
341 // a_0 = move? place.0
342 // a_1 = move? place.1
345 StatementKind
::Assign(box (lhs
, Rvalue
::Use(ref op
))) => {
346 let (rplace
, copy
) = match *op
{
347 Operand
::Copy(rplace
) => (rplace
, true),
348 Operand
::Move(rplace
) => (rplace
, false),
349 Operand
::Constant(_
) => bug
!(),
351 if let Some(final_locals
) = self.replacements
.place_fragments(lhs
) {
352 for (field
, ty
, new_local
) in final_locals
{
353 let rplace
= self.tcx
.mk_place_field(rplace
, field
, ty
);
357 .replace_place(self.tcx
, rplace
.as_ref())
360 let rvalue
= if copy
{
361 Rvalue
::Use(Operand
::Copy(rplace
))
363 Rvalue
::Use(Operand
::Move(rplace
))
365 self.patch
.add_statement(
367 StatementKind
::Assign(Box
::new((new_local
.into(), rvalue
))),
370 statement
.make_nop();
377 self.super_statement(statement
, location
)
380 #[instrument(level = "trace", skip(self))]
381 fn visit_var_debug_info(&mut self, var_debug_info
: &mut VarDebugInfo
<'tcx
>) {
382 match &mut var_debug_info
.value
{
383 VarDebugInfoContents
::Place(ref mut place
) => {
384 if let Some(repl
) = self.replacements
.replace_place(self.tcx
, place
.as_ref()) {
386 } else if let Some(local
) = place
.as_local()
387 && let Some(fragments
) = self.gather_debug_info_fragments(local
)
389 let ty
= place
.ty(self.local_decls
, self.tcx
).ty
;
390 var_debug_info
.value
= VarDebugInfoContents
::Composite { ty, fragments }
;
393 VarDebugInfoContents
::Composite { ty: _, ref mut fragments }
=> {
394 let mut new_fragments
= Vec
::new();
397 .drain_filter(|fragment
| {
399 self.replacements
.replace_place(self.tcx
, fragment
.contents
.as_ref())
401 fragment
.contents
= repl
;
403 } else if let Some(local
) = fragment
.contents
.as_local()
404 && let Some(frg
) = self.gather_debug_info_fragments(local
)
406 new_fragments
.extend(frg
.into_iter().map(|mut f
| {
407 f
.projection
.splice(0..0, fragment
.projection
.iter().copied());
417 debug
!(?new_fragments
);
418 fragments
.extend(new_fragments
);
420 VarDebugInfoContents
::Const(_
) => {}
424 fn visit_local(&mut self, local
: &mut Local
, _
: PlaceContext
, _
: Location
) {
425 assert
!(!self.all_dead_locals
.contains(*local
));