1 // Copyright 2016 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
12 use rustc
::ty
::TyCtxt
;
14 use rustc_data_structures
::indexed_vec
::Idx
;
15 use transform
::{MirPass, MirSource}
;
17 pub struct Deaggregator
;
19 impl MirPass
for Deaggregator
{
20 fn run_pass
<'a
, 'tcx
>(&self,
21 tcx
: TyCtxt
<'a
, 'tcx
, 'tcx
>,
23 mir
: &mut Mir
<'tcx
>) {
24 let node_path
= tcx
.item_path_str(source
.def_id
);
25 debug
!("running on: {:?}", node_path
);
26 // we only run when mir_opt_level > 2
27 if tcx
.sess
.opts
.debugging_opts
.mir_opt_level
<= 2 {
31 // Don't run on constant MIR, because trans might not be able to
32 // evaluate the modified MIR.
33 // FIXME(eddyb) Remove check after miri is merged.
34 let id
= tcx
.hir
.as_local_node_id(source
.def_id
).unwrap();
35 match (tcx
.hir
.body_owner_kind(id
), source
.promoted
) {
36 (hir
::BodyOwnerKind
::Fn
, None
) => {}
,
39 // In fact, we might not want to trigger in other cases.
40 // Ex: when we could use SROA. See issue #35259
42 for bb
in mir
.basic_blocks_mut() {
43 let mut curr
: usize = 0;
44 while let Some(idx
) = get_aggregate_statement_index(curr
, &bb
.statements
) {
46 debug
!("removing statement {:?}", idx
);
47 let src_info
= bb
.statements
[idx
].source_info
;
48 let suffix_stmts
= bb
.statements
.split_off(idx
+1);
49 let orig_stmt
= bb
.statements
.pop().unwrap();
50 let (lhs
, rhs
) = match orig_stmt
.kind
{
51 StatementKind
::Assign(ref lhs
, ref rhs
) => (lhs
, rhs
),
52 _
=> span_bug
!(src_info
.span
, "expected assign, not {:?}", orig_stmt
),
54 let (agg_kind
, operands
) = match rhs
{
55 &Rvalue
::Aggregate(ref agg_kind
, ref operands
) => (agg_kind
, operands
),
56 _
=> span_bug
!(src_info
.span
, "expected aggregate, not {:?}", rhs
),
58 let (adt_def
, variant
, substs
) = match **agg_kind
{
59 AggregateKind
::Adt(adt_def
, variant
, substs
, None
)
60 => (adt_def
, variant
, substs
),
61 _
=> span_bug
!(src_info
.span
, "expected struct, not {:?}", rhs
),
63 let n
= bb
.statements
.len();
64 bb
.statements
.reserve(n
+ operands
.len() + suffix_stmts
.len());
65 for (i
, op
) in operands
.iter().enumerate() {
66 let ref variant_def
= adt_def
.variants
[variant
];
67 let ty
= variant_def
.fields
[i
].ty(tcx
, substs
);
68 let rhs
= Rvalue
::Use(op
.clone());
70 let lhs_cast
= if adt_def
.variants
.len() > 1 {
71 Lvalue
::Projection(Box
::new(LvalueProjection
{
73 elem
: ProjectionElem
::Downcast(adt_def
, variant
),
79 let lhs_proj
= Lvalue
::Projection(Box
::new(LvalueProjection
{
81 elem
: ProjectionElem
::Field(Field
::new(i
), ty
),
83 let new_statement
= Statement
{
84 source_info
: src_info
,
85 kind
: StatementKind
::Assign(lhs_proj
, rhs
),
87 debug
!("inserting: {:?} @ {:?}", new_statement
, idx
+ i
);
88 bb
.statements
.push(new_statement
);
91 // if the aggregate was an enum, we need to set the discriminant
92 if adt_def
.variants
.len() > 1 {
93 let set_discriminant
= Statement
{
94 kind
: StatementKind
::SetDiscriminant
{
96 variant_index
: variant
,
98 source_info
: src_info
,
100 bb
.statements
.push(set_discriminant
);
103 curr
= bb
.statements
.len();
104 bb
.statements
.extend(suffix_stmts
);
110 fn get_aggregate_statement_index
<'a
, 'tcx
, 'b
>(start
: usize,
111 statements
: &Vec
<Statement
<'tcx
>>)
113 for i
in start
..statements
.len() {
114 let ref statement
= statements
[i
];
115 let rhs
= match statement
.kind
{
116 StatementKind
::Assign(_
, ref rhs
) => rhs
,
119 let (kind
, operands
) = match rhs
{
120 &Rvalue
::Aggregate(ref kind
, ref operands
) => (kind
, operands
),
123 let (adt_def
, variant
) = match **kind
{
124 AggregateKind
::Adt(adt_def
, variant
, _
, None
) => (adt_def
, variant
),
127 if operands
.len() == 0 {
128 // don't deaggregate ()
131 debug
!("getting variant {:?}", variant
);
132 debug
!("for adt_def {:?}", adt_def
);