]>
Commit | Line | Data |
---|---|---|
353b0b11 FG |
1 | use crate::MirPass; |
2 | use rustc_hir::def_id::DefId; | |
3 | use rustc_hir::lang_items::LangItem; | |
4 | use rustc_index::vec::IndexVec; | |
5 | use rustc_middle::mir::*; | |
6 | use rustc_middle::mir::{ | |
7 | interpret::{ConstValue, Scalar}, | |
8 | visit::{PlaceContext, Visitor}, | |
9 | }; | |
10 | use rustc_middle::ty::{Ty, TyCtxt, TypeAndMut}; | |
11 | use rustc_session::Session; | |
12 | ||
13 | pub struct CheckAlignment; | |
14 | ||
15 | impl<'tcx> MirPass<'tcx> for CheckAlignment { | |
16 | fn is_enabled(&self, sess: &Session) -> bool { | |
17 | sess.opts.debug_assertions | |
18 | } | |
19 | ||
20 | fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) { | |
21 | // This pass emits new panics. If for whatever reason we do not have a panic | |
22 | // implementation, running this pass may cause otherwise-valid code to not compile. | |
23 | if tcx.lang_items().get(LangItem::PanicImpl).is_none() { | |
24 | return; | |
25 | } | |
26 | ||
27 | let basic_blocks = body.basic_blocks.as_mut(); | |
28 | let local_decls = &mut body.local_decls; | |
29 | ||
30 | for block in (0..basic_blocks.len()).rev() { | |
31 | let block = block.into(); | |
32 | for statement_index in (0..basic_blocks[block].statements.len()).rev() { | |
33 | let location = Location { block, statement_index }; | |
34 | let statement = &basic_blocks[block].statements[statement_index]; | |
35 | let source_info = statement.source_info; | |
36 | ||
37 | let mut finder = PointerFinder { | |
38 | local_decls, | |
39 | tcx, | |
40 | pointers: Vec::new(), | |
41 | def_id: body.source.def_id(), | |
42 | }; | |
43 | for (pointer, pointee_ty) in finder.find_pointers(statement) { | |
44 | debug!("Inserting alignment check for {:?}", pointer.ty(&*local_decls, tcx).ty); | |
45 | ||
46 | let new_block = split_block(basic_blocks, location); | |
47 | insert_alignment_check( | |
48 | tcx, | |
49 | local_decls, | |
50 | &mut basic_blocks[block], | |
51 | pointer, | |
52 | pointee_ty, | |
53 | source_info, | |
54 | new_block, | |
55 | ); | |
56 | } | |
57 | } | |
58 | } | |
59 | } | |
60 | } | |
61 | ||
62 | impl<'tcx, 'a> PointerFinder<'tcx, 'a> { | |
63 | fn find_pointers(&mut self, statement: &Statement<'tcx>) -> Vec<(Place<'tcx>, Ty<'tcx>)> { | |
64 | self.pointers.clear(); | |
65 | self.visit_statement(statement, Location::START); | |
66 | core::mem::take(&mut self.pointers) | |
67 | } | |
68 | } | |
69 | ||
70 | struct PointerFinder<'tcx, 'a> { | |
71 | local_decls: &'a mut LocalDecls<'tcx>, | |
72 | tcx: TyCtxt<'tcx>, | |
73 | def_id: DefId, | |
74 | pointers: Vec<(Place<'tcx>, Ty<'tcx>)>, | |
75 | } | |
76 | ||
77 | impl<'tcx, 'a> Visitor<'tcx> for PointerFinder<'tcx, 'a> { | |
78 | fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, location: Location) { | |
79 | if let Rvalue::AddressOf(..) = rvalue { | |
80 | // Ignore dereferences inside of an AddressOf | |
81 | return; | |
82 | } | |
83 | self.super_rvalue(rvalue, location); | |
84 | } | |
85 | ||
86 | fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, _location: Location) { | |
87 | if let PlaceContext::NonUse(_) = context { | |
88 | return; | |
89 | } | |
90 | if !place.is_indirect() { | |
91 | return; | |
92 | } | |
93 | ||
94 | let pointer = Place::from(place.local); | |
95 | let pointer_ty = pointer.ty(&*self.local_decls, self.tcx).ty; | |
96 | ||
97 | // We only want to check unsafe pointers | |
98 | if !pointer_ty.is_unsafe_ptr() { | |
99 | trace!("Indirect, but not an unsafe ptr, not checking {:?}", pointer_ty); | |
100 | return; | |
101 | } | |
102 | ||
103 | let Some(pointee) = pointer_ty.builtin_deref(true) else { | |
104 | debug!("Indirect but no builtin deref: {:?}", pointer_ty); | |
105 | return; | |
106 | }; | |
107 | let mut pointee_ty = pointee.ty; | |
108 | if pointee_ty.is_array() || pointee_ty.is_slice() || pointee_ty.is_str() { | |
109 | pointee_ty = pointee_ty.sequence_element_type(self.tcx); | |
110 | } | |
111 | ||
112 | if !pointee_ty.is_sized(self.tcx, self.tcx.param_env_reveal_all_normalized(self.def_id)) { | |
113 | debug!("Unsafe pointer, but unsized: {:?}", pointer_ty); | |
114 | return; | |
115 | } | |
116 | ||
117 | if [self.tcx.types.bool, self.tcx.types.i8, self.tcx.types.u8, self.tcx.types.str_] | |
118 | .contains(&pointee_ty) | |
119 | { | |
120 | debug!("Trivially aligned pointee type: {:?}", pointer_ty); | |
121 | return; | |
122 | } | |
123 | ||
124 | self.pointers.push((pointer, pointee_ty)) | |
125 | } | |
126 | } | |
127 | ||
128 | fn split_block( | |
129 | basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>, | |
130 | location: Location, | |
131 | ) -> BasicBlock { | |
132 | let block_data = &mut basic_blocks[location.block]; | |
133 | ||
134 | // Drain every statement after this one and move the current terminator to a new basic block | |
135 | let new_block = BasicBlockData { | |
136 | statements: block_data.statements.split_off(location.statement_index), | |
137 | terminator: block_data.terminator.take(), | |
138 | is_cleanup: block_data.is_cleanup, | |
139 | }; | |
140 | ||
141 | basic_blocks.push(new_block) | |
142 | } | |
143 | ||
144 | fn insert_alignment_check<'tcx>( | |
145 | tcx: TyCtxt<'tcx>, | |
146 | local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>, | |
147 | block_data: &mut BasicBlockData<'tcx>, | |
148 | pointer: Place<'tcx>, | |
149 | pointee_ty: Ty<'tcx>, | |
150 | source_info: SourceInfo, | |
151 | new_block: BasicBlock, | |
152 | ) { | |
153 | // Cast the pointer to a *const () | |
154 | let const_raw_ptr = tcx.mk_ptr(TypeAndMut { ty: tcx.types.unit, mutbl: Mutability::Not }); | |
155 | let rvalue = Rvalue::Cast(CastKind::PtrToPtr, Operand::Copy(pointer), const_raw_ptr); | |
156 | let thin_ptr = local_decls.push(LocalDecl::with_source_info(const_raw_ptr, source_info)).into(); | |
157 | block_data | |
158 | .statements | |
159 | .push(Statement { source_info, kind: StatementKind::Assign(Box::new((thin_ptr, rvalue))) }); | |
160 | ||
161 | // Transmute the pointer to a usize (equivalent to `ptr.addr()`) | |
162 | let rvalue = Rvalue::Cast(CastKind::Transmute, Operand::Copy(thin_ptr), tcx.types.usize); | |
163 | let addr = local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | |
164 | block_data | |
165 | .statements | |
166 | .push(Statement { source_info, kind: StatementKind::Assign(Box::new((addr, rvalue))) }); | |
167 | ||
168 | // Get the alignment of the pointee | |
169 | let alignment = | |
170 | local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | |
171 | let rvalue = Rvalue::NullaryOp(NullOp::AlignOf, pointee_ty); | |
172 | block_data.statements.push(Statement { | |
173 | source_info, | |
174 | kind: StatementKind::Assign(Box::new((alignment, rvalue))), | |
175 | }); | |
176 | ||
177 | // Subtract 1 from the alignment to get the alignment mask | |
178 | let alignment_mask = | |
179 | local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | |
180 | let one = Operand::Constant(Box::new(Constant { | |
181 | span: source_info.span, | |
182 | user_ty: None, | |
183 | literal: ConstantKind::Val( | |
184 | ConstValue::Scalar(Scalar::from_target_usize(1, &tcx)), | |
185 | tcx.types.usize, | |
186 | ), | |
187 | })); | |
188 | block_data.statements.push(Statement { | |
189 | source_info, | |
190 | kind: StatementKind::Assign(Box::new(( | |
191 | alignment_mask, | |
192 | Rvalue::BinaryOp(BinOp::Sub, Box::new((Operand::Copy(alignment), one))), | |
193 | ))), | |
194 | }); | |
195 | ||
196 | // BitAnd the alignment mask with the pointer | |
197 | let alignment_bits = | |
198 | local_decls.push(LocalDecl::with_source_info(tcx.types.usize, source_info)).into(); | |
199 | block_data.statements.push(Statement { | |
200 | source_info, | |
201 | kind: StatementKind::Assign(Box::new(( | |
202 | alignment_bits, | |
203 | Rvalue::BinaryOp( | |
204 | BinOp::BitAnd, | |
205 | Box::new((Operand::Copy(addr), Operand::Copy(alignment_mask))), | |
206 | ), | |
207 | ))), | |
208 | }); | |
209 | ||
210 | // Check if the alignment bits are all zero | |
211 | let is_ok = local_decls.push(LocalDecl::with_source_info(tcx.types.bool, source_info)).into(); | |
212 | let zero = Operand::Constant(Box::new(Constant { | |
213 | span: source_info.span, | |
214 | user_ty: None, | |
215 | literal: ConstantKind::Val( | |
216 | ConstValue::Scalar(Scalar::from_target_usize(0, &tcx)), | |
217 | tcx.types.usize, | |
218 | ), | |
219 | })); | |
220 | block_data.statements.push(Statement { | |
221 | source_info, | |
222 | kind: StatementKind::Assign(Box::new(( | |
223 | is_ok, | |
224 | Rvalue::BinaryOp(BinOp::Eq, Box::new((Operand::Copy(alignment_bits), zero.clone()))), | |
225 | ))), | |
226 | }); | |
227 | ||
228 | // Set this block's terminator to our assert, continuing to new_block if we pass | |
229 | block_data.terminator = Some(Terminator { | |
230 | source_info, | |
231 | kind: TerminatorKind::Assert { | |
232 | cond: Operand::Copy(is_ok), | |
233 | expected: true, | |
234 | target: new_block, | |
235 | msg: AssertKind::MisalignedPointerDereference { | |
236 | required: Operand::Copy(alignment), | |
237 | found: Operand::Copy(addr), | |
238 | }, | |
239 | unwind: UnwindAction::Terminate, | |
240 | }, | |
241 | }); | |
242 | } |