]>
Commit | Line | Data |
---|---|---|
1 | //! A constant propagation optimization pass based on dataflow analysis. | |
2 | //! | |
3 | //! Currently, this pass only propagates scalar values. | |
4 | ||
5 | use rustc_const_eval::const_eval::{throw_machine_stop_str, DummyMachine}; | |
6 | use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; | |
7 | use rustc_data_structures::fx::FxHashMap; | |
8 | use rustc_hir::def::DefKind; | |
9 | use rustc_middle::bug; | |
10 | use rustc_middle::mir::interpret::{InterpResult, Scalar}; | |
11 | use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; | |
12 | use rustc_middle::mir::*; | |
13 | use rustc_middle::ty::layout::LayoutOf; | |
14 | use rustc_middle::ty::{self, Ty, TyCtxt}; | |
15 | use rustc_mir_dataflow::value_analysis::{ | |
16 | Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, | |
17 | }; | |
18 | use rustc_mir_dataflow::{lattice::FlatSet, Analysis, Results, ResultsVisitor}; | |
19 | use rustc_span::DUMMY_SP; | |
20 | use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; | |
21 | ||
22 | // These constants are somewhat random guesses and have not been optimized. | |
23 | // If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). | |
24 | const BLOCK_LIMIT: usize = 100; | |
25 | const PLACE_LIMIT: usize = 100; | |
26 | ||
27 | pub struct DataflowConstProp; | |
28 | ||
29 | impl<'tcx> MirPass<'tcx> for DataflowConstProp { | |
30 | fn is_enabled(&self, sess: &rustc_session::Session) -> bool { | |
31 | sess.mir_opt_level() >= 3 | |
32 | } | |
33 | ||
34 | #[instrument(skip_all level = "debug")] | |
35 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
36 | debug!(def_id = ?body.source.def_id()); | |
37 | if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { | |
38 | debug!("aborted dataflow const prop due too many basic blocks"); | |
39 | return; | |
40 | } | |
41 | ||
42 | // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. | |
43 | // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function | |
44 | // applications, where `h` is the height of the lattice. Because the height of our lattice | |
45 | // is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However, | |
46 | // because every transfer function application could traverse the whole map, this becomes | |
47 | // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of | |
48 | // map nodes is strongly correlated to the number of tracked places, this becomes more or | |
49 | // less `O(n)` if we place a constant limit on the number of tracked places. | |
50 | let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; | |
51 | ||
52 | // Decide which places to track during the analysis. | |
53 | let map = Map::new(tcx, body, place_limit); | |
54 | ||
55 | // Perform the actual dataflow analysis. | |
56 | let analysis = ConstAnalysis::new(tcx, body, map); | |
57 | let mut results = debug_span!("analyze") | |
58 | .in_scope(|| analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint()); | |
59 | ||
60 | // Collect results and patch the body afterwards. | |
61 | let mut visitor = Collector::new(tcx, &body.local_decls); | |
62 | debug_span!("collect").in_scope(|| results.visit_reachable_with(body, &mut visitor)); | |
63 | let mut patch = visitor.patch; | |
64 | debug_span!("patch").in_scope(|| patch.visit_body_preserves_cfg(body)); | |
65 | } | |
66 | } | |
67 | ||
68 | struct ConstAnalysis<'a, 'tcx> { | |
69 | map: Map, | |
70 | tcx: TyCtxt<'tcx>, | |
71 | local_decls: &'a LocalDecls<'tcx>, | |
72 | ecx: InterpCx<'tcx, DummyMachine>, | |
73 | param_env: ty::ParamEnv<'tcx>, | |
74 | } | |
75 | ||
76 | impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { | |
77 | type Value = FlatSet<Scalar>; | |
78 | ||
79 | const NAME: &'static str = "ConstAnalysis"; | |
80 | ||
81 | fn map(&self) -> &Map { | |
82 | &self.map | |
83 | } | |
84 | ||
85 | fn handle_set_discriminant( | |
86 | &self, | |
87 | place: Place<'tcx>, | |
88 | variant_index: VariantIdx, | |
89 | state: &mut State<Self::Value>, | |
90 | ) { | |
91 | state.flood_discr(place.as_ref(), &self.map); | |
92 | if self.map.find_discr(place.as_ref()).is_some() { | |
93 | let enum_ty = place.ty(self.local_decls, self.tcx).ty; | |
94 | if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { | |
95 | state.assign_discr( | |
96 | place.as_ref(), | |
97 | ValueOrPlace::Value(FlatSet::Elem(discr)), | |
98 | &self.map, | |
99 | ); | |
100 | } | |
101 | } | |
102 | } | |
103 | ||
104 | fn handle_assign( | |
105 | &self, | |
106 | target: Place<'tcx>, | |
107 | rvalue: &Rvalue<'tcx>, | |
108 | state: &mut State<Self::Value>, | |
109 | ) { | |
110 | match rvalue { | |
111 | Rvalue::Use(operand) => { | |
112 | state.flood(target.as_ref(), self.map()); | |
113 | if let Some(target) = self.map.find(target.as_ref()) { | |
114 | self.assign_operand(state, target, operand); | |
115 | } | |
116 | } | |
117 | Rvalue::CopyForDeref(rhs) => { | |
118 | state.flood(target.as_ref(), self.map()); | |
119 | if let Some(target) = self.map.find(target.as_ref()) { | |
120 | self.assign_operand(state, target, &Operand::Copy(*rhs)); | |
121 | } | |
122 | } | |
123 | Rvalue::Aggregate(kind, operands) => { | |
124 | // If we assign `target = Enum::Variant#0(operand)`, | |
125 | // we must make sure that all `target as Variant#i` are `Top`. | |
126 | state.flood(target.as_ref(), self.map()); | |
127 | ||
128 | let Some(target_idx) = self.map().find(target.as_ref()) else { return }; | |
129 | ||
130 | let (variant_target, variant_index) = match **kind { | |
131 | AggregateKind::Tuple | AggregateKind::Closure(..) => (Some(target_idx), None), | |
132 | AggregateKind::Adt(def_id, variant_index, ..) => { | |
133 | match self.tcx.def_kind(def_id) { | |
134 | DefKind::Struct => (Some(target_idx), None), | |
135 | DefKind::Enum => ( | |
136 | self.map.apply(target_idx, TrackElem::Variant(variant_index)), | |
137 | Some(variant_index), | |
138 | ), | |
139 | _ => return, | |
140 | } | |
141 | } | |
142 | _ => return, | |
143 | }; | |
144 | if let Some(variant_target_idx) = variant_target { | |
145 | for (field_index, operand) in operands.iter_enumerated() { | |
146 | if let Some(field) = | |
147 | self.map().apply(variant_target_idx, TrackElem::Field(field_index)) | |
148 | { | |
149 | self.assign_operand(state, field, operand); | |
150 | } | |
151 | } | |
152 | } | |
153 | if let Some(variant_index) = variant_index | |
154 | && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) | |
155 | { | |
156 | // We are assigning the discriminant as part of an aggregate. | |
157 | // This discriminant can only alias a variant field's value if the operand | |
158 | // had an invalid value for that type. | |
159 | // Using invalid values is UB, so we are allowed to perform the assignment | |
160 | // without extra flooding. | |
161 | let enum_ty = target.ty(self.local_decls, self.tcx).ty; | |
162 | if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { | |
163 | state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); | |
164 | } | |
165 | } | |
166 | } | |
167 | Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { | |
168 | // Flood everything now, so we can use `insert_value_idx` directly later. | |
169 | state.flood(target.as_ref(), self.map()); | |
170 | ||
171 | let Some(target) = self.map().find(target.as_ref()) else { return }; | |
172 | ||
173 | let value_target = self.map().apply(target, TrackElem::Field(0_u32.into())); | |
174 | let overflow_target = self.map().apply(target, TrackElem::Field(1_u32.into())); | |
175 | ||
176 | if value_target.is_some() || overflow_target.is_some() { | |
177 | let (val, overflow) = self.binary_op(state, *op, left, right); | |
178 | ||
179 | if let Some(value_target) = value_target { | |
180 | // We have flooded `target` earlier. | |
181 | state.insert_value_idx(value_target, val, self.map()); | |
182 | } | |
183 | if let Some(overflow_target) = overflow_target { | |
184 | let overflow = match overflow { | |
185 | FlatSet::Top => FlatSet::Top, | |
186 | FlatSet::Elem(overflow) => FlatSet::Elem(overflow), | |
187 | FlatSet::Bottom => FlatSet::Bottom, | |
188 | }; | |
189 | // We have flooded `target` earlier. | |
190 | state.insert_value_idx(overflow_target, overflow, self.map()); | |
191 | } | |
192 | } | |
193 | } | |
194 | Rvalue::Cast( | |
195 | CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize), | |
196 | operand, | |
197 | _, | |
198 | ) => { | |
199 | let pointer = self.handle_operand(operand, state); | |
200 | state.assign(target.as_ref(), pointer, self.map()); | |
201 | ||
202 | if let Some(target_len) = self.map().find_len(target.as_ref()) | |
203 | && let operand_ty = operand.ty(self.local_decls, self.tcx) | |
204 | && let Some(operand_ty) = operand_ty.builtin_deref(true) | |
205 | && let ty::Array(_, len) = operand_ty.kind() | |
206 | && let Some(len) = Const::Ty(self.tcx.types.usize, *len) | |
207 | .try_eval_scalar_int(self.tcx, self.param_env) | |
208 | { | |
209 | state.insert_value_idx(target_len, FlatSet::Elem(len.into()), self.map()); | |
210 | } | |
211 | } | |
212 | _ => self.super_assign(target, rvalue, state), | |
213 | } | |
214 | } | |
215 | ||
216 | fn handle_rvalue( | |
217 | &self, | |
218 | rvalue: &Rvalue<'tcx>, | |
219 | state: &mut State<Self::Value>, | |
220 | ) -> ValueOrPlace<Self::Value> { | |
221 | let val = match rvalue { | |
222 | Rvalue::Len(place) => { | |
223 | let place_ty = place.ty(self.local_decls, self.tcx); | |
224 | if let ty::Array(_, len) = place_ty.ty.kind() { | |
225 | Const::Ty(self.tcx.types.usize, *len) | |
226 | .try_eval_scalar(self.tcx, self.param_env) | |
227 | .map_or(FlatSet::Top, FlatSet::Elem) | |
228 | } else if let [ProjectionElem::Deref] = place.projection[..] { | |
229 | state.get_len(place.local.into(), self.map()) | |
230 | } else { | |
231 | FlatSet::Top | |
232 | } | |
233 | } | |
234 | Rvalue::Cast(CastKind::IntToInt | CastKind::IntToFloat, operand, ty) => { | |
235 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
236 | return ValueOrPlace::Value(FlatSet::Top); | |
237 | }; | |
238 | match self.eval_operand(operand, state) { | |
239 | FlatSet::Elem(op) => self | |
240 | .ecx | |
241 | .int_to_int_or_float(&op, layout) | |
242 | .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), | |
243 | FlatSet::Bottom => FlatSet::Bottom, | |
244 | FlatSet::Top => FlatSet::Top, | |
245 | } | |
246 | } | |
247 | Rvalue::Cast(CastKind::FloatToInt | CastKind::FloatToFloat, operand, ty) => { | |
248 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
249 | return ValueOrPlace::Value(FlatSet::Top); | |
250 | }; | |
251 | match self.eval_operand(operand, state) { | |
252 | FlatSet::Elem(op) => self | |
253 | .ecx | |
254 | .float_to_float_or_int(&op, layout) | |
255 | .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), | |
256 | FlatSet::Bottom => FlatSet::Bottom, | |
257 | FlatSet::Top => FlatSet::Top, | |
258 | } | |
259 | } | |
260 | Rvalue::Cast(CastKind::Transmute, operand, _) => { | |
261 | match self.eval_operand(operand, state) { | |
262 | FlatSet::Elem(op) => self.wrap_immediate(*op), | |
263 | FlatSet::Bottom => FlatSet::Bottom, | |
264 | FlatSet::Top => FlatSet::Top, | |
265 | } | |
266 | } | |
267 | Rvalue::BinaryOp(op, box (left, right)) if !op.is_overflowing() => { | |
268 | // Overflows must be ignored here. | |
269 | // The overflowing operators are handled in `handle_assign`. | |
270 | let (val, _overflow) = self.binary_op(state, *op, left, right); | |
271 | val | |
272 | } | |
273 | Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { | |
274 | FlatSet::Elem(value) => self | |
275 | .ecx | |
276 | .unary_op(*op, &value) | |
277 | .map_or(FlatSet::Top, |val| self.wrap_immediate(*val)), | |
278 | FlatSet::Bottom => FlatSet::Bottom, | |
279 | FlatSet::Top => FlatSet::Top, | |
280 | }, | |
281 | Rvalue::NullaryOp(null_op, ty) => { | |
282 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
283 | return ValueOrPlace::Value(FlatSet::Top); | |
284 | }; | |
285 | let val = match null_op { | |
286 | NullOp::SizeOf if layout.is_sized() => layout.size.bytes(), | |
287 | NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(), | |
288 | NullOp::OffsetOf(fields) => { | |
289 | layout.offset_of_subfield(&self.ecx, fields.iter()).bytes() | |
290 | } | |
291 | _ => return ValueOrPlace::Value(FlatSet::Top), | |
292 | }; | |
293 | FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) | |
294 | } | |
295 | Rvalue::Discriminant(place) => state.get_discr(place.as_ref(), self.map()), | |
296 | _ => return self.super_rvalue(rvalue, state), | |
297 | }; | |
298 | ValueOrPlace::Value(val) | |
299 | } | |
300 | ||
301 | fn handle_constant( | |
302 | &self, | |
303 | constant: &ConstOperand<'tcx>, | |
304 | _state: &mut State<Self::Value>, | |
305 | ) -> Self::Value { | |
306 | constant | |
307 | .const_ | |
308 | .try_eval_scalar(self.tcx, self.param_env) | |
309 | .map_or(FlatSet::Top, FlatSet::Elem) | |
310 | } | |
311 | ||
312 | fn handle_switch_int<'mir>( | |
313 | &self, | |
314 | discr: &'mir Operand<'tcx>, | |
315 | targets: &'mir SwitchTargets, | |
316 | state: &mut State<Self::Value>, | |
317 | ) -> TerminatorEdges<'mir, 'tcx> { | |
318 | let value = match self.handle_operand(discr, state) { | |
319 | ValueOrPlace::Value(value) => value, | |
320 | ValueOrPlace::Place(place) => state.get_idx(place, self.map()), | |
321 | }; | |
322 | match value { | |
323 | // We are branching on uninitialized data, this is UB, treat it as unreachable. | |
324 | // This allows the set of visited edges to grow monotonically with the lattice. | |
325 | FlatSet::Bottom => TerminatorEdges::None, | |
326 | FlatSet::Elem(scalar) => { | |
327 | let choice = scalar.assert_bits(scalar.size()); | |
328 | TerminatorEdges::Single(targets.target_for_value(choice)) | |
329 | } | |
330 | FlatSet::Top => TerminatorEdges::SwitchInt { discr, targets }, | |
331 | } | |
332 | } | |
333 | } | |
334 | ||
335 | impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { | |
336 | pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map) -> Self { | |
337 | let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); | |
338 | Self { | |
339 | map, | |
340 | tcx, | |
341 | local_decls: &body.local_decls, | |
342 | ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), | |
343 | param_env: param_env, | |
344 | } | |
345 | } | |
346 | ||
347 | /// The caller must have flooded `place`. | |
348 | fn assign_operand( | |
349 | &self, | |
350 | state: &mut State<FlatSet<Scalar>>, | |
351 | place: PlaceIndex, | |
352 | operand: &Operand<'tcx>, | |
353 | ) { | |
354 | match operand { | |
355 | Operand::Copy(rhs) | Operand::Move(rhs) => { | |
356 | if let Some(rhs) = self.map.find(rhs.as_ref()) { | |
357 | state.insert_place_idx(place, rhs, &self.map); | |
358 | } else if rhs.projection.first() == Some(&PlaceElem::Deref) | |
359 | && let FlatSet::Elem(pointer) = state.get(rhs.local.into(), &self.map) | |
360 | && let rhs_ty = self.local_decls[rhs.local].ty | |
361 | && let Ok(rhs_layout) = self.tcx.layout_of(self.param_env.and(rhs_ty)) | |
362 | { | |
363 | let op = ImmTy::from_scalar(pointer, rhs_layout).into(); | |
364 | self.assign_constant(state, place, op, rhs.projection); | |
365 | } | |
366 | } | |
367 | Operand::Constant(box constant) => { | |
368 | if let Ok(constant) = | |
369 | self.ecx.eval_mir_constant(&constant.const_, constant.span, None) | |
370 | { | |
371 | self.assign_constant(state, place, constant, &[]); | |
372 | } | |
373 | } | |
374 | } | |
375 | } | |
376 | ||
377 | /// The caller must have flooded `place`. | |
378 | /// | |
379 | /// Perform: `place = operand.projection`. | |
380 | #[instrument(level = "trace", skip(self, state))] | |
381 | fn assign_constant( | |
382 | &self, | |
383 | state: &mut State<FlatSet<Scalar>>, | |
384 | place: PlaceIndex, | |
385 | mut operand: OpTy<'tcx>, | |
386 | projection: &[PlaceElem<'tcx>], | |
387 | ) -> Option<!> { | |
388 | for &(mut proj_elem) in projection { | |
389 | if let PlaceElem::Index(index) = proj_elem { | |
390 | if let FlatSet::Elem(index) = state.get(index.into(), &self.map) | |
391 | && let Ok(offset) = index.to_target_usize(&self.tcx) | |
392 | && let Some(min_length) = offset.checked_add(1) | |
393 | { | |
394 | proj_elem = PlaceElem::ConstantIndex { offset, min_length, from_end: false }; | |
395 | } else { | |
396 | return None; | |
397 | } | |
398 | } | |
399 | operand = self.ecx.project(&operand, proj_elem).ok()?; | |
400 | } | |
401 | ||
402 | self.map.for_each_projection_value( | |
403 | place, | |
404 | operand, | |
405 | &mut |elem, op| match elem { | |
406 | TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(), | |
407 | TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), | |
408 | TrackElem::Discriminant => { | |
409 | let variant = self.ecx.read_discriminant(op).ok()?; | |
410 | let discr_value = | |
411 | self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; | |
412 | Some(discr_value.into()) | |
413 | } | |
414 | TrackElem::DerefLen => { | |
415 | let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into(); | |
416 | let len_usize = op.len(&self.ecx).ok()?; | |
417 | let layout = | |
418 | self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).unwrap(); | |
419 | Some(ImmTy::from_uint(len_usize, layout).into()) | |
420 | } | |
421 | }, | |
422 | &mut |place, op| { | |
423 | if let Ok(imm) = self.ecx.read_immediate_raw(op) | |
424 | && let Some(imm) = imm.right() | |
425 | { | |
426 | let elem = self.wrap_immediate(*imm); | |
427 | state.insert_value_idx(place, elem, &self.map); | |
428 | } | |
429 | }, | |
430 | ); | |
431 | ||
432 | None | |
433 | } | |
434 | ||
435 | fn binary_op( | |
436 | &self, | |
437 | state: &mut State<FlatSet<Scalar>>, | |
438 | op: BinOp, | |
439 | left: &Operand<'tcx>, | |
440 | right: &Operand<'tcx>, | |
441 | ) -> (FlatSet<Scalar>, FlatSet<Scalar>) { | |
442 | let left = self.eval_operand(left, state); | |
443 | let right = self.eval_operand(right, state); | |
444 | ||
445 | match (left, right) { | |
446 | (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), | |
447 | // Both sides are known, do the actual computation. | |
448 | (FlatSet::Elem(left), FlatSet::Elem(right)) => { | |
449 | match self.ecx.binary_op(op, &left, &right) { | |
450 | // Ideally this would return an Immediate, since it's sometimes | |
451 | // a pair and sometimes not. But as a hack we always return a pair | |
452 | // and just make the 2nd component `Bottom` when it does not exist. | |
453 | Ok(val) => { | |
454 | if matches!(val.layout.abi, Abi::ScalarPair(..)) { | |
455 | let (val, overflow) = val.to_scalar_pair(); | |
456 | (FlatSet::Elem(val), FlatSet::Elem(overflow)) | |
457 | } else { | |
458 | (FlatSet::Elem(val.to_scalar()), FlatSet::Bottom) | |
459 | } | |
460 | } | |
461 | _ => (FlatSet::Top, FlatSet::Top), | |
462 | } | |
463 | } | |
464 | // Exactly one side is known, attempt some algebraic simplifications. | |
465 | (FlatSet::Elem(const_arg), _) | (_, FlatSet::Elem(const_arg)) => { | |
466 | let layout = const_arg.layout; | |
467 | if !matches!(layout.abi, rustc_target::abi::Abi::Scalar(..)) { | |
468 | return (FlatSet::Top, FlatSet::Top); | |
469 | } | |
470 | ||
471 | let arg_scalar = const_arg.to_scalar(); | |
472 | let Ok(arg_value) = arg_scalar.to_bits(layout.size) else { | |
473 | return (FlatSet::Top, FlatSet::Top); | |
474 | }; | |
475 | ||
476 | match op { | |
477 | BinOp::BitAnd if arg_value == 0 => (FlatSet::Elem(arg_scalar), FlatSet::Bottom), | |
478 | BinOp::BitOr | |
479 | if arg_value == layout.size.truncate(u128::MAX) | |
480 | || (layout.ty.is_bool() && arg_value == 1) => | |
481 | { | |
482 | (FlatSet::Elem(arg_scalar), FlatSet::Bottom) | |
483 | } | |
484 | BinOp::Mul if layout.ty.is_integral() && arg_value == 0 => { | |
485 | (FlatSet::Elem(arg_scalar), FlatSet::Elem(Scalar::from_bool(false))) | |
486 | } | |
487 | _ => (FlatSet::Top, FlatSet::Top), | |
488 | } | |
489 | } | |
490 | (FlatSet::Top, FlatSet::Top) => (FlatSet::Top, FlatSet::Top), | |
491 | } | |
492 | } | |
493 | ||
494 | fn eval_operand( | |
495 | &self, | |
496 | op: &Operand<'tcx>, | |
497 | state: &mut State<FlatSet<Scalar>>, | |
498 | ) -> FlatSet<ImmTy<'tcx>> { | |
499 | let value = match self.handle_operand(op, state) { | |
500 | ValueOrPlace::Value(value) => value, | |
501 | ValueOrPlace::Place(place) => state.get_idx(place, &self.map), | |
502 | }; | |
503 | match value { | |
504 | FlatSet::Top => FlatSet::Top, | |
505 | FlatSet::Elem(scalar) => { | |
506 | let ty = op.ty(self.local_decls, self.tcx); | |
507 | self.tcx.layout_of(self.param_env.and(ty)).map_or(FlatSet::Top, |layout| { | |
508 | FlatSet::Elem(ImmTy::from_scalar(scalar, layout)) | |
509 | }) | |
510 | } | |
511 | FlatSet::Bottom => FlatSet::Bottom, | |
512 | } | |
513 | } | |
514 | ||
515 | fn eval_discriminant(&self, enum_ty: Ty<'tcx>, variant_index: VariantIdx) -> Option<Scalar> { | |
516 | if !enum_ty.is_enum() { | |
517 | return None; | |
518 | } | |
519 | let enum_ty_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; | |
520 | let discr_value = | |
521 | self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).ok()?; | |
522 | Some(discr_value.to_scalar()) | |
523 | } | |
524 | ||
525 | fn wrap_immediate(&self, imm: Immediate) -> FlatSet<Scalar> { | |
526 | match imm { | |
527 | Immediate::Scalar(scalar) => FlatSet::Elem(scalar), | |
528 | Immediate::Uninit => FlatSet::Bottom, | |
529 | _ => FlatSet::Top, | |
530 | } | |
531 | } | |
532 | } | |
533 | ||
534 | pub(crate) struct Patch<'tcx> { | |
535 | tcx: TyCtxt<'tcx>, | |
536 | ||
537 | /// For a given MIR location, this stores the values of the operands used by that location. In | |
538 | /// particular, this is before the effect, such that the operands of `_1 = _1 + _2` are | |
539 | /// properly captured. (This may become UB soon, but it is currently emitted even by safe code.) | |
540 | pub(crate) before_effect: FxHashMap<(Location, Place<'tcx>), Const<'tcx>>, | |
541 | ||
542 | /// Stores the assigned values for assignments where the Rvalue is constant. | |
543 | pub(crate) assignments: FxHashMap<Location, Const<'tcx>>, | |
544 | } | |
545 | ||
546 | impl<'tcx> Patch<'tcx> { | |
547 | pub(crate) fn new(tcx: TyCtxt<'tcx>) -> Self { | |
548 | Self { tcx, before_effect: FxHashMap::default(), assignments: FxHashMap::default() } | |
549 | } | |
550 | ||
551 | fn make_operand(&self, const_: Const<'tcx>) -> Operand<'tcx> { | |
552 | Operand::Constant(Box::new(ConstOperand { span: DUMMY_SP, user_ty: None, const_ })) | |
553 | } | |
554 | } | |
555 | ||
556 | struct Collector<'tcx, 'locals> { | |
557 | patch: Patch<'tcx>, | |
558 | local_decls: &'locals LocalDecls<'tcx>, | |
559 | } | |
560 | ||
561 | impl<'tcx, 'locals> Collector<'tcx, 'locals> { | |
562 | pub(crate) fn new(tcx: TyCtxt<'tcx>, local_decls: &'locals LocalDecls<'tcx>) -> Self { | |
563 | Self { patch: Patch::new(tcx), local_decls } | |
564 | } | |
565 | ||
566 | fn try_make_constant( | |
567 | &self, | |
568 | ecx: &mut InterpCx<'tcx, DummyMachine>, | |
569 | place: Place<'tcx>, | |
570 | state: &State<FlatSet<Scalar>>, | |
571 | map: &Map, | |
572 | ) -> Option<Const<'tcx>> { | |
573 | let ty = place.ty(self.local_decls, self.patch.tcx).ty; | |
574 | let layout = ecx.layout_of(ty).ok()?; | |
575 | ||
576 | if layout.is_zst() { | |
577 | return Some(Const::zero_sized(ty)); | |
578 | } | |
579 | ||
580 | if layout.is_unsized() { | |
581 | return None; | |
582 | } | |
583 | ||
584 | let place = map.find(place.as_ref())?; | |
585 | if layout.abi.is_scalar() | |
586 | && let Some(value) = propagatable_scalar(place, state, map) | |
587 | { | |
588 | return Some(Const::Val(ConstValue::Scalar(value), ty)); | |
589 | } | |
590 | ||
591 | if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { | |
592 | let alloc_id = ecx | |
593 | .intern_with_temp_alloc(layout, |ecx, dest| { | |
594 | try_write_constant(ecx, dest, place, ty, state, map) | |
595 | }) | |
596 | .ok()?; | |
597 | return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty)); | |
598 | } | |
599 | ||
600 | None | |
601 | } | |
602 | } | |
603 | ||
604 | fn propagatable_scalar( | |
605 | place: PlaceIndex, | |
606 | state: &State<FlatSet<Scalar>>, | |
607 | map: &Map, | |
608 | ) -> Option<Scalar> { | |
609 | if let FlatSet::Elem(value) = state.get_idx(place, map) | |
610 | && value.try_to_int().is_ok() | |
611 | { | |
612 | // Do not attempt to propagate pointers, as we may fail to preserve their identity. | |
613 | Some(value) | |
614 | } else { | |
615 | None | |
616 | } | |
617 | } | |
618 | ||
619 | #[instrument(level = "trace", skip(ecx, state, map))] | |
620 | fn try_write_constant<'tcx>( | |
621 | ecx: &mut InterpCx<'tcx, DummyMachine>, | |
622 | dest: &PlaceTy<'tcx>, | |
623 | place: PlaceIndex, | |
624 | ty: Ty<'tcx>, | |
625 | state: &State<FlatSet<Scalar>>, | |
626 | map: &Map, | |
627 | ) -> InterpResult<'tcx> { | |
628 | let layout = ecx.layout_of(ty)?; | |
629 | ||
630 | // Fast path for ZSTs. | |
631 | if layout.is_zst() { | |
632 | return Ok(()); | |
633 | } | |
634 | ||
635 | // Fast path for scalars. | |
636 | if layout.abi.is_scalar() | |
637 | && let Some(value) = propagatable_scalar(place, state, map) | |
638 | { | |
639 | return ecx.write_immediate(Immediate::Scalar(value), dest); | |
640 | } | |
641 | ||
642 | match ty.kind() { | |
643 | // ZSTs. Nothing to do. | |
644 | ty::FnDef(..) => {} | |
645 | ||
646 | // Those are scalars, must be handled above. | |
647 | ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"), | |
648 | ||
649 | ty::Tuple(elem_tys) => { | |
650 | for (i, elem) in elem_tys.iter().enumerate() { | |
651 | let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else { | |
652 | throw_machine_stop_str!("missing field in tuple") | |
653 | }; | |
654 | let field_dest = ecx.project_field(dest, i)?; | |
655 | try_write_constant(ecx, &field_dest, field, elem, state, map)?; | |
656 | } | |
657 | } | |
658 | ||
659 | ty::Adt(def, args) => { | |
660 | if def.is_union() { | |
661 | throw_machine_stop_str!("cannot propagate unions") | |
662 | } | |
663 | ||
664 | let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() { | |
665 | let Some(discr) = map.apply(place, TrackElem::Discriminant) else { | |
666 | throw_machine_stop_str!("missing discriminant for enum") | |
667 | }; | |
668 | let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { | |
669 | throw_machine_stop_str!("discriminant with provenance") | |
670 | }; | |
671 | let discr_bits = discr.assert_bits(discr.size()); | |
672 | let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else { | |
673 | throw_machine_stop_str!("illegal discriminant for enum") | |
674 | }; | |
675 | let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else { | |
676 | throw_machine_stop_str!("missing variant for enum") | |
677 | }; | |
678 | let variant_dest = ecx.project_downcast(dest, variant)?; | |
679 | (variant, def.variant(variant), variant_place, variant_dest) | |
680 | } else { | |
681 | (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone()) | |
682 | }; | |
683 | ||
684 | for (i, field) in variant_def.fields.iter_enumerated() { | |
685 | let ty = field.ty(*ecx.tcx, args); | |
686 | let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else { | |
687 | throw_machine_stop_str!("missing field in ADT") | |
688 | }; | |
689 | let field_dest = ecx.project_field(&variant_dest, i.as_usize())?; | |
690 | try_write_constant(ecx, &field_dest, field, ty, state, map)?; | |
691 | } | |
692 | ecx.write_discriminant(variant_idx, dest)?; | |
693 | } | |
694 | ||
695 | // Unsupported for now. | |
696 | ty::Array(_, _) | |
697 | | ty::Pat(_, _) | |
698 | ||
699 | // Do not attempt to support indirection in constants. | |
700 | | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) | |
701 | ||
702 | | ty::Never | |
703 | | ty::Foreign(..) | |
704 | | ty::Alias(..) | |
705 | | ty::Param(_) | |
706 | | ty::Bound(..) | |
707 | | ty::Placeholder(..) | |
708 | | ty::Closure(..) | |
709 | | ty::CoroutineClosure(..) | |
710 | | ty::Coroutine(..) | |
711 | | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), | |
712 | ||
713 | ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), | |
714 | } | |
715 | ||
716 | Ok(()) | |
717 | } | |
718 | ||
719 | impl<'mir, 'tcx> | |
720 | ResultsVisitor<'mir, 'tcx, Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>> | |
721 | for Collector<'tcx, '_> | |
722 | { | |
723 | type FlowState = State<FlatSet<Scalar>>; | |
724 | ||
725 | fn visit_statement_before_primary_effect( | |
726 | &mut self, | |
727 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
728 | state: &Self::FlowState, | |
729 | statement: &'mir Statement<'tcx>, | |
730 | location: Location, | |
731 | ) { | |
732 | match &statement.kind { | |
733 | StatementKind::Assign(box (_, rvalue)) => { | |
734 | OperandCollector { | |
735 | state, | |
736 | visitor: self, | |
737 | ecx: &mut results.analysis.0.ecx, | |
738 | map: &results.analysis.0.map, | |
739 | } | |
740 | .visit_rvalue(rvalue, location); | |
741 | } | |
742 | _ => (), | |
743 | } | |
744 | } | |
745 | ||
746 | fn visit_statement_after_primary_effect( | |
747 | &mut self, | |
748 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
749 | state: &Self::FlowState, | |
750 | statement: &'mir Statement<'tcx>, | |
751 | location: Location, | |
752 | ) { | |
753 | match statement.kind { | |
754 | StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(_)))) => { | |
755 | // Don't overwrite the assignment if it already uses a constant (to keep the span). | |
756 | } | |
757 | StatementKind::Assign(box (place, _)) => { | |
758 | if let Some(value) = self.try_make_constant( | |
759 | &mut results.analysis.0.ecx, | |
760 | place, | |
761 | state, | |
762 | &results.analysis.0.map, | |
763 | ) { | |
764 | self.patch.assignments.insert(location, value); | |
765 | } | |
766 | } | |
767 | _ => (), | |
768 | } | |
769 | } | |
770 | ||
771 | fn visit_terminator_before_primary_effect( | |
772 | &mut self, | |
773 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
774 | state: &Self::FlowState, | |
775 | terminator: &'mir Terminator<'tcx>, | |
776 | location: Location, | |
777 | ) { | |
778 | OperandCollector { | |
779 | state, | |
780 | visitor: self, | |
781 | ecx: &mut results.analysis.0.ecx, | |
782 | map: &results.analysis.0.map, | |
783 | } | |
784 | .visit_terminator(terminator, location); | |
785 | } | |
786 | } | |
787 | ||
788 | impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { | |
789 | fn tcx(&self) -> TyCtxt<'tcx> { | |
790 | self.tcx | |
791 | } | |
792 | ||
793 | fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { | |
794 | if let Some(value) = self.assignments.get(&location) { | |
795 | match &mut statement.kind { | |
796 | StatementKind::Assign(box (_, rvalue)) => { | |
797 | *rvalue = Rvalue::Use(self.make_operand(*value)); | |
798 | } | |
799 | _ => bug!("found assignment info for non-assign statement"), | |
800 | } | |
801 | } else { | |
802 | self.super_statement(statement, location); | |
803 | } | |
804 | } | |
805 | ||
806 | fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { | |
807 | match operand { | |
808 | Operand::Copy(place) | Operand::Move(place) => { | |
809 | if let Some(value) = self.before_effect.get(&(location, *place)) { | |
810 | *operand = self.make_operand(*value); | |
811 | } else if !place.projection.is_empty() { | |
812 | self.super_operand(operand, location) | |
813 | } | |
814 | } | |
815 | Operand::Constant(_) => {} | |
816 | } | |
817 | } | |
818 | ||
819 | fn process_projection_elem( | |
820 | &mut self, | |
821 | elem: PlaceElem<'tcx>, | |
822 | location: Location, | |
823 | ) -> Option<PlaceElem<'tcx>> { | |
824 | if let PlaceElem::Index(local) = elem { | |
825 | let offset = self.before_effect.get(&(location, local.into()))?; | |
826 | let offset = offset.try_to_scalar()?; | |
827 | let offset = offset.to_target_usize(&self.tcx).ok()?; | |
828 | let min_length = offset.checked_add(1)?; | |
829 | Some(PlaceElem::ConstantIndex { offset, min_length, from_end: false }) | |
830 | } else { | |
831 | None | |
832 | } | |
833 | } | |
834 | } | |
835 | ||
836 | struct OperandCollector<'tcx, 'map, 'locals, 'a> { | |
837 | state: &'a State<FlatSet<Scalar>>, | |
838 | visitor: &'a mut Collector<'tcx, 'locals>, | |
839 | ecx: &'map mut InterpCx<'tcx, DummyMachine>, | |
840 | map: &'map Map, | |
841 | } | |
842 | ||
843 | impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { | |
844 | fn visit_projection_elem( | |
845 | &mut self, | |
846 | _: PlaceRef<'tcx>, | |
847 | elem: PlaceElem<'tcx>, | |
848 | _: PlaceContext, | |
849 | location: Location, | |
850 | ) { | |
851 | if let PlaceElem::Index(local) = elem | |
852 | && let Some(value) = | |
853 | self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map) | |
854 | { | |
855 | self.visitor.patch.before_effect.insert((location, local.into()), value); | |
856 | } | |
857 | } | |
858 | ||
859 | fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { | |
860 | if let Some(place) = operand.place() { | |
861 | if let Some(value) = | |
862 | self.visitor.try_make_constant(self.ecx, place, self.state, self.map) | |
863 | { | |
864 | self.visitor.patch.before_effect.insert((location, place), value); | |
865 | } else if !place.projection.is_empty() { | |
866 | // Try to propagate into `Index` projections. | |
867 | self.super_operand(operand, location) | |
868 | } | |
869 | } | |
870 | } | |
871 | } |