]>
Commit | Line | Data |
---|---|---|
3b2f2976 XL |
1 | // Copyright 2015 The Rust Project Developers. See the COPYRIGHT |
2 | // file at the top-level directory of this distribution and at | |
3 | // http://rust-lang.org/COPYRIGHT. | |
4 | // | |
5 | // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or | |
6 | // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license | |
7 | // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your | |
8 | // option. This file may not be copied, modified, or distributed | |
9 | // except according to those terms. | |
10 | ||
11 | //! This pass adds validation calls (AcquireValid, ReleaseValid) where appropriate. | |
12 | //! It has to be run really early, before transformations like inlining, because | |
13 | //! introducing these calls *adds* UB -- so, conceptually, this pass is actually part | |
14 | //! of MIR building, and only after this pass we think of the program has having the | |
15 | //! normal MIR semantics. | |
16 | ||
17 | use rustc::ty::{self, TyCtxt, RegionKind}; | |
18 | use rustc::hir; | |
19 | use rustc::mir::*; | |
ea8adc8c | 20 | use rustc::middle::region; |
abe05a73 | 21 | use transform::{MirPass, MirSource}; |
3b2f2976 XL |
22 | |
23 | pub struct AddValidation; | |
24 | ||
25 | /// Determine the "context" of the lval: Mutability and region. | |
26 | fn lval_context<'a, 'tcx, D>( | |
27 | lval: &Lvalue<'tcx>, | |
28 | local_decls: &D, | |
29 | tcx: TyCtxt<'a, 'tcx, 'tcx> | |
ea8adc8c | 30 | ) -> (Option<region::Scope>, hir::Mutability) |
3b2f2976 XL |
31 | where D: HasLocalDecls<'tcx> |
32 | { | |
33 | use rustc::mir::Lvalue::*; | |
34 | ||
35 | match *lval { | |
36 | Local { .. } => (None, hir::MutMutable), | |
37 | Static(_) => (None, hir::MutImmutable), | |
38 | Projection(ref proj) => { | |
39 | match proj.elem { | |
40 | ProjectionElem::Deref => { | |
41 | // Computing the inside the recursion makes this quadratic. | |
42 | // We don't expect deep paths though. | |
43 | let ty = proj.base.ty(local_decls, tcx).to_ty(tcx); | |
44 | // A Deref projection may restrict the context, this depends on the type | |
45 | // being deref'd. | |
46 | let context = match ty.sty { | |
47 | ty::TyRef(re, tam) => { | |
48 | let re = match re { | |
49 | &RegionKind::ReScope(ce) => Some(ce), | |
50 | &RegionKind::ReErased => | |
51 | bug!("AddValidation pass must be run before erasing lifetimes"), | |
52 | _ => None | |
53 | }; | |
54 | (re, tam.mutbl) | |
55 | } | |
56 | ty::TyRawPtr(_) => | |
57 | // There is no guarantee behind even a mutable raw pointer, | |
58 | // no write locks are acquired there, so we also don't want to | |
59 | // release any. | |
60 | (None, hir::MutImmutable), | |
61 | ty::TyAdt(adt, _) if adt.is_box() => (None, hir::MutMutable), | |
62 | _ => bug!("Deref on a non-pointer type {:?}", ty), | |
63 | }; | |
64 | // "Intersect" this restriction with proj.base. | |
65 | if let (Some(_), hir::MutImmutable) = context { | |
66 | // This is already as restricted as it gets, no need to even recurse | |
67 | context | |
68 | } else { | |
69 | let base_context = lval_context(&proj.base, local_decls, tcx); | |
70 | // The region of the outermost Deref is always most restrictive. | |
71 | let re = context.0.or(base_context.0); | |
72 | let mutbl = context.1.and(base_context.1); | |
73 | (re, mutbl) | |
74 | } | |
75 | ||
76 | } | |
77 | _ => lval_context(&proj.base, local_decls, tcx), | |
78 | } | |
79 | } | |
80 | } | |
81 | } | |
82 | ||
83 | /// Check if this function contains an unsafe block or is an unsafe function. | |
84 | fn fn_contains_unsafe<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, src: MirSource) -> bool { | |
85 | use rustc::hir::intravisit::{self, Visitor, FnKind}; | |
86 | use rustc::hir::map::blocks::FnLikeNode; | |
87 | use rustc::hir::map::Node; | |
88 | ||
89 | /// Decide if this is an unsafe block | |
90 | fn block_is_unsafe(block: &hir::Block) -> bool { | |
91 | use rustc::hir::BlockCheckMode::*; | |
92 | ||
93 | match block.rules { | |
94 | UnsafeBlock(_) | PushUnsafeBlock(_) => true, | |
95 | // For PopUnsafeBlock, we don't actually know -- but we will always also check all | |
96 | // parent blocks, so we can safely declare the PopUnsafeBlock to not be unsafe. | |
97 | DefaultBlock | PopUnsafeBlock(_) => false, | |
98 | } | |
99 | } | |
100 | ||
101 | /// Decide if this FnLike is a closure | |
102 | fn fn_is_closure<'a>(fn_like: FnLikeNode<'a>) -> bool { | |
103 | match fn_like.kind() { | |
104 | FnKind::Closure(_) => true, | |
105 | FnKind::Method(..) | FnKind::ItemFn(..) => false, | |
106 | } | |
107 | } | |
108 | ||
abe05a73 XL |
109 | let node_id = tcx.hir.as_local_node_id(src.def_id).unwrap(); |
110 | let fn_like = match tcx.hir.body_owner_kind(node_id) { | |
111 | hir::BodyOwnerKind::Fn => { | |
3b2f2976 XL |
112 | match FnLikeNode::from_node(tcx.hir.get(node_id)) { |
113 | Some(fn_like) => fn_like, | |
114 | None => return false, // e.g. struct ctor shims -- such auto-generated code cannot | |
115 | // contain unsafe. | |
116 | } | |
117 | }, | |
118 | _ => return false, // only functions can have unsafe | |
119 | }; | |
120 | ||
121 | // Test if the function is marked unsafe. | |
122 | if fn_like.unsafety() == hir::Unsafety::Unsafe { | |
123 | return true; | |
124 | } | |
125 | ||
126 | // For closures, we need to walk up the parents and see if we are inside an unsafe fn or | |
127 | // unsafe block. | |
128 | if fn_is_closure(fn_like) { | |
129 | let mut cur = fn_like.id(); | |
130 | loop { | |
131 | // Go further upwards. | |
132 | cur = tcx.hir.get_parent_node(cur); | |
133 | let node = tcx.hir.get(cur); | |
134 | // Check if this is an unsafe function | |
135 | if let Some(fn_like) = FnLikeNode::from_node(node) { | |
136 | if !fn_is_closure(fn_like) { | |
137 | if fn_like.unsafety() == hir::Unsafety::Unsafe { | |
138 | return true; | |
139 | } | |
140 | } | |
141 | } | |
142 | // Check if this is an unsafe block, or an item | |
143 | match node { | |
144 | Node::NodeExpr(&hir::Expr { node: hir::ExprBlock(ref block), ..}) => { | |
145 | if block_is_unsafe(&*block) { | |
146 | // Found an unsafe block, we can bail out here. | |
147 | return true; | |
148 | } | |
149 | } | |
150 | Node::NodeItem(..) => { | |
151 | // No walking up beyond items. This makes sure the loop always terminates. | |
152 | break; | |
153 | } | |
154 | _ => {}, | |
155 | } | |
156 | } | |
157 | } | |
158 | ||
159 | // Visit the entire body of the function and check for unsafe blocks in there | |
160 | struct FindUnsafe { | |
161 | found_unsafe: bool, | |
162 | } | |
163 | let mut finder = FindUnsafe { found_unsafe: false }; | |
164 | // Run the visitor on the NodeId we got. Seems like there is no uniform way to do that. | |
165 | finder.visit_body(tcx.hir.body(fn_like.body())); | |
166 | ||
167 | impl<'tcx> Visitor<'tcx> for FindUnsafe { | |
168 | fn nested_visit_map<'this>(&'this mut self) -> intravisit::NestedVisitorMap<'this, 'tcx> { | |
169 | intravisit::NestedVisitorMap::None | |
170 | } | |
171 | ||
172 | fn visit_block(&mut self, b: &'tcx hir::Block) { | |
173 | if self.found_unsafe { return; } // short-circuit | |
174 | ||
175 | if block_is_unsafe(b) { | |
176 | // We found an unsafe block. We can stop searching. | |
177 | self.found_unsafe = true; | |
178 | } else { | |
179 | // No unsafe block here, go on searching. | |
180 | intravisit::walk_block(self, b); | |
181 | } | |
182 | } | |
183 | } | |
184 | ||
185 | finder.found_unsafe | |
186 | } | |
187 | ||
188 | impl MirPass for AddValidation { | |
189 | fn run_pass<'a, 'tcx>(&self, | |
190 | tcx: TyCtxt<'a, 'tcx, 'tcx>, | |
191 | src: MirSource, | |
192 | mir: &mut Mir<'tcx>) | |
193 | { | |
194 | let emit_validate = tcx.sess.opts.debugging_opts.mir_emit_validate; | |
195 | if emit_validate == 0 { | |
196 | return; | |
197 | } | |
198 | let restricted_validation = emit_validate == 1 && fn_contains_unsafe(tcx, src); | |
199 | let local_decls = mir.local_decls.clone(); // FIXME: Find a way to get rid of this clone. | |
200 | ||
201 | // Convert an lvalue to a validation operand. | |
202 | let lval_to_operand = |lval: Lvalue<'tcx>| -> ValidationOperand<'tcx, Lvalue<'tcx>> { | |
203 | let (re, mutbl) = lval_context(&lval, &local_decls, tcx); | |
204 | let ty = lval.ty(&local_decls, tcx).to_ty(tcx); | |
205 | ValidationOperand { lval, ty, re, mutbl } | |
206 | }; | |
207 | ||
208 | // Emit an Acquire at the beginning of the given block. If we are in restricted emission | |
209 | // mode (mir_emit_validate=1), also emit a Release immediately after the Acquire. | |
210 | let emit_acquire = |block: &mut BasicBlockData<'tcx>, source_info, operands: Vec<_>| { | |
211 | if operands.len() == 0 { | |
212 | return; // Nothing to do | |
213 | } | |
214 | // Emit the release first, to avoid cloning if we do not emit it | |
215 | if restricted_validation { | |
216 | let release_stmt = Statement { | |
217 | source_info, | |
218 | kind: StatementKind::Validate(ValidationOp::Release, operands.clone()), | |
219 | }; | |
220 | block.statements.insert(0, release_stmt); | |
221 | } | |
222 | // Now, the acquire | |
223 | let acquire_stmt = Statement { | |
224 | source_info, | |
225 | kind: StatementKind::Validate(ValidationOp::Acquire, operands), | |
226 | }; | |
227 | block.statements.insert(0, acquire_stmt); | |
228 | }; | |
229 | ||
230 | // PART 1 | |
231 | // Add an AcquireValid at the beginning of the start block. | |
232 | { | |
233 | let source_info = SourceInfo { | |
234 | scope: ARGUMENT_VISIBILITY_SCOPE, | |
235 | span: mir.span, // FIXME: Consider using just the span covering the function | |
236 | // argument declaration. | |
237 | }; | |
238 | // Gather all arguments, skip return value. | |
239 | let operands = mir.local_decls.iter_enumerated().skip(1).take(mir.arg_count) | |
240 | .map(|(local, _)| lval_to_operand(Lvalue::Local(local))).collect(); | |
241 | emit_acquire(&mut mir.basic_blocks_mut()[START_BLOCK], source_info, operands); | |
242 | } | |
243 | ||
244 | // PART 2 | |
245 | // Add ReleaseValid/AcquireValid around function call terminators. We don't use a visitor | |
246 | // because we need to access the block that a Call jumps to. | |
247 | let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new(); | |
248 | for block_data in mir.basic_blocks_mut() { | |
249 | match block_data.terminator { | |
250 | Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. }, | |
251 | source_info }) => { | |
252 | // Before the call: Release all arguments *and* the return value. | |
253 | // The callee may write into the return value! Note that this relies | |
254 | // on "release of uninitialized" to be a NOP. | |
255 | if !restricted_validation { | |
256 | let release_stmt = Statement { | |
257 | source_info, | |
258 | kind: StatementKind::Validate(ValidationOp::Release, | |
259 | destination.iter().map(|dest| lval_to_operand(dest.0.clone())) | |
260 | .chain( | |
261 | args.iter().filter_map(|op| { | |
262 | match op { | |
263 | &Operand::Consume(ref lval) => | |
264 | Some(lval_to_operand(lval.clone())), | |
265 | &Operand::Constant(..) => { None }, | |
266 | } | |
267 | }) | |
268 | ).collect()) | |
269 | }; | |
270 | block_data.statements.push(release_stmt); | |
271 | } | |
272 | // Remember the return destination for later | |
273 | if let &Some(ref destination) = destination { | |
274 | returns.push((source_info, destination.0.clone(), destination.1)); | |
275 | } | |
276 | } | |
277 | Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. }, | |
278 | source_info }) | | |
279 | Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. }, | |
280 | source_info }) => { | |
281 | // Before the call: Release all arguments | |
282 | if !restricted_validation { | |
283 | let release_stmt = Statement { | |
284 | source_info, | |
285 | kind: StatementKind::Validate(ValidationOp::Release, | |
286 | vec![lval_to_operand(lval.clone())]), | |
287 | }; | |
288 | block_data.statements.push(release_stmt); | |
289 | } | |
290 | // drop doesn't return anything, so we need no acquire. | |
291 | } | |
292 | _ => { | |
293 | // Not a block ending in a Call -> ignore. | |
294 | } | |
295 | } | |
296 | } | |
297 | // Now we go over the returns we collected to acquire the return values. | |
298 | for (source_info, dest_lval, dest_block) in returns { | |
299 | emit_acquire( | |
300 | &mut mir.basic_blocks_mut()[dest_block], | |
301 | source_info, | |
302 | vec![lval_to_operand(dest_lval)] | |
303 | ); | |
304 | } | |
305 | ||
306 | if restricted_validation { | |
307 | // No part 3 for us. | |
308 | return; | |
309 | } | |
310 | ||
311 | // PART 3 | |
312 | // Add ReleaseValid/AcquireValid around Ref and Cast. Again an iterator does not seem very | |
313 | // suited as we need to add new statements before and after each Ref. | |
314 | for block_data in mir.basic_blocks_mut() { | |
315 | // We want to insert statements around Ref commands as we iterate. To this end, we | |
316 | // iterate backwards using indices. | |
317 | for i in (0..block_data.statements.len()).rev() { | |
318 | match block_data.statements[i].kind { | |
319 | // When the borrow of this ref expires, we need to recover validation. | |
320 | StatementKind::Assign(_, Rvalue::Ref(_, _, _)) => { | |
321 | // Due to a lack of NLL; we can't capture anything directly here. | |
322 | // Instead, we have to re-match and clone there. | |
323 | let (dest_lval, re, src_lval) = match block_data.statements[i].kind { | |
324 | StatementKind::Assign(ref dest_lval, | |
325 | Rvalue::Ref(re, _, ref src_lval)) => { | |
326 | (dest_lval.clone(), re, src_lval.clone()) | |
327 | }, | |
328 | _ => bug!("We already matched this."), | |
329 | }; | |
330 | // So this is a ref, and we got all the data we wanted. | |
331 | // Do an acquire of the result -- but only what it points to, so add a Deref | |
332 | // projection. | |
333 | let dest_lval = Projection { base: dest_lval, elem: ProjectionElem::Deref }; | |
334 | let dest_lval = Lvalue::Projection(Box::new(dest_lval)); | |
335 | let acquire_stmt = Statement { | |
336 | source_info: block_data.statements[i].source_info, | |
337 | kind: StatementKind::Validate(ValidationOp::Acquire, | |
338 | vec![lval_to_operand(dest_lval)]), | |
339 | }; | |
340 | block_data.statements.insert(i+1, acquire_stmt); | |
341 | ||
342 | // The source is released until the region of the borrow ends. | |
343 | let op = match re { | |
344 | &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce), | |
345 | &RegionKind::ReErased => | |
346 | bug!("AddValidation pass must be run before erasing lifetimes"), | |
347 | _ => ValidationOp::Release, | |
348 | }; | |
349 | let release_stmt = Statement { | |
350 | source_info: block_data.statements[i].source_info, | |
351 | kind: StatementKind::Validate(op, vec![lval_to_operand(src_lval)]), | |
352 | }; | |
353 | block_data.statements.insert(i, release_stmt); | |
354 | } | |
355 | // Casts can change what validation does (e.g. unsizing) | |
356 | StatementKind::Assign(_, Rvalue::Cast(kind, Operand::Consume(_), _)) | |
357 | if kind != CastKind::Misc => | |
358 | { | |
359 | // Due to a lack of NLL; we can't capture anything directly here. | |
360 | // Instead, we have to re-match and clone there. | |
361 | let (dest_lval, src_lval) = match block_data.statements[i].kind { | |
362 | StatementKind::Assign(ref dest_lval, | |
363 | Rvalue::Cast(_, Operand::Consume(ref src_lval), _)) => | |
364 | { | |
365 | (dest_lval.clone(), src_lval.clone()) | |
366 | }, | |
367 | _ => bug!("We already matched this."), | |
368 | }; | |
369 | ||
370 | // Acquire of the result | |
371 | let acquire_stmt = Statement { | |
372 | source_info: block_data.statements[i].source_info, | |
373 | kind: StatementKind::Validate(ValidationOp::Acquire, | |
374 | vec![lval_to_operand(dest_lval)]), | |
375 | }; | |
376 | block_data.statements.insert(i+1, acquire_stmt); | |
377 | ||
378 | // Release of the input | |
379 | let release_stmt = Statement { | |
380 | source_info: block_data.statements[i].source_info, | |
381 | kind: StatementKind::Validate(ValidationOp::Release, | |
382 | vec![lval_to_operand(src_lval)]), | |
383 | }; | |
384 | block_data.statements.insert(i, release_stmt); | |
385 | } | |
386 | _ => {}, | |
387 | } | |
388 | } | |
389 | } | |
390 | } | |
391 | } |