]> git.proxmox.com Git - rustc.git/blame - compiler/rustc_mir_transform/src/large_enums.rs
New upstream version 1.70.0+dfsg1
[rustc.git] / compiler / rustc_mir_transform / src / large_enums.rs
CommitLineData
9ffffee4
FG
1use crate::rustc_middle::ty::util::IntTypeExt;
2use crate::MirPass;
3use rustc_data_structures::fx::FxHashMap;
4use rustc_middle::mir::interpret::AllocId;
5use rustc_middle::mir::*;
6use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt};
7use rustc_session::Session;
8use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants};
9
10/// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large
11/// enough discrepancy between them.
12///
13/// i.e. If there is are two variants:
14/// ```
15/// enum Example {
16/// Small,
17/// Large([u32; 1024]),
18/// }
19/// ```
20/// Instead of emitting moves of the large variant,
21/// Perform a memcpy instead.
22/// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD).
23///
24/// In summary, what this does is at runtime determine which enum variant is active,
25/// and instead of copying all the bytes of the largest possible variant,
26/// copy only the bytes for the currently active variant.
27pub struct EnumSizeOpt {
28 pub(crate) discrepancy: u64,
29}
30
31impl<'tcx> MirPass<'tcx> for EnumSizeOpt {
32 fn is_enabled(&self, sess: &Session) -> bool {
33 sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3
34 }
35 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
36 // NOTE: This pass may produce different MIR based on the alignment of the target
37 // platform, but it will still be valid.
38 self.optim(tcx, body);
39 }
40}
41
42impl EnumSizeOpt {
43 fn candidate<'tcx>(
44 &self,
45 tcx: TyCtxt<'tcx>,
46 param_env: ParamEnv<'tcx>,
47 ty: Ty<'tcx>,
48 alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>,
49 ) -> Option<(AdtDef<'tcx>, usize, AllocId)> {
50 let adt_def = match ty.kind() {
51 ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def,
52 _ => return None,
53 };
54 let layout = tcx.layout_of(param_env.and(ty)).ok()?;
55 let variants = match &layout.variants {
56 Variants::Single { .. } => return None,
57 Variants::Multiple { tag_encoding, .. }
58 if matches!(tag_encoding, TagEncoding::Niche { .. }) =>
59 {
60 return None;
61 }
62 Variants::Multiple { variants, .. } if variants.len() <= 1 => return None,
63 Variants::Multiple { variants, .. } => variants,
64 };
65 let min = variants.iter().map(|v| v.size).min().unwrap();
66 let max = variants.iter().map(|v| v.size).max().unwrap();
67 if max.bytes() - min.bytes() < self.discrepancy {
68 return None;
69 }
70
71 let num_discrs = adt_def.discriminants(tcx).count();
72 if variants.iter_enumerated().any(|(var_idx, _)| {
73 let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val;
74 (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs)
75 }) {
76 return None;
77 }
78 if let Some(alloc_id) = alloc_cache.get(&ty) {
79 return Some((*adt_def, num_discrs, *alloc_id));
80 }
81
82 let data_layout = tcx.data_layout();
83 let ptr_sized_int = data_layout.ptr_sized_integer();
84 let target_bytes = ptr_sized_int.size().bytes() as usize;
85 let mut data = vec![0; target_bytes * num_discrs];
86 macro_rules! encode_store {
87 ($curr_idx: expr, $endian: expr, $bytes: expr) => {
88 let bytes = match $endian {
89 rustc_target::abi::Endian::Little => $bytes.to_le_bytes(),
90 rustc_target::abi::Endian::Big => $bytes.to_be_bytes(),
91 };
92 for (i, b) in bytes.into_iter().enumerate() {
93 data[$curr_idx + i] = b;
94 }
95 };
96 }
97
98 for (var_idx, layout) in variants.iter_enumerated() {
99 let curr_idx =
100 target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize;
101 let sz = layout.size;
102 match ptr_sized_int {
103 rustc_target::abi::Integer::I32 => {
104 encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32);
105 }
106 rustc_target::abi::Integer::I64 => {
107 encode_store!(curr_idx, data_layout.endian, sz.bytes());
108 }
109 _ => unreachable!(),
110 };
111 }
112 let alloc = interpret::Allocation::from_bytes(
113 data,
114 tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi,
115 Mutability::Not,
116 );
117 let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc));
118 Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc)))
119 }
120 fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
121 let mut alloc_cache = FxHashMap::default();
122 let body_did = body.source.def_id();
123 let param_env = tcx.param_env(body_did);
124
125 let blocks = body.basic_blocks.as_mut();
126 let local_decls = &mut body.local_decls;
127
128 for bb in blocks {
129 bb.expand_statements(|st| {
130 if let StatementKind::Assign(box (
131 lhs,
132 Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)),
133 )) = &st.kind
134 {
135 let ty = lhs.ty(local_decls, tcx).ty;
136
137 let source_info = st.source_info;
138 let span = source_info.span;
139
140 let (adt_def, num_variants, alloc_id) =
141 self.candidate(tcx, param_env, ty, &mut alloc_cache)?;
142 let alloc = tcx.global_alloc(alloc_id).unwrap_memory();
143
144 let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64);
145
146 let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span));
147 let store_live = Statement {
148 source_info,
149 kind: StatementKind::StorageLive(size_array_local),
150 };
151
152 let place = Place::from(size_array_local);
153 let constant_vals = Constant {
154 span,
155 user_ty: None,
156 literal: ConstantKind::Val(
157 interpret::ConstValue::ByRef { alloc, offset: Size::ZERO },
158 tmp_ty,
159 ),
160 };
353b0b11 161 let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals)));
9ffffee4 162
353b0b11
FG
163 let const_assign = Statement {
164 source_info,
165 kind: StatementKind::Assign(Box::new((place, rval))),
166 };
9ffffee4
FG
167
168 let discr_place = Place::from(
169 local_decls
170 .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)),
171 );
172
173 let store_discr = Statement {
174 source_info,
353b0b11
FG
175 kind: StatementKind::Assign(Box::new((
176 discr_place,
177 Rvalue::Discriminant(*rhs),
178 ))),
9ffffee4
FG
179 };
180
181 let discr_cast_place =
182 Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
183
184 let cast_discr = Statement {
185 source_info,
353b0b11 186 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
187 discr_cast_place,
188 Rvalue::Cast(
189 CastKind::IntToInt,
190 Operand::Copy(discr_place),
191 tcx.types.usize,
192 ),
353b0b11 193 ))),
9ffffee4
FG
194 };
195
196 let size_place =
197 Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span)));
198
199 let store_size = Statement {
200 source_info,
353b0b11 201 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
202 size_place,
203 Rvalue::Use(Operand::Copy(Place {
204 local: size_array_local,
205 projection: tcx
206 .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]),
207 })),
353b0b11 208 ))),
9ffffee4
FG
209 };
210
211 let dst =
212 Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span)));
213
214 let dst_ptr = Statement {
215 source_info,
353b0b11 216 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
217 dst,
218 Rvalue::AddressOf(Mutability::Mut, *lhs),
353b0b11 219 ))),
9ffffee4
FG
220 };
221
222 let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8);
223 let dst_cast_place =
224 Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span)));
225
226 let dst_cast = Statement {
227 source_info,
353b0b11 228 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
229 dst_cast_place,
230 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty),
353b0b11 231 ))),
9ffffee4
FG
232 };
233
234 let src =
235 Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span)));
236
237 let src_ptr = Statement {
238 source_info,
353b0b11 239 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
240 src,
241 Rvalue::AddressOf(Mutability::Not, *rhs),
353b0b11 242 ))),
9ffffee4
FG
243 };
244
245 let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8);
246 let src_cast_place =
247 Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span)));
248
249 let src_cast = Statement {
250 source_info,
353b0b11 251 kind: StatementKind::Assign(Box::new((
9ffffee4
FG
252 src_cast_place,
253 Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty),
353b0b11 254 ))),
9ffffee4
FG
255 };
256
257 let deinit_old =
353b0b11 258 Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) };
9ffffee4
FG
259
260 let copy_bytes = Statement {
261 source_info,
353b0b11
FG
262 kind: StatementKind::Intrinsic(Box::new(
263 NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping {
9ffffee4
FG
264 src: Operand::Copy(src_cast_place),
265 dst: Operand::Copy(dst_cast_place),
266 count: Operand::Copy(size_place),
267 }),
353b0b11 268 )),
9ffffee4
FG
269 };
270
271 let store_dead = Statement {
272 source_info,
273 kind: StatementKind::StorageDead(size_array_local),
274 };
275 let iter = [
276 store_live,
277 const_assign,
278 store_discr,
279 cast_discr,
280 store_size,
281 dst_ptr,
282 dst_cast,
283 src_ptr,
284 src_cast,
285 deinit_old,
286 copy_bytes,
287 store_dead,
288 ]
289 .into_iter();
290
291 st.make_nop();
292 Some(iter)
293 } else {
294 None
295 }
296 });
297 }
298 }
299}