]> git.proxmox.com Git - rustc.git/blame - src/librustc_mir/transform/add_validation.rs
New upstream version 1.23.0+dfsg1
[rustc.git] / src / librustc_mir / transform / add_validation.rs
CommitLineData
3b2f2976
XL
1// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
2// file at the top-level directory of this distribution and at
3// http://rust-lang.org/COPYRIGHT.
4//
5// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8// option. This file may not be copied, modified, or distributed
9// except according to those terms.
10
11//! This pass adds validation calls (AcquireValid, ReleaseValid) where appropriate.
12//! It has to be run really early, before transformations like inlining, because
13//! introducing these calls *adds* UB -- so, conceptually, this pass is actually part
14//! of MIR building, and only after this pass we think of the program has having the
15//! normal MIR semantics.
16
17use rustc::ty::{self, TyCtxt, RegionKind};
18use rustc::hir;
19use rustc::mir::*;
ea8adc8c 20use rustc::middle::region;
abe05a73 21use transform::{MirPass, MirSource};
3b2f2976
XL
22
23pub struct AddValidation;
24
25/// Determine the "context" of the lval: Mutability and region.
26fn lval_context<'a, 'tcx, D>(
27 lval: &Lvalue<'tcx>,
28 local_decls: &D,
29 tcx: TyCtxt<'a, 'tcx, 'tcx>
ea8adc8c 30) -> (Option<region::Scope>, hir::Mutability)
3b2f2976
XL
31 where D: HasLocalDecls<'tcx>
32{
33 use rustc::mir::Lvalue::*;
34
35 match *lval {
36 Local { .. } => (None, hir::MutMutable),
37 Static(_) => (None, hir::MutImmutable),
38 Projection(ref proj) => {
39 match proj.elem {
40 ProjectionElem::Deref => {
41 // Computing the inside the recursion makes this quadratic.
42 // We don't expect deep paths though.
43 let ty = proj.base.ty(local_decls, tcx).to_ty(tcx);
44 // A Deref projection may restrict the context, this depends on the type
45 // being deref'd.
46 let context = match ty.sty {
47 ty::TyRef(re, tam) => {
48 let re = match re {
49 &RegionKind::ReScope(ce) => Some(ce),
50 &RegionKind::ReErased =>
51 bug!("AddValidation pass must be run before erasing lifetimes"),
52 _ => None
53 };
54 (re, tam.mutbl)
55 }
56 ty::TyRawPtr(_) =>
57 // There is no guarantee behind even a mutable raw pointer,
58 // no write locks are acquired there, so we also don't want to
59 // release any.
60 (None, hir::MutImmutable),
61 ty::TyAdt(adt, _) if adt.is_box() => (None, hir::MutMutable),
62 _ => bug!("Deref on a non-pointer type {:?}", ty),
63 };
64 // "Intersect" this restriction with proj.base.
65 if let (Some(_), hir::MutImmutable) = context {
66 // This is already as restricted as it gets, no need to even recurse
67 context
68 } else {
69 let base_context = lval_context(&proj.base, local_decls, tcx);
70 // The region of the outermost Deref is always most restrictive.
71 let re = context.0.or(base_context.0);
72 let mutbl = context.1.and(base_context.1);
73 (re, mutbl)
74 }
75
76 }
77 _ => lval_context(&proj.base, local_decls, tcx),
78 }
79 }
80 }
81}
82
83/// Check if this function contains an unsafe block or is an unsafe function.
84fn fn_contains_unsafe<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>, src: MirSource) -> bool {
85 use rustc::hir::intravisit::{self, Visitor, FnKind};
86 use rustc::hir::map::blocks::FnLikeNode;
87 use rustc::hir::map::Node;
88
89 /// Decide if this is an unsafe block
90 fn block_is_unsafe(block: &hir::Block) -> bool {
91 use rustc::hir::BlockCheckMode::*;
92
93 match block.rules {
94 UnsafeBlock(_) | PushUnsafeBlock(_) => true,
95 // For PopUnsafeBlock, we don't actually know -- but we will always also check all
96 // parent blocks, so we can safely declare the PopUnsafeBlock to not be unsafe.
97 DefaultBlock | PopUnsafeBlock(_) => false,
98 }
99 }
100
101 /// Decide if this FnLike is a closure
102 fn fn_is_closure<'a>(fn_like: FnLikeNode<'a>) -> bool {
103 match fn_like.kind() {
104 FnKind::Closure(_) => true,
105 FnKind::Method(..) | FnKind::ItemFn(..) => false,
106 }
107 }
108
abe05a73
XL
109 let node_id = tcx.hir.as_local_node_id(src.def_id).unwrap();
110 let fn_like = match tcx.hir.body_owner_kind(node_id) {
111 hir::BodyOwnerKind::Fn => {
3b2f2976
XL
112 match FnLikeNode::from_node(tcx.hir.get(node_id)) {
113 Some(fn_like) => fn_like,
114 None => return false, // e.g. struct ctor shims -- such auto-generated code cannot
115 // contain unsafe.
116 }
117 },
118 _ => return false, // only functions can have unsafe
119 };
120
121 // Test if the function is marked unsafe.
122 if fn_like.unsafety() == hir::Unsafety::Unsafe {
123 return true;
124 }
125
126 // For closures, we need to walk up the parents and see if we are inside an unsafe fn or
127 // unsafe block.
128 if fn_is_closure(fn_like) {
129 let mut cur = fn_like.id();
130 loop {
131 // Go further upwards.
132 cur = tcx.hir.get_parent_node(cur);
133 let node = tcx.hir.get(cur);
134 // Check if this is an unsafe function
135 if let Some(fn_like) = FnLikeNode::from_node(node) {
136 if !fn_is_closure(fn_like) {
137 if fn_like.unsafety() == hir::Unsafety::Unsafe {
138 return true;
139 }
140 }
141 }
142 // Check if this is an unsafe block, or an item
143 match node {
144 Node::NodeExpr(&hir::Expr { node: hir::ExprBlock(ref block), ..}) => {
145 if block_is_unsafe(&*block) {
146 // Found an unsafe block, we can bail out here.
147 return true;
148 }
149 }
150 Node::NodeItem(..) => {
151 // No walking up beyond items. This makes sure the loop always terminates.
152 break;
153 }
154 _ => {},
155 }
156 }
157 }
158
159 // Visit the entire body of the function and check for unsafe blocks in there
160 struct FindUnsafe {
161 found_unsafe: bool,
162 }
163 let mut finder = FindUnsafe { found_unsafe: false };
164 // Run the visitor on the NodeId we got. Seems like there is no uniform way to do that.
165 finder.visit_body(tcx.hir.body(fn_like.body()));
166
167 impl<'tcx> Visitor<'tcx> for FindUnsafe {
168 fn nested_visit_map<'this>(&'this mut self) -> intravisit::NestedVisitorMap<'this, 'tcx> {
169 intravisit::NestedVisitorMap::None
170 }
171
172 fn visit_block(&mut self, b: &'tcx hir::Block) {
173 if self.found_unsafe { return; } // short-circuit
174
175 if block_is_unsafe(b) {
176 // We found an unsafe block. We can stop searching.
177 self.found_unsafe = true;
178 } else {
179 // No unsafe block here, go on searching.
180 intravisit::walk_block(self, b);
181 }
182 }
183 }
184
185 finder.found_unsafe
186}
187
188impl MirPass for AddValidation {
189 fn run_pass<'a, 'tcx>(&self,
190 tcx: TyCtxt<'a, 'tcx, 'tcx>,
191 src: MirSource,
192 mir: &mut Mir<'tcx>)
193 {
194 let emit_validate = tcx.sess.opts.debugging_opts.mir_emit_validate;
195 if emit_validate == 0 {
196 return;
197 }
198 let restricted_validation = emit_validate == 1 && fn_contains_unsafe(tcx, src);
199 let local_decls = mir.local_decls.clone(); // FIXME: Find a way to get rid of this clone.
200
201 // Convert an lvalue to a validation operand.
202 let lval_to_operand = |lval: Lvalue<'tcx>| -> ValidationOperand<'tcx, Lvalue<'tcx>> {
203 let (re, mutbl) = lval_context(&lval, &local_decls, tcx);
204 let ty = lval.ty(&local_decls, tcx).to_ty(tcx);
205 ValidationOperand { lval, ty, re, mutbl }
206 };
207
208 // Emit an Acquire at the beginning of the given block. If we are in restricted emission
209 // mode (mir_emit_validate=1), also emit a Release immediately after the Acquire.
210 let emit_acquire = |block: &mut BasicBlockData<'tcx>, source_info, operands: Vec<_>| {
211 if operands.len() == 0 {
212 return; // Nothing to do
213 }
214 // Emit the release first, to avoid cloning if we do not emit it
215 if restricted_validation {
216 let release_stmt = Statement {
217 source_info,
218 kind: StatementKind::Validate(ValidationOp::Release, operands.clone()),
219 };
220 block.statements.insert(0, release_stmt);
221 }
222 // Now, the acquire
223 let acquire_stmt = Statement {
224 source_info,
225 kind: StatementKind::Validate(ValidationOp::Acquire, operands),
226 };
227 block.statements.insert(0, acquire_stmt);
228 };
229
230 // PART 1
231 // Add an AcquireValid at the beginning of the start block.
232 {
233 let source_info = SourceInfo {
234 scope: ARGUMENT_VISIBILITY_SCOPE,
235 span: mir.span, // FIXME: Consider using just the span covering the function
236 // argument declaration.
237 };
238 // Gather all arguments, skip return value.
239 let operands = mir.local_decls.iter_enumerated().skip(1).take(mir.arg_count)
240 .map(|(local, _)| lval_to_operand(Lvalue::Local(local))).collect();
241 emit_acquire(&mut mir.basic_blocks_mut()[START_BLOCK], source_info, operands);
242 }
243
244 // PART 2
245 // Add ReleaseValid/AcquireValid around function call terminators. We don't use a visitor
246 // because we need to access the block that a Call jumps to.
247 let mut returns : Vec<(SourceInfo, Lvalue<'tcx>, BasicBlock)> = Vec::new();
248 for block_data in mir.basic_blocks_mut() {
249 match block_data.terminator {
250 Some(Terminator { kind: TerminatorKind::Call { ref args, ref destination, .. },
251 source_info }) => {
252 // Before the call: Release all arguments *and* the return value.
253 // The callee may write into the return value! Note that this relies
254 // on "release of uninitialized" to be a NOP.
255 if !restricted_validation {
256 let release_stmt = Statement {
257 source_info,
258 kind: StatementKind::Validate(ValidationOp::Release,
259 destination.iter().map(|dest| lval_to_operand(dest.0.clone()))
260 .chain(
261 args.iter().filter_map(|op| {
262 match op {
263 &Operand::Consume(ref lval) =>
264 Some(lval_to_operand(lval.clone())),
265 &Operand::Constant(..) => { None },
266 }
267 })
268 ).collect())
269 };
270 block_data.statements.push(release_stmt);
271 }
272 // Remember the return destination for later
273 if let &Some(ref destination) = destination {
274 returns.push((source_info, destination.0.clone(), destination.1));
275 }
276 }
277 Some(Terminator { kind: TerminatorKind::Drop { location: ref lval, .. },
278 source_info }) |
279 Some(Terminator { kind: TerminatorKind::DropAndReplace { location: ref lval, .. },
280 source_info }) => {
281 // Before the call: Release all arguments
282 if !restricted_validation {
283 let release_stmt = Statement {
284 source_info,
285 kind: StatementKind::Validate(ValidationOp::Release,
286 vec![lval_to_operand(lval.clone())]),
287 };
288 block_data.statements.push(release_stmt);
289 }
290 // drop doesn't return anything, so we need no acquire.
291 }
292 _ => {
293 // Not a block ending in a Call -> ignore.
294 }
295 }
296 }
297 // Now we go over the returns we collected to acquire the return values.
298 for (source_info, dest_lval, dest_block) in returns {
299 emit_acquire(
300 &mut mir.basic_blocks_mut()[dest_block],
301 source_info,
302 vec![lval_to_operand(dest_lval)]
303 );
304 }
305
306 if restricted_validation {
307 // No part 3 for us.
308 return;
309 }
310
311 // PART 3
312 // Add ReleaseValid/AcquireValid around Ref and Cast. Again an iterator does not seem very
313 // suited as we need to add new statements before and after each Ref.
314 for block_data in mir.basic_blocks_mut() {
315 // We want to insert statements around Ref commands as we iterate. To this end, we
316 // iterate backwards using indices.
317 for i in (0..block_data.statements.len()).rev() {
318 match block_data.statements[i].kind {
319 // When the borrow of this ref expires, we need to recover validation.
320 StatementKind::Assign(_, Rvalue::Ref(_, _, _)) => {
321 // Due to a lack of NLL; we can't capture anything directly here.
322 // Instead, we have to re-match and clone there.
323 let (dest_lval, re, src_lval) = match block_data.statements[i].kind {
324 StatementKind::Assign(ref dest_lval,
325 Rvalue::Ref(re, _, ref src_lval)) => {
326 (dest_lval.clone(), re, src_lval.clone())
327 },
328 _ => bug!("We already matched this."),
329 };
330 // So this is a ref, and we got all the data we wanted.
331 // Do an acquire of the result -- but only what it points to, so add a Deref
332 // projection.
333 let dest_lval = Projection { base: dest_lval, elem: ProjectionElem::Deref };
334 let dest_lval = Lvalue::Projection(Box::new(dest_lval));
335 let acquire_stmt = Statement {
336 source_info: block_data.statements[i].source_info,
337 kind: StatementKind::Validate(ValidationOp::Acquire,
338 vec![lval_to_operand(dest_lval)]),
339 };
340 block_data.statements.insert(i+1, acquire_stmt);
341
342 // The source is released until the region of the borrow ends.
343 let op = match re {
344 &RegionKind::ReScope(ce) => ValidationOp::Suspend(ce),
345 &RegionKind::ReErased =>
346 bug!("AddValidation pass must be run before erasing lifetimes"),
347 _ => ValidationOp::Release,
348 };
349 let release_stmt = Statement {
350 source_info: block_data.statements[i].source_info,
351 kind: StatementKind::Validate(op, vec![lval_to_operand(src_lval)]),
352 };
353 block_data.statements.insert(i, release_stmt);
354 }
355 // Casts can change what validation does (e.g. unsizing)
356 StatementKind::Assign(_, Rvalue::Cast(kind, Operand::Consume(_), _))
357 if kind != CastKind::Misc =>
358 {
359 // Due to a lack of NLL; we can't capture anything directly here.
360 // Instead, we have to re-match and clone there.
361 let (dest_lval, src_lval) = match block_data.statements[i].kind {
362 StatementKind::Assign(ref dest_lval,
363 Rvalue::Cast(_, Operand::Consume(ref src_lval), _)) =>
364 {
365 (dest_lval.clone(), src_lval.clone())
366 },
367 _ => bug!("We already matched this."),
368 };
369
370 // Acquire of the result
371 let acquire_stmt = Statement {
372 source_info: block_data.statements[i].source_info,
373 kind: StatementKind::Validate(ValidationOp::Acquire,
374 vec![lval_to_operand(dest_lval)]),
375 };
376 block_data.statements.insert(i+1, acquire_stmt);
377
378 // Release of the input
379 let release_stmt = Statement {
380 source_info: block_data.statements[i].source_info,
381 kind: StatementKind::Validate(ValidationOp::Release,
382 vec![lval_to_operand(src_lval)]),
383 };
384 block_data.statements.insert(i, release_stmt);
385 }
386 _ => {},
387 }
388 }
389 }
390 }
391}