]>
Commit | Line | Data |
---|---|---|
1 | //! A constant propagation optimization pass based on dataflow analysis. | |
2 | //! | |
3 | //! Currently, this pass only propagates scalar values. | |
4 | ||
5 | use rustc_const_eval::const_eval::{throw_machine_stop_str, DummyMachine}; | |
6 | use rustc_const_eval::interpret::{ImmTy, Immediate, InterpCx, OpTy, PlaceTy, Projectable}; | |
7 | use rustc_data_structures::fx::FxHashMap; | |
8 | use rustc_hir::def::DefKind; | |
9 | use rustc_middle::bug; | |
10 | use rustc_middle::mir::interpret::{InterpResult, Scalar}; | |
11 | use rustc_middle::mir::visit::{MutVisitor, PlaceContext, Visitor}; | |
12 | use rustc_middle::mir::*; | |
13 | use rustc_middle::ty::layout::{HasParamEnv, LayoutOf}; | |
14 | use rustc_middle::ty::{self, Ty, TyCtxt}; | |
15 | use rustc_mir_dataflow::lattice::FlatSet; | |
16 | use rustc_mir_dataflow::value_analysis::{ | |
17 | Map, PlaceIndex, State, TrackElem, ValueAnalysis, ValueAnalysisWrapper, ValueOrPlace, | |
18 | }; | |
19 | use rustc_mir_dataflow::{Analysis, Results, ResultsVisitor}; | |
20 | use rustc_span::DUMMY_SP; | |
21 | use rustc_target::abi::{Abi, FieldIdx, Size, VariantIdx, FIRST_VARIANT}; | |
22 | use tracing::{debug, debug_span, instrument}; | |
23 | ||
24 | // These constants are somewhat random guesses and have not been optimized. | |
25 | // If `tcx.sess.mir_opt_level() >= 4`, we ignore the limits (this can become very expensive). | |
26 | const BLOCK_LIMIT: usize = 100; | |
27 | const PLACE_LIMIT: usize = 100; | |
28 | ||
29 | pub struct DataflowConstProp; | |
30 | ||
31 | impl<'tcx> MirPass<'tcx> for DataflowConstProp { | |
32 | fn is_enabled(&self, sess: &rustc_session::Session) -> bool { | |
33 | sess.mir_opt_level() >= 3 | |
34 | } | |
35 | ||
36 | #[instrument(skip_all level = "debug")] | |
37 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
38 | debug!(def_id = ?body.source.def_id()); | |
39 | if tcx.sess.mir_opt_level() < 4 && body.basic_blocks.len() > BLOCK_LIMIT { | |
40 | debug!("aborted dataflow const prop due too many basic blocks"); | |
41 | return; | |
42 | } | |
43 | ||
44 | // We want to have a somewhat linear runtime w.r.t. the number of statements/terminators. | |
45 | // Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function | |
46 | // applications, where `h` is the height of the lattice. Because the height of our lattice | |
47 | // is linear w.r.t. the number of tracked places, this is `O(tracked_places * n)`. However, | |
48 | // because every transfer function application could traverse the whole map, this becomes | |
49 | // `O(num_nodes * tracked_places * n)` in terms of time complexity. Since the number of | |
50 | // map nodes is strongly correlated to the number of tracked places, this becomes more or | |
51 | // less `O(n)` if we place a constant limit on the number of tracked places. | |
52 | let place_limit = if tcx.sess.mir_opt_level() < 4 { Some(PLACE_LIMIT) } else { None }; | |
53 | ||
54 | // Decide which places to track during the analysis. | |
55 | let map = Map::new(tcx, body, place_limit); | |
56 | ||
57 | // Perform the actual dataflow analysis. | |
58 | let analysis = ConstAnalysis::new(tcx, body, map); | |
59 | let mut results = debug_span!("analyze") | |
60 | .in_scope(|| analysis.wrap().into_engine(tcx, body).iterate_to_fixpoint()); | |
61 | ||
62 | // Collect results and patch the body afterwards. | |
63 | let mut visitor = Collector::new(tcx, &body.local_decls); | |
64 | debug_span!("collect").in_scope(|| results.visit_reachable_with(body, &mut visitor)); | |
65 | let mut patch = visitor.patch; | |
66 | debug_span!("patch").in_scope(|| patch.visit_body_preserves_cfg(body)); | |
67 | } | |
68 | } | |
69 | ||
70 | struct ConstAnalysis<'a, 'tcx> { | |
71 | map: Map<'tcx>, | |
72 | tcx: TyCtxt<'tcx>, | |
73 | local_decls: &'a LocalDecls<'tcx>, | |
74 | ecx: InterpCx<'tcx, DummyMachine>, | |
75 | param_env: ty::ParamEnv<'tcx>, | |
76 | } | |
77 | ||
78 | impl<'tcx> ValueAnalysis<'tcx> for ConstAnalysis<'_, 'tcx> { | |
79 | type Value = FlatSet<Scalar>; | |
80 | ||
81 | const NAME: &'static str = "ConstAnalysis"; | |
82 | ||
83 | fn map(&self) -> &Map<'tcx> { | |
84 | &self.map | |
85 | } | |
86 | ||
87 | fn handle_set_discriminant( | |
88 | &self, | |
89 | place: Place<'tcx>, | |
90 | variant_index: VariantIdx, | |
91 | state: &mut State<Self::Value>, | |
92 | ) { | |
93 | state.flood_discr(place.as_ref(), &self.map); | |
94 | if self.map.find_discr(place.as_ref()).is_some() { | |
95 | let enum_ty = place.ty(self.local_decls, self.tcx).ty; | |
96 | if let Some(discr) = self.eval_discriminant(enum_ty, variant_index) { | |
97 | state.assign_discr( | |
98 | place.as_ref(), | |
99 | ValueOrPlace::Value(FlatSet::Elem(discr)), | |
100 | &self.map, | |
101 | ); | |
102 | } | |
103 | } | |
104 | } | |
105 | ||
106 | fn handle_assign( | |
107 | &self, | |
108 | target: Place<'tcx>, | |
109 | rvalue: &Rvalue<'tcx>, | |
110 | state: &mut State<Self::Value>, | |
111 | ) { | |
112 | match rvalue { | |
113 | Rvalue::Use(operand) => { | |
114 | state.flood(target.as_ref(), self.map()); | |
115 | if let Some(target) = self.map.find(target.as_ref()) { | |
116 | self.assign_operand(state, target, operand); | |
117 | } | |
118 | } | |
119 | Rvalue::CopyForDeref(rhs) => { | |
120 | state.flood(target.as_ref(), self.map()); | |
121 | if let Some(target) = self.map.find(target.as_ref()) { | |
122 | self.assign_operand(state, target, &Operand::Copy(*rhs)); | |
123 | } | |
124 | } | |
125 | Rvalue::Aggregate(kind, operands) => { | |
126 | // If we assign `target = Enum::Variant#0(operand)`, | |
127 | // we must make sure that all `target as Variant#i` are `Top`. | |
128 | state.flood(target.as_ref(), self.map()); | |
129 | ||
130 | let Some(target_idx) = self.map().find(target.as_ref()) else { return }; | |
131 | ||
132 | let (variant_target, variant_index) = match **kind { | |
133 | AggregateKind::Tuple | AggregateKind::Closure(..) => (Some(target_idx), None), | |
134 | AggregateKind::Adt(def_id, variant_index, ..) => { | |
135 | match self.tcx.def_kind(def_id) { | |
136 | DefKind::Struct => (Some(target_idx), None), | |
137 | DefKind::Enum => ( | |
138 | self.map.apply(target_idx, TrackElem::Variant(variant_index)), | |
139 | Some(variant_index), | |
140 | ), | |
141 | _ => return, | |
142 | } | |
143 | } | |
144 | _ => return, | |
145 | }; | |
146 | if let Some(variant_target_idx) = variant_target { | |
147 | for (field_index, operand) in operands.iter_enumerated() { | |
148 | if let Some(field) = | |
149 | self.map().apply(variant_target_idx, TrackElem::Field(field_index)) | |
150 | { | |
151 | self.assign_operand(state, field, operand); | |
152 | } | |
153 | } | |
154 | } | |
155 | if let Some(variant_index) = variant_index | |
156 | && let Some(discr_idx) = self.map().apply(target_idx, TrackElem::Discriminant) | |
157 | { | |
158 | // We are assigning the discriminant as part of an aggregate. | |
159 | // This discriminant can only alias a variant field's value if the operand | |
160 | // had an invalid value for that type. | |
161 | // Using invalid values is UB, so we are allowed to perform the assignment | |
162 | // without extra flooding. | |
163 | let enum_ty = target.ty(self.local_decls, self.tcx).ty; | |
164 | if let Some(discr_val) = self.eval_discriminant(enum_ty, variant_index) { | |
165 | state.insert_value_idx(discr_idx, FlatSet::Elem(discr_val), &self.map); | |
166 | } | |
167 | } | |
168 | } | |
169 | Rvalue::BinaryOp(op, box (left, right)) if op.is_overflowing() => { | |
170 | // Flood everything now, so we can use `insert_value_idx` directly later. | |
171 | state.flood(target.as_ref(), self.map()); | |
172 | ||
173 | let Some(target) = self.map().find(target.as_ref()) else { return }; | |
174 | ||
175 | let value_target = self.map().apply(target, TrackElem::Field(0_u32.into())); | |
176 | let overflow_target = self.map().apply(target, TrackElem::Field(1_u32.into())); | |
177 | ||
178 | if value_target.is_some() || overflow_target.is_some() { | |
179 | let (val, overflow) = self.binary_op(state, *op, left, right); | |
180 | ||
181 | if let Some(value_target) = value_target { | |
182 | // We have flooded `target` earlier. | |
183 | state.insert_value_idx(value_target, val, self.map()); | |
184 | } | |
185 | if let Some(overflow_target) = overflow_target { | |
186 | // We have flooded `target` earlier. | |
187 | state.insert_value_idx(overflow_target, overflow, self.map()); | |
188 | } | |
189 | } | |
190 | } | |
191 | Rvalue::Cast( | |
192 | CastKind::PointerCoercion(ty::adjustment::PointerCoercion::Unsize), | |
193 | operand, | |
194 | _, | |
195 | ) => { | |
196 | let pointer = self.handle_operand(operand, state); | |
197 | state.assign(target.as_ref(), pointer, self.map()); | |
198 | ||
199 | if let Some(target_len) = self.map().find_len(target.as_ref()) | |
200 | && let operand_ty = operand.ty(self.local_decls, self.tcx) | |
201 | && let Some(operand_ty) = operand_ty.builtin_deref(true) | |
202 | && let ty::Array(_, len) = operand_ty.kind() | |
203 | && let Some(len) = Const::Ty(self.tcx.types.usize, *len) | |
204 | .try_eval_scalar_int(self.tcx, self.param_env) | |
205 | { | |
206 | state.insert_value_idx(target_len, FlatSet::Elem(len.into()), self.map()); | |
207 | } | |
208 | } | |
209 | _ => self.super_assign(target, rvalue, state), | |
210 | } | |
211 | } | |
212 | ||
213 | fn handle_rvalue( | |
214 | &self, | |
215 | rvalue: &Rvalue<'tcx>, | |
216 | state: &mut State<Self::Value>, | |
217 | ) -> ValueOrPlace<Self::Value> { | |
218 | let val = match rvalue { | |
219 | Rvalue::Len(place) => { | |
220 | let place_ty = place.ty(self.local_decls, self.tcx); | |
221 | if let ty::Array(_, len) = place_ty.ty.kind() { | |
222 | Const::Ty(self.tcx.types.usize, *len) | |
223 | .try_eval_scalar(self.tcx, self.param_env) | |
224 | .map_or(FlatSet::Top, FlatSet::Elem) | |
225 | } else if let [ProjectionElem::Deref] = place.projection[..] { | |
226 | state.get_len(place.local.into(), self.map()) | |
227 | } else { | |
228 | FlatSet::Top | |
229 | } | |
230 | } | |
231 | Rvalue::Cast(CastKind::IntToInt | CastKind::IntToFloat, operand, ty) => { | |
232 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
233 | return ValueOrPlace::Value(FlatSet::Top); | |
234 | }; | |
235 | match self.eval_operand(operand, state) { | |
236 | FlatSet::Elem(op) => self | |
237 | .ecx | |
238 | .int_to_int_or_float(&op, layout) | |
239 | .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), | |
240 | FlatSet::Bottom => FlatSet::Bottom, | |
241 | FlatSet::Top => FlatSet::Top, | |
242 | } | |
243 | } | |
244 | Rvalue::Cast(CastKind::FloatToInt | CastKind::FloatToFloat, operand, ty) => { | |
245 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
246 | return ValueOrPlace::Value(FlatSet::Top); | |
247 | }; | |
248 | match self.eval_operand(operand, state) { | |
249 | FlatSet::Elem(op) => self | |
250 | .ecx | |
251 | .float_to_float_or_int(&op, layout) | |
252 | .map_or(FlatSet::Top, |result| self.wrap_immediate(*result)), | |
253 | FlatSet::Bottom => FlatSet::Bottom, | |
254 | FlatSet::Top => FlatSet::Top, | |
255 | } | |
256 | } | |
257 | Rvalue::Cast(CastKind::Transmute, operand, _) => { | |
258 | match self.eval_operand(operand, state) { | |
259 | FlatSet::Elem(op) => self.wrap_immediate(*op), | |
260 | FlatSet::Bottom => FlatSet::Bottom, | |
261 | FlatSet::Top => FlatSet::Top, | |
262 | } | |
263 | } | |
264 | Rvalue::BinaryOp(op, box (left, right)) if !op.is_overflowing() => { | |
265 | // Overflows must be ignored here. | |
266 | // The overflowing operators are handled in `handle_assign`. | |
267 | let (val, _overflow) = self.binary_op(state, *op, left, right); | |
268 | val | |
269 | } | |
270 | Rvalue::UnaryOp(op, operand) => match self.eval_operand(operand, state) { | |
271 | FlatSet::Elem(value) => self | |
272 | .ecx | |
273 | .unary_op(*op, &value) | |
274 | .map_or(FlatSet::Top, |val| self.wrap_immediate(*val)), | |
275 | FlatSet::Bottom => FlatSet::Bottom, | |
276 | FlatSet::Top => FlatSet::Top, | |
277 | }, | |
278 | Rvalue::NullaryOp(null_op, ty) => { | |
279 | let Ok(layout) = self.tcx.layout_of(self.param_env.and(*ty)) else { | |
280 | return ValueOrPlace::Value(FlatSet::Top); | |
281 | }; | |
282 | let val = match null_op { | |
283 | NullOp::SizeOf if layout.is_sized() => layout.size.bytes(), | |
284 | NullOp::AlignOf if layout.is_sized() => layout.align.abi.bytes(), | |
285 | NullOp::OffsetOf(fields) => self | |
286 | .ecx | |
287 | .tcx | |
288 | .offset_of_subfield(self.ecx.param_env(), layout, fields.iter()) | |
289 | .bytes(), | |
290 | _ => return ValueOrPlace::Value(FlatSet::Top), | |
291 | }; | |
292 | FlatSet::Elem(Scalar::from_target_usize(val, &self.tcx)) | |
293 | } | |
294 | Rvalue::Discriminant(place) => state.get_discr(place.as_ref(), self.map()), | |
295 | _ => return self.super_rvalue(rvalue, state), | |
296 | }; | |
297 | ValueOrPlace::Value(val) | |
298 | } | |
299 | ||
300 | fn handle_constant( | |
301 | &self, | |
302 | constant: &ConstOperand<'tcx>, | |
303 | _state: &mut State<Self::Value>, | |
304 | ) -> Self::Value { | |
305 | constant | |
306 | .const_ | |
307 | .try_eval_scalar(self.tcx, self.param_env) | |
308 | .map_or(FlatSet::Top, FlatSet::Elem) | |
309 | } | |
310 | ||
311 | fn handle_switch_int<'mir>( | |
312 | &self, | |
313 | discr: &'mir Operand<'tcx>, | |
314 | targets: &'mir SwitchTargets, | |
315 | state: &mut State<Self::Value>, | |
316 | ) -> TerminatorEdges<'mir, 'tcx> { | |
317 | let value = match self.handle_operand(discr, state) { | |
318 | ValueOrPlace::Value(value) => value, | |
319 | ValueOrPlace::Place(place) => state.get_idx(place, self.map()), | |
320 | }; | |
321 | match value { | |
322 | // We are branching on uninitialized data, this is UB, treat it as unreachable. | |
323 | // This allows the set of visited edges to grow monotonically with the lattice. | |
324 | FlatSet::Bottom => TerminatorEdges::None, | |
325 | FlatSet::Elem(scalar) => { | |
326 | let choice = scalar.assert_scalar_int().to_bits_unchecked(); | |
327 | TerminatorEdges::Single(targets.target_for_value(choice)) | |
328 | } | |
329 | FlatSet::Top => TerminatorEdges::SwitchInt { discr, targets }, | |
330 | } | |
331 | } | |
332 | } | |
333 | ||
334 | impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> { | |
335 | pub fn new(tcx: TyCtxt<'tcx>, body: &'a Body<'tcx>, map: Map<'tcx>) -> Self { | |
336 | let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id()); | |
337 | Self { | |
338 | map, | |
339 | tcx, | |
340 | local_decls: &body.local_decls, | |
341 | ecx: InterpCx::new(tcx, DUMMY_SP, param_env, DummyMachine), | |
342 | param_env, | |
343 | } | |
344 | } | |
345 | ||
346 | /// The caller must have flooded `place`. | |
347 | fn assign_operand( | |
348 | &self, | |
349 | state: &mut State<FlatSet<Scalar>>, | |
350 | place: PlaceIndex, | |
351 | operand: &Operand<'tcx>, | |
352 | ) { | |
353 | match operand { | |
354 | Operand::Copy(rhs) | Operand::Move(rhs) => { | |
355 | if let Some(rhs) = self.map.find(rhs.as_ref()) { | |
356 | state.insert_place_idx(place, rhs, &self.map); | |
357 | } else if rhs.projection.first() == Some(&PlaceElem::Deref) | |
358 | && let FlatSet::Elem(pointer) = state.get(rhs.local.into(), &self.map) | |
359 | && let rhs_ty = self.local_decls[rhs.local].ty | |
360 | && let Ok(rhs_layout) = self.tcx.layout_of(self.param_env.and(rhs_ty)) | |
361 | { | |
362 | let op = ImmTy::from_scalar(pointer, rhs_layout).into(); | |
363 | self.assign_constant(state, place, op, rhs.projection); | |
364 | } | |
365 | } | |
366 | Operand::Constant(box constant) => { | |
367 | if let Ok(constant) = | |
368 | self.ecx.eval_mir_constant(&constant.const_, constant.span, None) | |
369 | { | |
370 | self.assign_constant(state, place, constant, &[]); | |
371 | } | |
372 | } | |
373 | } | |
374 | } | |
375 | ||
376 | /// The caller must have flooded `place`. | |
377 | /// | |
378 | /// Perform: `place = operand.projection`. | |
379 | #[instrument(level = "trace", skip(self, state))] | |
380 | fn assign_constant( | |
381 | &self, | |
382 | state: &mut State<FlatSet<Scalar>>, | |
383 | place: PlaceIndex, | |
384 | mut operand: OpTy<'tcx>, | |
385 | projection: &[PlaceElem<'tcx>], | |
386 | ) { | |
387 | for &(mut proj_elem) in projection { | |
388 | if let PlaceElem::Index(index) = proj_elem { | |
389 | if let FlatSet::Elem(index) = state.get(index.into(), &self.map) | |
390 | && let Ok(offset) = index.to_target_usize(&self.tcx) | |
391 | && let Some(min_length) = offset.checked_add(1) | |
392 | { | |
393 | proj_elem = PlaceElem::ConstantIndex { offset, min_length, from_end: false }; | |
394 | } else { | |
395 | return; | |
396 | } | |
397 | } | |
398 | operand = if let Ok(operand) = self.ecx.project(&operand, proj_elem) { | |
399 | operand | |
400 | } else { | |
401 | return; | |
402 | } | |
403 | } | |
404 | ||
405 | self.map.for_each_projection_value( | |
406 | place, | |
407 | operand, | |
408 | &mut |elem, op| match elem { | |
409 | TrackElem::Field(idx) => self.ecx.project_field(op, idx.as_usize()).ok(), | |
410 | TrackElem::Variant(idx) => self.ecx.project_downcast(op, idx).ok(), | |
411 | TrackElem::Discriminant => { | |
412 | let variant = self.ecx.read_discriminant(op).ok()?; | |
413 | let discr_value = | |
414 | self.ecx.discriminant_for_variant(op.layout.ty, variant).ok()?; | |
415 | Some(discr_value.into()) | |
416 | } | |
417 | TrackElem::DerefLen => { | |
418 | let op: OpTy<'_> = self.ecx.deref_pointer(op).ok()?.into(); | |
419 | let len_usize = op.len(&self.ecx).ok()?; | |
420 | let layout = | |
421 | self.tcx.layout_of(self.param_env.and(self.tcx.types.usize)).unwrap(); | |
422 | Some(ImmTy::from_uint(len_usize, layout).into()) | |
423 | } | |
424 | }, | |
425 | &mut |place, op| { | |
426 | if let Ok(imm) = self.ecx.read_immediate_raw(op) | |
427 | && let Some(imm) = imm.right() | |
428 | { | |
429 | let elem = self.wrap_immediate(*imm); | |
430 | state.insert_value_idx(place, elem, &self.map); | |
431 | } | |
432 | }, | |
433 | ); | |
434 | } | |
435 | ||
436 | fn binary_op( | |
437 | &self, | |
438 | state: &mut State<FlatSet<Scalar>>, | |
439 | op: BinOp, | |
440 | left: &Operand<'tcx>, | |
441 | right: &Operand<'tcx>, | |
442 | ) -> (FlatSet<Scalar>, FlatSet<Scalar>) { | |
443 | let left = self.eval_operand(left, state); | |
444 | let right = self.eval_operand(right, state); | |
445 | ||
446 | match (left, right) { | |
447 | (FlatSet::Bottom, _) | (_, FlatSet::Bottom) => (FlatSet::Bottom, FlatSet::Bottom), | |
448 | // Both sides are known, do the actual computation. | |
449 | (FlatSet::Elem(left), FlatSet::Elem(right)) => { | |
450 | match self.ecx.binary_op(op, &left, &right) { | |
451 | // Ideally this would return an Immediate, since it's sometimes | |
452 | // a pair and sometimes not. But as a hack we always return a pair | |
453 | // and just make the 2nd component `Bottom` when it does not exist. | |
454 | Ok(val) => { | |
455 | if matches!(val.layout.abi, Abi::ScalarPair(..)) { | |
456 | let (val, overflow) = val.to_scalar_pair(); | |
457 | (FlatSet::Elem(val), FlatSet::Elem(overflow)) | |
458 | } else { | |
459 | (FlatSet::Elem(val.to_scalar()), FlatSet::Bottom) | |
460 | } | |
461 | } | |
462 | _ => (FlatSet::Top, FlatSet::Top), | |
463 | } | |
464 | } | |
465 | // Exactly one side is known, attempt some algebraic simplifications. | |
466 | (FlatSet::Elem(const_arg), _) | (_, FlatSet::Elem(const_arg)) => { | |
467 | let layout = const_arg.layout; | |
468 | if !matches!(layout.abi, rustc_target::abi::Abi::Scalar(..)) { | |
469 | return (FlatSet::Top, FlatSet::Top); | |
470 | } | |
471 | ||
472 | let arg_scalar = const_arg.to_scalar(); | |
473 | let Ok(arg_value) = arg_scalar.to_bits(layout.size) else { | |
474 | return (FlatSet::Top, FlatSet::Top); | |
475 | }; | |
476 | ||
477 | match op { | |
478 | BinOp::BitAnd if arg_value == 0 => (FlatSet::Elem(arg_scalar), FlatSet::Bottom), | |
479 | BinOp::BitOr | |
480 | if arg_value == layout.size.truncate(u128::MAX) | |
481 | || (layout.ty.is_bool() && arg_value == 1) => | |
482 | { | |
483 | (FlatSet::Elem(arg_scalar), FlatSet::Bottom) | |
484 | } | |
485 | BinOp::Mul if layout.ty.is_integral() && arg_value == 0 => { | |
486 | (FlatSet::Elem(arg_scalar), FlatSet::Elem(Scalar::from_bool(false))) | |
487 | } | |
488 | _ => (FlatSet::Top, FlatSet::Top), | |
489 | } | |
490 | } | |
491 | (FlatSet::Top, FlatSet::Top) => (FlatSet::Top, FlatSet::Top), | |
492 | } | |
493 | } | |
494 | ||
495 | fn eval_operand( | |
496 | &self, | |
497 | op: &Operand<'tcx>, | |
498 | state: &mut State<FlatSet<Scalar>>, | |
499 | ) -> FlatSet<ImmTy<'tcx>> { | |
500 | let value = match self.handle_operand(op, state) { | |
501 | ValueOrPlace::Value(value) => value, | |
502 | ValueOrPlace::Place(place) => state.get_idx(place, &self.map), | |
503 | }; | |
504 | match value { | |
505 | FlatSet::Top => FlatSet::Top, | |
506 | FlatSet::Elem(scalar) => { | |
507 | let ty = op.ty(self.local_decls, self.tcx); | |
508 | self.tcx.layout_of(self.param_env.and(ty)).map_or(FlatSet::Top, |layout| { | |
509 | FlatSet::Elem(ImmTy::from_scalar(scalar, layout)) | |
510 | }) | |
511 | } | |
512 | FlatSet::Bottom => FlatSet::Bottom, | |
513 | } | |
514 | } | |
515 | ||
516 | fn eval_discriminant(&self, enum_ty: Ty<'tcx>, variant_index: VariantIdx) -> Option<Scalar> { | |
517 | if !enum_ty.is_enum() { | |
518 | return None; | |
519 | } | |
520 | let enum_ty_layout = self.tcx.layout_of(self.param_env.and(enum_ty)).ok()?; | |
521 | let discr_value = | |
522 | self.ecx.discriminant_for_variant(enum_ty_layout.ty, variant_index).ok()?; | |
523 | Some(discr_value.to_scalar()) | |
524 | } | |
525 | ||
526 | fn wrap_immediate(&self, imm: Immediate) -> FlatSet<Scalar> { | |
527 | match imm { | |
528 | Immediate::Scalar(scalar) => FlatSet::Elem(scalar), | |
529 | Immediate::Uninit => FlatSet::Bottom, | |
530 | _ => FlatSet::Top, | |
531 | } | |
532 | } | |
533 | } | |
534 | ||
535 | pub(crate) struct Patch<'tcx> { | |
536 | tcx: TyCtxt<'tcx>, | |
537 | ||
538 | /// For a given MIR location, this stores the values of the operands used by that location. In | |
539 | /// particular, this is before the effect, such that the operands of `_1 = _1 + _2` are | |
540 | /// properly captured. (This may become UB soon, but it is currently emitted even by safe code.) | |
541 | pub(crate) before_effect: FxHashMap<(Location, Place<'tcx>), Const<'tcx>>, | |
542 | ||
543 | /// Stores the assigned values for assignments where the Rvalue is constant. | |
544 | pub(crate) assignments: FxHashMap<Location, Const<'tcx>>, | |
545 | } | |
546 | ||
547 | impl<'tcx> Patch<'tcx> { | |
548 | pub(crate) fn new(tcx: TyCtxt<'tcx>) -> Self { | |
549 | Self { tcx, before_effect: FxHashMap::default(), assignments: FxHashMap::default() } | |
550 | } | |
551 | ||
552 | fn make_operand(&self, const_: Const<'tcx>) -> Operand<'tcx> { | |
553 | Operand::Constant(Box::new(ConstOperand { span: DUMMY_SP, user_ty: None, const_ })) | |
554 | } | |
555 | } | |
556 | ||
557 | struct Collector<'tcx, 'locals> { | |
558 | patch: Patch<'tcx>, | |
559 | local_decls: &'locals LocalDecls<'tcx>, | |
560 | } | |
561 | ||
562 | impl<'tcx, 'locals> Collector<'tcx, 'locals> { | |
563 | pub(crate) fn new(tcx: TyCtxt<'tcx>, local_decls: &'locals LocalDecls<'tcx>) -> Self { | |
564 | Self { patch: Patch::new(tcx), local_decls } | |
565 | } | |
566 | ||
567 | #[instrument(level = "trace", skip(self, ecx, map), ret)] | |
568 | fn try_make_constant( | |
569 | &self, | |
570 | ecx: &mut InterpCx<'tcx, DummyMachine>, | |
571 | place: Place<'tcx>, | |
572 | state: &State<FlatSet<Scalar>>, | |
573 | map: &Map<'tcx>, | |
574 | ) -> Option<Const<'tcx>> { | |
575 | let ty = place.ty(self.local_decls, self.patch.tcx).ty; | |
576 | let layout = ecx.layout_of(ty).ok()?; | |
577 | ||
578 | if layout.is_zst() { | |
579 | return Some(Const::zero_sized(ty)); | |
580 | } | |
581 | ||
582 | if layout.is_unsized() { | |
583 | return None; | |
584 | } | |
585 | ||
586 | let place = map.find(place.as_ref())?; | |
587 | if layout.abi.is_scalar() | |
588 | && let Some(value) = propagatable_scalar(place, state, map) | |
589 | { | |
590 | return Some(Const::Val(ConstValue::Scalar(value), ty)); | |
591 | } | |
592 | ||
593 | if matches!(layout.abi, Abi::Scalar(..) | Abi::ScalarPair(..)) { | |
594 | let alloc_id = ecx | |
595 | .intern_with_temp_alloc(layout, |ecx, dest| { | |
596 | try_write_constant(ecx, dest, place, ty, state, map) | |
597 | }) | |
598 | .ok()?; | |
599 | return Some(Const::Val(ConstValue::Indirect { alloc_id, offset: Size::ZERO }, ty)); | |
600 | } | |
601 | ||
602 | None | |
603 | } | |
604 | } | |
605 | ||
606 | #[instrument(level = "trace", skip(map), ret)] | |
607 | fn propagatable_scalar( | |
608 | place: PlaceIndex, | |
609 | state: &State<FlatSet<Scalar>>, | |
610 | map: &Map<'_>, | |
611 | ) -> Option<Scalar> { | |
612 | if let FlatSet::Elem(value) = state.get_idx(place, map) | |
613 | && value.try_to_scalar_int().is_ok() | |
614 | { | |
615 | // Do not attempt to propagate pointers, as we may fail to preserve their identity. | |
616 | Some(value) | |
617 | } else { | |
618 | None | |
619 | } | |
620 | } | |
621 | ||
622 | #[instrument(level = "trace", skip(ecx, state, map), ret)] | |
623 | fn try_write_constant<'tcx>( | |
624 | ecx: &mut InterpCx<'tcx, DummyMachine>, | |
625 | dest: &PlaceTy<'tcx>, | |
626 | place: PlaceIndex, | |
627 | ty: Ty<'tcx>, | |
628 | state: &State<FlatSet<Scalar>>, | |
629 | map: &Map<'tcx>, | |
630 | ) -> InterpResult<'tcx> { | |
631 | let layout = ecx.layout_of(ty)?; | |
632 | ||
633 | // Fast path for ZSTs. | |
634 | if layout.is_zst() { | |
635 | return Ok(()); | |
636 | } | |
637 | ||
638 | // Fast path for scalars. | |
639 | if layout.abi.is_scalar() | |
640 | && let Some(value) = propagatable_scalar(place, state, map) | |
641 | { | |
642 | return ecx.write_immediate(Immediate::Scalar(value), dest); | |
643 | } | |
644 | ||
645 | match ty.kind() { | |
646 | // ZSTs. Nothing to do. | |
647 | ty::FnDef(..) => {} | |
648 | ||
649 | // Those are scalars, must be handled above. | |
650 | ty::Bool | ty::Int(_) | ty::Uint(_) | ty::Float(_) | ty::Char => throw_machine_stop_str!("primitive type with provenance"), | |
651 | ||
652 | ty::Tuple(elem_tys) => { | |
653 | for (i, elem) in elem_tys.iter().enumerate() { | |
654 | let Some(field) = map.apply(place, TrackElem::Field(FieldIdx::from_usize(i))) else { | |
655 | throw_machine_stop_str!("missing field in tuple") | |
656 | }; | |
657 | let field_dest = ecx.project_field(dest, i)?; | |
658 | try_write_constant(ecx, &field_dest, field, elem, state, map)?; | |
659 | } | |
660 | } | |
661 | ||
662 | ty::Adt(def, args) => { | |
663 | if def.is_union() { | |
664 | throw_machine_stop_str!("cannot propagate unions") | |
665 | } | |
666 | ||
667 | let (variant_idx, variant_def, variant_place, variant_dest) = if def.is_enum() { | |
668 | let Some(discr) = map.apply(place, TrackElem::Discriminant) else { | |
669 | throw_machine_stop_str!("missing discriminant for enum") | |
670 | }; | |
671 | let FlatSet::Elem(Scalar::Int(discr)) = state.get_idx(discr, map) else { | |
672 | throw_machine_stop_str!("discriminant with provenance") | |
673 | }; | |
674 | let discr_bits = discr.to_bits(discr.size()); | |
675 | let Some((variant, _)) = def.discriminants(*ecx.tcx).find(|(_, var)| discr_bits == var.val) else { | |
676 | throw_machine_stop_str!("illegal discriminant for enum") | |
677 | }; | |
678 | let Some(variant_place) = map.apply(place, TrackElem::Variant(variant)) else { | |
679 | throw_machine_stop_str!("missing variant for enum") | |
680 | }; | |
681 | let variant_dest = ecx.project_downcast(dest, variant)?; | |
682 | (variant, def.variant(variant), variant_place, variant_dest) | |
683 | } else { | |
684 | (FIRST_VARIANT, def.non_enum_variant(), place, dest.clone()) | |
685 | }; | |
686 | ||
687 | for (i, field) in variant_def.fields.iter_enumerated() { | |
688 | let ty = field.ty(*ecx.tcx, args); | |
689 | let Some(field) = map.apply(variant_place, TrackElem::Field(i)) else { | |
690 | throw_machine_stop_str!("missing field in ADT") | |
691 | }; | |
692 | let field_dest = ecx.project_field(&variant_dest, i.as_usize())?; | |
693 | try_write_constant(ecx, &field_dest, field, ty, state, map)?; | |
694 | } | |
695 | ecx.write_discriminant(variant_idx, dest)?; | |
696 | } | |
697 | ||
698 | // Unsupported for now. | |
699 | ty::Array(_, _) | |
700 | | ty::Pat(_, _) | |
701 | ||
702 | // Do not attempt to support indirection in constants. | |
703 | | ty::Ref(..) | ty::RawPtr(..) | ty::FnPtr(..) | ty::Str | ty::Slice(_) | |
704 | ||
705 | | ty::Never | |
706 | | ty::Foreign(..) | |
707 | | ty::Alias(..) | |
708 | | ty::Param(_) | |
709 | | ty::Bound(..) | |
710 | | ty::Placeholder(..) | |
711 | | ty::Closure(..) | |
712 | | ty::CoroutineClosure(..) | |
713 | | ty::Coroutine(..) | |
714 | | ty::Dynamic(..) => throw_machine_stop_str!("unsupported type"), | |
715 | ||
716 | ty::Error(_) | ty::Infer(..) | ty::CoroutineWitness(..) => bug!(), | |
717 | } | |
718 | ||
719 | Ok(()) | |
720 | } | |
721 | ||
722 | impl<'mir, 'tcx> | |
723 | ResultsVisitor<'mir, 'tcx, Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>> | |
724 | for Collector<'tcx, '_> | |
725 | { | |
726 | type FlowState = State<FlatSet<Scalar>>; | |
727 | ||
728 | #[instrument(level = "trace", skip(self, results, statement))] | |
729 | fn visit_statement_before_primary_effect( | |
730 | &mut self, | |
731 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
732 | state: &Self::FlowState, | |
733 | statement: &'mir Statement<'tcx>, | |
734 | location: Location, | |
735 | ) { | |
736 | match &statement.kind { | |
737 | StatementKind::Assign(box (_, rvalue)) => { | |
738 | OperandCollector { | |
739 | state, | |
740 | visitor: self, | |
741 | ecx: &mut results.analysis.0.ecx, | |
742 | map: &results.analysis.0.map, | |
743 | } | |
744 | .visit_rvalue(rvalue, location); | |
745 | } | |
746 | _ => (), | |
747 | } | |
748 | } | |
749 | ||
750 | #[instrument(level = "trace", skip(self, results, statement))] | |
751 | fn visit_statement_after_primary_effect( | |
752 | &mut self, | |
753 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
754 | state: &Self::FlowState, | |
755 | statement: &'mir Statement<'tcx>, | |
756 | location: Location, | |
757 | ) { | |
758 | match statement.kind { | |
759 | StatementKind::Assign(box (_, Rvalue::Use(Operand::Constant(_)))) => { | |
760 | // Don't overwrite the assignment if it already uses a constant (to keep the span). | |
761 | } | |
762 | StatementKind::Assign(box (place, _)) => { | |
763 | if let Some(value) = self.try_make_constant( | |
764 | &mut results.analysis.0.ecx, | |
765 | place, | |
766 | state, | |
767 | &results.analysis.0.map, | |
768 | ) { | |
769 | self.patch.assignments.insert(location, value); | |
770 | } | |
771 | } | |
772 | _ => (), | |
773 | } | |
774 | } | |
775 | ||
776 | fn visit_terminator_before_primary_effect( | |
777 | &mut self, | |
778 | results: &mut Results<'tcx, ValueAnalysisWrapper<ConstAnalysis<'_, 'tcx>>>, | |
779 | state: &Self::FlowState, | |
780 | terminator: &'mir Terminator<'tcx>, | |
781 | location: Location, | |
782 | ) { | |
783 | OperandCollector { | |
784 | state, | |
785 | visitor: self, | |
786 | ecx: &mut results.analysis.0.ecx, | |
787 | map: &results.analysis.0.map, | |
788 | } | |
789 | .visit_terminator(terminator, location); | |
790 | } | |
791 | } | |
792 | ||
793 | impl<'tcx> MutVisitor<'tcx> for Patch<'tcx> { | |
794 | fn tcx(&self) -> TyCtxt<'tcx> { | |
795 | self.tcx | |
796 | } | |
797 | ||
798 | fn visit_statement(&mut self, statement: &mut Statement<'tcx>, location: Location) { | |
799 | if let Some(value) = self.assignments.get(&location) { | |
800 | match &mut statement.kind { | |
801 | StatementKind::Assign(box (_, rvalue)) => { | |
802 | *rvalue = Rvalue::Use(self.make_operand(*value)); | |
803 | } | |
804 | _ => bug!("found assignment info for non-assign statement"), | |
805 | } | |
806 | } else { | |
807 | self.super_statement(statement, location); | |
808 | } | |
809 | } | |
810 | ||
811 | fn visit_operand(&mut self, operand: &mut Operand<'tcx>, location: Location) { | |
812 | match operand { | |
813 | Operand::Copy(place) | Operand::Move(place) => { | |
814 | if let Some(value) = self.before_effect.get(&(location, *place)) { | |
815 | *operand = self.make_operand(*value); | |
816 | } else if !place.projection.is_empty() { | |
817 | self.super_operand(operand, location) | |
818 | } | |
819 | } | |
820 | Operand::Constant(_) => {} | |
821 | } | |
822 | } | |
823 | ||
824 | fn process_projection_elem( | |
825 | &mut self, | |
826 | elem: PlaceElem<'tcx>, | |
827 | location: Location, | |
828 | ) -> Option<PlaceElem<'tcx>> { | |
829 | if let PlaceElem::Index(local) = elem { | |
830 | let offset = self.before_effect.get(&(location, local.into()))?; | |
831 | let offset = offset.try_to_scalar()?; | |
832 | let offset = offset.to_target_usize(&self.tcx).ok()?; | |
833 | let min_length = offset.checked_add(1)?; | |
834 | Some(PlaceElem::ConstantIndex { offset, min_length, from_end: false }) | |
835 | } else { | |
836 | None | |
837 | } | |
838 | } | |
839 | } | |
840 | ||
841 | struct OperandCollector<'tcx, 'map, 'locals, 'a> { | |
842 | state: &'a State<FlatSet<Scalar>>, | |
843 | visitor: &'a mut Collector<'tcx, 'locals>, | |
844 | ecx: &'map mut InterpCx<'tcx, DummyMachine>, | |
845 | map: &'map Map<'tcx>, | |
846 | } | |
847 | ||
848 | impl<'tcx> Visitor<'tcx> for OperandCollector<'tcx, '_, '_, '_> { | |
849 | fn visit_projection_elem( | |
850 | &mut self, | |
851 | _: PlaceRef<'tcx>, | |
852 | elem: PlaceElem<'tcx>, | |
853 | _: PlaceContext, | |
854 | location: Location, | |
855 | ) { | |
856 | if let PlaceElem::Index(local) = elem | |
857 | && let Some(value) = | |
858 | self.visitor.try_make_constant(self.ecx, local.into(), self.state, self.map) | |
859 | { | |
860 | self.visitor.patch.before_effect.insert((location, local.into()), value); | |
861 | } | |
862 | } | |
863 | ||
864 | fn visit_operand(&mut self, operand: &Operand<'tcx>, location: Location) { | |
865 | if let Some(place) = operand.place() { | |
866 | if let Some(value) = | |
867 | self.visitor.try_make_constant(self.ecx, place, self.state, self.map) | |
868 | { | |
869 | self.visitor.patch.before_effect.insert((location, place), value); | |
870 | } else if !place.projection.is_empty() { | |
871 | // Try to propagate into `Index` projections. | |
872 | self.super_operand(operand, location) | |
873 | } | |
874 | } | |
875 | } | |
876 | } |