]> git.proxmox.com Git - rustc.git/blobdiff - compiler/rustc_mir_transform/src/inline.rs
New upstream version 1.65.0+dfsg1
[rustc.git] / compiler / rustc_mir_transform / src / inline.rs
index 76b1522f394f053813bcb03de811ac4a8de15db4..d00a384cb44ec2a6c57a1adef2a85edea65639da 100644 (file)
@@ -8,8 +8,11 @@ use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs}
 use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
 use rustc_middle::ty::subst::Subst;
-use rustc_middle::ty::{self, ConstKind, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
+use rustc_middle::ty::{self, Instance, InstanceDef, ParamEnv, Ty, TyCtxt};
+use rustc_session::config::OptLevel;
+use rustc_span::def_id::DefId;
 use rustc_span::{hygiene::ExpnKind, ExpnData, LocalExpnId, Span};
+use rustc_target::abi::VariantIdx;
 use rustc_target::spec::abi::Abi;
 
 use super::simplify::{remove_dead_blocks, CfgSimplifier};
@@ -43,8 +46,15 @@ impl<'tcx> MirPass<'tcx> for Inline {
             return enabled;
         }
 
-        // rust-lang/rust#101004: reverted to old inlining decision logic
-        sess.mir_opt_level() >= 3
+        match sess.mir_opt_level() {
+            0 | 1 => false,
+            2 => {
+                (sess.opts.optimize == OptLevel::Default
+                    || sess.opts.optimize == OptLevel::Aggressive)
+                    && sess.opts.incremental == None
+            }
+            _ => true,
+        }
     }
 
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
@@ -85,7 +95,7 @@ fn inline<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> bool {
         history: Vec::new(),
         changed: false,
     };
-    let blocks = BasicBlock::new(0)..body.basic_blocks().next_index();
+    let blocks = BasicBlock::new(0)..body.basic_blocks.next_index();
     this.process_blocks(body, blocks);
     this.changed
 }
@@ -95,8 +105,12 @@ struct Inliner<'tcx> {
     param_env: ParamEnv<'tcx>,
     /// Caller codegen attributes.
     codegen_fn_attrs: &'tcx CodegenFnAttrs,
-    /// Stack of inlined Instances.
-    history: Vec<ty::Instance<'tcx>>,
+    /// Stack of inlined instances.
+    /// We only check the `DefId` and not the substs because we want to
+    /// avoid inlining cases of polymorphic recursion.
+    /// The number of `DefId`s is finite, so checking history is enough
+    /// to ensure that we do not loop endlessly while inlining.
+    history: Vec<DefId>,
     /// Indicates that the caller body has been modified.
     changed: bool,
 }
@@ -124,7 +138,7 @@ impl<'tcx> Inliner<'tcx> {
                 Ok(new_blocks) => {
                     debug!("inlined {}", callsite.callee);
                     self.changed = true;
-                    self.history.push(callsite.callee);
+                    self.history.push(callsite.callee.def_id());
                     self.process_blocks(caller_body, new_blocks);
                     self.history.pop();
                 }
@@ -203,9 +217,9 @@ impl<'tcx> Inliner<'tcx> {
             }
         }
 
-        let old_blocks = caller_body.basic_blocks().next_index();
+        let old_blocks = caller_body.basic_blocks.next_index();
         self.inline_call(caller_body, &callsite, callee_body);
-        let new_blocks = old_blocks..caller_body.basic_blocks().next_index();
+        let new_blocks = old_blocks..caller_body.basic_blocks.next_index();
 
         Ok(new_blocks)
     }
@@ -300,7 +314,7 @@ impl<'tcx> Inliner<'tcx> {
                     return None;
                 }
 
-                if self.history.contains(&callee) {
+                if self.history.contains(&callee.def_id()) {
                     return None;
                 }
 
@@ -395,124 +409,66 @@ impl<'tcx> Inliner<'tcx> {
         // Give a bonus functions with a small number of blocks,
         // We normally have two or three blocks for even
         // very small functions.
-        if callee_body.basic_blocks().len() <= 3 {
+        if callee_body.basic_blocks.len() <= 3 {
             threshold += threshold / 4;
         }
         debug!("    final inline threshold = {}", threshold);
 
         // FIXME: Give a bonus to functions with only a single caller
-        let mut first_block = true;
-        let mut cost = 0;
+        let diverges = matches!(
+            callee_body.basic_blocks[START_BLOCK].terminator().kind,
+            TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
+        );
+        if diverges && !matches!(callee_attrs.inline, InlineAttr::Always) {
+            return Err("callee diverges unconditionally");
+        }
 
-        // Traverse the MIR manually so we can account for the effects of
-        // inlining on the CFG.
+        let mut checker = CostChecker {
+            tcx: self.tcx,
+            param_env: self.param_env,
+            instance: callsite.callee,
+            callee_body,
+            cost: 0,
+            validation: Ok(()),
+        };
+
+        // Traverse the MIR manually so we can account for the effects of inlining on the CFG.
         let mut work_list = vec![START_BLOCK];
-        let mut visited = BitSet::new_empty(callee_body.basic_blocks().len());
+        let mut visited = BitSet::new_empty(callee_body.basic_blocks.len());
         while let Some(bb) = work_list.pop() {
             if !visited.insert(bb.index()) {
                 continue;
             }
-            let blk = &callee_body.basic_blocks()[bb];
 
-            for stmt in &blk.statements {
-                // Don't count StorageLive/StorageDead in the inlining cost.
-                match stmt.kind {
-                    StatementKind::StorageLive(_)
-                    | StatementKind::StorageDead(_)
-                    | StatementKind::Deinit(_)
-                    | StatementKind::Nop => {}
-                    _ => cost += INSTR_COST,
-                }
-            }
-            let term = blk.terminator();
-            let mut is_drop = false;
-            match term.kind {
-                TerminatorKind::Drop { ref place, target, unwind }
-                | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } => {
-                    is_drop = true;
-                    work_list.push(target);
-                    // If the place doesn't actually need dropping, treat it like
-                    // a regular goto.
-                    let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
-                    if ty.needs_drop(tcx, self.param_env) {
-                        cost += CALL_PENALTY;
-                        if let Some(unwind) = unwind {
-                            cost += LANDINGPAD_PENALTY;
-                            work_list.push(unwind);
-                        }
-                    } else {
-                        cost += INSTR_COST;
-                    }
-                }
+            let blk = &callee_body.basic_blocks[bb];
+            checker.visit_basic_block_data(bb, blk);
 
-                TerminatorKind::Unreachable | TerminatorKind::Call { target: None, .. }
-                    if first_block =>
-                {
-                    // If the function always diverges, don't inline
-                    // unless the cost is zero
-                    threshold = 0;
-                }
-
-                TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
-                    if let ty::FnDef(def_id, _) =
-                        *callsite.callee.subst_mir(self.tcx, &f.literal.ty()).kind()
-                    {
-                        // Don't give intrinsics the extra penalty for calls
-                        if tcx.is_intrinsic(def_id) {
-                            cost += INSTR_COST;
-                        } else {
-                            cost += CALL_PENALTY;
-                        }
-                    } else {
-                        cost += CALL_PENALTY;
-                    }
-                    if cleanup.is_some() {
-                        cost += LANDINGPAD_PENALTY;
-                    }
-                }
-                TerminatorKind::Assert { cleanup, .. } => {
-                    cost += CALL_PENALTY;
-
-                    if cleanup.is_some() {
-                        cost += LANDINGPAD_PENALTY;
-                    }
-                }
-                TerminatorKind::Resume => cost += RESUME_PENALTY,
-                TerminatorKind::InlineAsm { cleanup, .. } => {
-                    cost += INSTR_COST;
+            let term = blk.terminator();
+            if let TerminatorKind::Drop { ref place, target, unwind }
+            | TerminatorKind::DropAndReplace { ref place, target, unwind, .. } = term.kind
+            {
+                work_list.push(target);
 
-                    if cleanup.is_some() {
-                        cost += LANDINGPAD_PENALTY;
+                // If the place doesn't actually need dropping, treat it like a regular goto.
+                let ty = callsite.callee.subst_mir(self.tcx, &place.ty(callee_body, tcx).ty);
+                if ty.needs_drop(tcx, self.param_env) && let Some(unwind) = unwind {
+                        work_list.push(unwind);
                     }
-                }
-                _ => cost += INSTR_COST,
-            }
-
-            if !is_drop {
-                for succ in term.successors() {
-                    work_list.push(succ);
-                }
+            } else {
+                work_list.extend(term.successors())
             }
-
-            first_block = false;
         }
 
         // Count up the cost of local variables and temps, if we know the size
         // use that, otherwise we use a moderately-large dummy cost.
-
-        let ptr_size = tcx.data_layout.pointer_size.bytes();
-
         for v in callee_body.vars_and_temps_iter() {
-            let ty = callsite.callee.subst_mir(self.tcx, &callee_body.local_decls[v].ty);
-            // Cost of the var is the size in machine-words, if we know
-            // it.
-            if let Some(size) = type_size_of(tcx, self.param_env, ty) {
-                cost += ((size + ptr_size - 1) / ptr_size) as usize;
-            } else {
-                cost += UNKNOWN_SIZE_COST;
-            }
+            checker.visit_local_decl(v, &callee_body.local_decls[v]);
         }
 
+        // Abort if type validation found anything fishy.
+        checker.validation?;
+
+        let cost = checker.cost;
         if let InlineAttr::Always = callee_attrs.inline {
             debug!("INLINING {:?} because inline(always) [cost={}]", callsite, cost);
             Ok(())
@@ -585,7 +541,7 @@ impl<'tcx> Inliner<'tcx> {
                     args: &args,
                     new_locals: Local::new(caller_body.local_decls.len())..,
                     new_scopes: SourceScope::new(caller_body.source_scopes.len())..,
-                    new_blocks: BasicBlock::new(caller_body.basic_blocks().len())..,
+                    new_blocks: BasicBlock::new(caller_body.basic_blocks.len())..,
                     destination: dest,
                     callsite_scope: caller_body.source_scopes[callsite.source_info.scope].clone(),
                     callsite,
@@ -603,7 +559,9 @@ impl<'tcx> Inliner<'tcx> {
                 // If there are any locals without storage markers, give them storage only for the
                 // duration of the call.
                 for local in callee_body.vars_and_temps_iter() {
-                    if integrator.always_live_locals.contains(local) {
+                    if !callee_body.local_decls[local].internal
+                        && integrator.always_live_locals.contains(local)
+                    {
                         let new_local = integrator.map_local(local);
                         caller_body[callsite.block].statements.push(Statement {
                             source_info: callsite.source_info,
@@ -616,7 +574,9 @@ impl<'tcx> Inliner<'tcx> {
                     // the slice once.
                     let mut n = 0;
                     for local in callee_body.vars_and_temps_iter().rev() {
-                        if integrator.always_live_locals.contains(local) {
+                        if !callee_body.local_decls[local].internal
+                            && integrator.always_live_locals.contains(local)
+                        {
                             let new_local = integrator.map_local(local);
                             caller_body[block].statements.push(Statement {
                                 source_info: callsite.source_info,
@@ -644,11 +604,11 @@ impl<'tcx> Inliner<'tcx> {
                 // `required_consts`, here we may not only have `ConstKind::Unevaluated`
                 // because we are calling `subst_and_normalize_erasing_regions`.
                 caller_body.required_consts.extend(
-                    callee_body.required_consts.iter().copied().filter(|&ct| {
-                        match ct.literal.const_for_ty() {
-                            Some(ct) => matches!(ct.kind(), ConstKind::Unevaluated(_)),
-                            None => true,
+                    callee_body.required_consts.iter().copied().filter(|&ct| match ct.literal {
+                        ConstantKind::Ty(_) => {
+                            bug!("should never encounter ty::Unevaluated in `required_consts`")
                         }
+                        ConstantKind::Val(..) | ConstantKind::Unevaluated(..) => true,
                     }),
                 );
             }
@@ -782,6 +742,193 @@ fn type_size_of<'tcx>(
     tcx.layout_of(param_env.and(ty)).ok().map(|layout| layout.size.bytes())
 }
 
+/// Verify that the callee body is compatible with the caller.
+///
+/// This visitor mostly computes the inlining cost,
+/// but also needs to verify that types match because of normalization failure.
+struct CostChecker<'b, 'tcx> {
+    tcx: TyCtxt<'tcx>,
+    param_env: ParamEnv<'tcx>,
+    cost: usize,
+    callee_body: &'b Body<'tcx>,
+    instance: ty::Instance<'tcx>,
+    validation: Result<(), &'static str>,
+}
+
+impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
+    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
+        // Don't count StorageLive/StorageDead in the inlining cost.
+        match statement.kind {
+            StatementKind::StorageLive(_)
+            | StatementKind::StorageDead(_)
+            | StatementKind::Deinit(_)
+            | StatementKind::Nop => {}
+            _ => self.cost += INSTR_COST,
+        }
+
+        self.super_statement(statement, location);
+    }
+
+    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
+        let tcx = self.tcx;
+        match terminator.kind {
+            TerminatorKind::Drop { ref place, unwind, .. }
+            | TerminatorKind::DropAndReplace { ref place, unwind, .. } => {
+                // If the place doesn't actually need dropping, treat it like a regular goto.
+                let ty = self.instance.subst_mir(tcx, &place.ty(self.callee_body, tcx).ty);
+                if ty.needs_drop(tcx, self.param_env) {
+                    self.cost += CALL_PENALTY;
+                    if unwind.is_some() {
+                        self.cost += LANDINGPAD_PENALTY;
+                    }
+                } else {
+                    self.cost += INSTR_COST;
+                }
+            }
+            TerminatorKind::Call { func: Operand::Constant(ref f), cleanup, .. } => {
+                let fn_ty = self.instance.subst_mir(tcx, &f.literal.ty());
+                self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind() && tcx.is_intrinsic(def_id) {
+                    // Don't give intrinsics the extra penalty for calls
+                    INSTR_COST
+                } else {
+                    CALL_PENALTY
+                };
+                if cleanup.is_some() {
+                    self.cost += LANDINGPAD_PENALTY;
+                }
+            }
+            TerminatorKind::Assert { cleanup, .. } => {
+                self.cost += CALL_PENALTY;
+                if cleanup.is_some() {
+                    self.cost += LANDINGPAD_PENALTY;
+                }
+            }
+            TerminatorKind::Resume => self.cost += RESUME_PENALTY,
+            TerminatorKind::InlineAsm { cleanup, .. } => {
+                self.cost += INSTR_COST;
+                if cleanup.is_some() {
+                    self.cost += LANDINGPAD_PENALTY;
+                }
+            }
+            _ => self.cost += INSTR_COST,
+        }
+
+        self.super_terminator(terminator, location);
+    }
+
+    /// Count up the cost of local variables and temps, if we know the size
+    /// use that, otherwise we use a moderately-large dummy cost.
+    fn visit_local_decl(&mut self, local: Local, local_decl: &LocalDecl<'tcx>) {
+        let tcx = self.tcx;
+        let ptr_size = tcx.data_layout.pointer_size.bytes();
+
+        let ty = self.instance.subst_mir(tcx, &local_decl.ty);
+        // Cost of the var is the size in machine-words, if we know
+        // it.
+        if let Some(size) = type_size_of(tcx, self.param_env, ty) {
+            self.cost += ((size + ptr_size - 1) / ptr_size) as usize;
+        } else {
+            self.cost += UNKNOWN_SIZE_COST;
+        }
+
+        self.super_local_decl(local, local_decl)
+    }
+
+    /// This method duplicates code from MIR validation in an attempt to detect type mismatches due
+    /// to normalization failure.
+    fn visit_projection_elem(
+        &mut self,
+        local: Local,
+        proj_base: &[PlaceElem<'tcx>],
+        elem: PlaceElem<'tcx>,
+        context: PlaceContext,
+        location: Location,
+    ) {
+        if let ProjectionElem::Field(f, ty) = elem {
+            let parent = Place { local, projection: self.tcx.intern_place_elems(proj_base) };
+            let parent_ty = parent.ty(&self.callee_body.local_decls, self.tcx);
+            let check_equal = |this: &mut Self, f_ty| {
+                if !equal_up_to_regions(this.tcx, this.param_env, ty, f_ty) {
+                    trace!(?ty, ?f_ty);
+                    this.validation = Err("failed to normalize projection type");
+                    return;
+                }
+            };
+
+            let kind = match parent_ty.ty.kind() {
+                &ty::Opaque(def_id, substs) => {
+                    self.tcx.bound_type_of(def_id).subst(self.tcx, substs).kind()
+                }
+                kind => kind,
+            };
+
+            match kind {
+                ty::Tuple(fields) => {
+                    let Some(f_ty) = fields.get(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, *f_ty);
+                }
+                ty::Adt(adt_def, substs) => {
+                    let var = parent_ty.variant_index.unwrap_or(VariantIdx::from_u32(0));
+                    let Some(field) = adt_def.variant(var).fields.get(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, field.ty(self.tcx, substs));
+                }
+                ty::Closure(_, substs) => {
+                    let substs = substs.as_closure();
+                    let Some(f_ty) = substs.upvar_tys().nth(f.as_usize()) else {
+                        self.validation = Err("malformed MIR");
+                        return;
+                    };
+                    check_equal(self, f_ty);
+                }
+                &ty::Generator(def_id, substs, _) => {
+                    let f_ty = if let Some(var) = parent_ty.variant_index {
+                        let gen_body = if def_id == self.callee_body.source.def_id() {
+                            self.callee_body
+                        } else {
+                            self.tcx.optimized_mir(def_id)
+                        };
+
+                        let Some(layout) = gen_body.generator_layout() else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        let Some(&local) = layout.variant_fields[var].get(f) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        let Some(&f_ty) = layout.field_tys.get(local) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        f_ty
+                    } else {
+                        let Some(f_ty) = substs.as_generator().prefix_tys().nth(f.index()) else {
+                            self.validation = Err("malformed MIR");
+                            return;
+                        };
+
+                        f_ty
+                    };
+
+                    check_equal(self, f_ty);
+                }
+                _ => self.validation = Err("malformed MIR"),
+            }
+        }
+
+        self.super_projection_elem(local, proj_base, elem, context, location);
+    }
+}
+
 /**
  * Integrator.
  *