]>
Commit | Line | Data |
---|---|---|
9ffffee4 FG |
1 | use crate::rustc_middle::ty::util::IntTypeExt; |
2 | use crate::MirPass; | |
3 | use rustc_data_structures::fx::FxHashMap; | |
4 | use rustc_middle::mir::interpret::AllocId; | |
5 | use rustc_middle::mir::*; | |
6 | use rustc_middle::ty::{self, AdtDef, ParamEnv, Ty, TyCtxt}; | |
7 | use rustc_session::Session; | |
8 | use rustc_target::abi::{HasDataLayout, Size, TagEncoding, Variants}; | |
9 | ||
10 | /// A pass that seeks to optimize unnecessary moves of large enum types, if there is a large | |
11 | /// enough discrepancy between them. | |
12 | /// | |
13 | /// i.e. If there is are two variants: | |
14 | /// ``` | |
15 | /// enum Example { | |
16 | /// Small, | |
17 | /// Large([u32; 1024]), | |
18 | /// } | |
19 | /// ``` | |
20 | /// Instead of emitting moves of the large variant, | |
21 | /// Perform a memcpy instead. | |
22 | /// Based off of [this HackMD](https://hackmd.io/@ft4bxUsFT5CEUBmRKYHr7w/rJM8BBPzD). | |
23 | /// | |
24 | /// In summary, what this does is at runtime determine which enum variant is active, | |
25 | /// and instead of copying all the bytes of the largest possible variant, | |
26 | /// copy only the bytes for the currently active variant. | |
27 | pub struct EnumSizeOpt { | |
28 | pub(crate) discrepancy: u64, | |
29 | } | |
30 | ||
31 | impl<'tcx> MirPass<'tcx> for EnumSizeOpt { | |
32 | fn is_enabled(&self, sess: &Session) -> bool { | |
33 | sess.opts.unstable_opts.unsound_mir_opts || sess.mir_opt_level() >= 3 | |
34 | } | |
35 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
36 | // NOTE: This pass may produce different MIR based on the alignment of the target | |
37 | // platform, but it will still be valid. | |
38 | self.optim(tcx, body); | |
39 | } | |
40 | } | |
41 | ||
42 | impl EnumSizeOpt { | |
43 | fn candidate<'tcx>( | |
44 | &self, | |
45 | tcx: TyCtxt<'tcx>, | |
46 | param_env: ParamEnv<'tcx>, | |
47 | ty: Ty<'tcx>, | |
48 | alloc_cache: &mut FxHashMap<Ty<'tcx>, AllocId>, | |
49 | ) -> Option<(AdtDef<'tcx>, usize, AllocId)> { | |
50 | let adt_def = match ty.kind() { | |
51 | ty::Adt(adt_def, _substs) if adt_def.is_enum() => adt_def, | |
52 | _ => return None, | |
53 | }; | |
54 | let layout = tcx.layout_of(param_env.and(ty)).ok()?; | |
55 | let variants = match &layout.variants { | |
56 | Variants::Single { .. } => return None, | |
57 | Variants::Multiple { tag_encoding, .. } | |
58 | if matches!(tag_encoding, TagEncoding::Niche { .. }) => | |
59 | { | |
60 | return None; | |
61 | } | |
62 | Variants::Multiple { variants, .. } if variants.len() <= 1 => return None, | |
63 | Variants::Multiple { variants, .. } => variants, | |
64 | }; | |
65 | let min = variants.iter().map(|v| v.size).min().unwrap(); | |
66 | let max = variants.iter().map(|v| v.size).max().unwrap(); | |
67 | if max.bytes() - min.bytes() < self.discrepancy { | |
68 | return None; | |
69 | } | |
70 | ||
71 | let num_discrs = adt_def.discriminants(tcx).count(); | |
72 | if variants.iter_enumerated().any(|(var_idx, _)| { | |
73 | let discr_for_var = adt_def.discriminant_for_variant(tcx, var_idx).val; | |
74 | (discr_for_var > usize::MAX as u128) || (discr_for_var as usize >= num_discrs) | |
75 | }) { | |
76 | return None; | |
77 | } | |
78 | if let Some(alloc_id) = alloc_cache.get(&ty) { | |
79 | return Some((*adt_def, num_discrs, *alloc_id)); | |
80 | } | |
81 | ||
82 | let data_layout = tcx.data_layout(); | |
83 | let ptr_sized_int = data_layout.ptr_sized_integer(); | |
84 | let target_bytes = ptr_sized_int.size().bytes() as usize; | |
85 | let mut data = vec![0; target_bytes * num_discrs]; | |
86 | macro_rules! encode_store { | |
87 | ($curr_idx: expr, $endian: expr, $bytes: expr) => { | |
88 | let bytes = match $endian { | |
89 | rustc_target::abi::Endian::Little => $bytes.to_le_bytes(), | |
90 | rustc_target::abi::Endian::Big => $bytes.to_be_bytes(), | |
91 | }; | |
92 | for (i, b) in bytes.into_iter().enumerate() { | |
93 | data[$curr_idx + i] = b; | |
94 | } | |
95 | }; | |
96 | } | |
97 | ||
98 | for (var_idx, layout) in variants.iter_enumerated() { | |
99 | let curr_idx = | |
100 | target_bytes * adt_def.discriminant_for_variant(tcx, var_idx).val as usize; | |
101 | let sz = layout.size; | |
102 | match ptr_sized_int { | |
103 | rustc_target::abi::Integer::I32 => { | |
104 | encode_store!(curr_idx, data_layout.endian, sz.bytes() as u32); | |
105 | } | |
106 | rustc_target::abi::Integer::I64 => { | |
107 | encode_store!(curr_idx, data_layout.endian, sz.bytes()); | |
108 | } | |
109 | _ => unreachable!(), | |
110 | }; | |
111 | } | |
112 | let alloc = interpret::Allocation::from_bytes( | |
113 | data, | |
114 | tcx.data_layout.ptr_sized_integer().align(&tcx.data_layout).abi, | |
115 | Mutability::Not, | |
116 | ); | |
117 | let alloc = tcx.create_memory_alloc(tcx.mk_const_alloc(alloc)); | |
118 | Some((*adt_def, num_discrs, *alloc_cache.entry(ty).or_insert(alloc))) | |
119 | } | |
120 | fn optim<'tcx>(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
121 | let mut alloc_cache = FxHashMap::default(); | |
122 | let body_did = body.source.def_id(); | |
123 | let param_env = tcx.param_env(body_did); | |
124 | ||
125 | let blocks = body.basic_blocks.as_mut(); | |
126 | let local_decls = &mut body.local_decls; | |
127 | ||
128 | for bb in blocks { | |
129 | bb.expand_statements(|st| { | |
130 | if let StatementKind::Assign(box ( | |
131 | lhs, | |
132 | Rvalue::Use(Operand::Copy(rhs) | Operand::Move(rhs)), | |
133 | )) = &st.kind | |
134 | { | |
135 | let ty = lhs.ty(local_decls, tcx).ty; | |
136 | ||
137 | let source_info = st.source_info; | |
138 | let span = source_info.span; | |
139 | ||
140 | let (adt_def, num_variants, alloc_id) = | |
141 | self.candidate(tcx, param_env, ty, &mut alloc_cache)?; | |
142 | let alloc = tcx.global_alloc(alloc_id).unwrap_memory(); | |
143 | ||
144 | let tmp_ty = tcx.mk_array(tcx.types.usize, num_variants as u64); | |
145 | ||
146 | let size_array_local = local_decls.push(LocalDecl::new(tmp_ty, span)); | |
147 | let store_live = Statement { | |
148 | source_info, | |
149 | kind: StatementKind::StorageLive(size_array_local), | |
150 | }; | |
151 | ||
152 | let place = Place::from(size_array_local); | |
153 | let constant_vals = Constant { | |
154 | span, | |
155 | user_ty: None, | |
156 | literal: ConstantKind::Val( | |
157 | interpret::ConstValue::ByRef { alloc, offset: Size::ZERO }, | |
158 | tmp_ty, | |
159 | ), | |
160 | }; | |
353b0b11 | 161 | let rval = Rvalue::Use(Operand::Constant(Box::new(constant_vals))); |
9ffffee4 | 162 | |
353b0b11 FG |
163 | let const_assign = Statement { |
164 | source_info, | |
165 | kind: StatementKind::Assign(Box::new((place, rval))), | |
166 | }; | |
9ffffee4 FG |
167 | |
168 | let discr_place = Place::from( | |
169 | local_decls | |
170 | .push(LocalDecl::new(adt_def.repr().discr_type().to_ty(tcx), span)), | |
171 | ); | |
172 | ||
173 | let store_discr = Statement { | |
174 | source_info, | |
353b0b11 FG |
175 | kind: StatementKind::Assign(Box::new(( |
176 | discr_place, | |
177 | Rvalue::Discriminant(*rhs), | |
178 | ))), | |
9ffffee4 FG |
179 | }; |
180 | ||
181 | let discr_cast_place = | |
182 | Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); | |
183 | ||
184 | let cast_discr = Statement { | |
185 | source_info, | |
353b0b11 | 186 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
187 | discr_cast_place, |
188 | Rvalue::Cast( | |
189 | CastKind::IntToInt, | |
190 | Operand::Copy(discr_place), | |
191 | tcx.types.usize, | |
192 | ), | |
353b0b11 | 193 | ))), |
9ffffee4 FG |
194 | }; |
195 | ||
196 | let size_place = | |
197 | Place::from(local_decls.push(LocalDecl::new(tcx.types.usize, span))); | |
198 | ||
199 | let store_size = Statement { | |
200 | source_info, | |
353b0b11 | 201 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
202 | size_place, |
203 | Rvalue::Use(Operand::Copy(Place { | |
204 | local: size_array_local, | |
205 | projection: tcx | |
206 | .mk_place_elems(&[PlaceElem::Index(discr_cast_place.local)]), | |
207 | })), | |
353b0b11 | 208 | ))), |
9ffffee4 FG |
209 | }; |
210 | ||
211 | let dst = | |
212 | Place::from(local_decls.push(LocalDecl::new(tcx.mk_mut_ptr(ty), span))); | |
213 | ||
214 | let dst_ptr = Statement { | |
215 | source_info, | |
353b0b11 | 216 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
217 | dst, |
218 | Rvalue::AddressOf(Mutability::Mut, *lhs), | |
353b0b11 | 219 | ))), |
9ffffee4 FG |
220 | }; |
221 | ||
222 | let dst_cast_ty = tcx.mk_mut_ptr(tcx.types.u8); | |
223 | let dst_cast_place = | |
224 | Place::from(local_decls.push(LocalDecl::new(dst_cast_ty, span))); | |
225 | ||
226 | let dst_cast = Statement { | |
227 | source_info, | |
353b0b11 | 228 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
229 | dst_cast_place, |
230 | Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(dst), dst_cast_ty), | |
353b0b11 | 231 | ))), |
9ffffee4 FG |
232 | }; |
233 | ||
234 | let src = | |
235 | Place::from(local_decls.push(LocalDecl::new(tcx.mk_imm_ptr(ty), span))); | |
236 | ||
237 | let src_ptr = Statement { | |
238 | source_info, | |
353b0b11 | 239 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
240 | src, |
241 | Rvalue::AddressOf(Mutability::Not, *rhs), | |
353b0b11 | 242 | ))), |
9ffffee4 FG |
243 | }; |
244 | ||
245 | let src_cast_ty = tcx.mk_imm_ptr(tcx.types.u8); | |
246 | let src_cast_place = | |
247 | Place::from(local_decls.push(LocalDecl::new(src_cast_ty, span))); | |
248 | ||
249 | let src_cast = Statement { | |
250 | source_info, | |
353b0b11 | 251 | kind: StatementKind::Assign(Box::new(( |
9ffffee4 FG |
252 | src_cast_place, |
253 | Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(src), src_cast_ty), | |
353b0b11 | 254 | ))), |
9ffffee4 FG |
255 | }; |
256 | ||
257 | let deinit_old = | |
353b0b11 | 258 | Statement { source_info, kind: StatementKind::Deinit(Box::new(dst)) }; |
9ffffee4 FG |
259 | |
260 | let copy_bytes = Statement { | |
261 | source_info, | |
353b0b11 FG |
262 | kind: StatementKind::Intrinsic(Box::new( |
263 | NonDivergingIntrinsic::CopyNonOverlapping(CopyNonOverlapping { | |
9ffffee4 FG |
264 | src: Operand::Copy(src_cast_place), |
265 | dst: Operand::Copy(dst_cast_place), | |
266 | count: Operand::Copy(size_place), | |
267 | }), | |
353b0b11 | 268 | )), |
9ffffee4 FG |
269 | }; |
270 | ||
271 | let store_dead = Statement { | |
272 | source_info, | |
273 | kind: StatementKind::StorageDead(size_array_local), | |
274 | }; | |
275 | let iter = [ | |
276 | store_live, | |
277 | const_assign, | |
278 | store_discr, | |
279 | cast_discr, | |
280 | store_size, | |
281 | dst_ptr, | |
282 | dst_cast, | |
283 | src_ptr, | |
284 | src_cast, | |
285 | deinit_old, | |
286 | copy_bytes, | |
287 | store_dead, | |
288 | ] | |
289 | .into_iter(); | |
290 | ||
291 | st.make_nop(); | |
292 | Some(iter) | |
293 | } else { | |
294 | None | |
295 | } | |
296 | }); | |
297 | } | |
298 | } | |
299 | } |