]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/instrument_coverage.rs
c36614938e10f9ab8b9a006faae3becece49105b
[rust.git] / src / librustc_mir / transform / instrument_coverage.rs
1 use crate::transform::{MirPass, MirSource};
2 use crate::util::patch::MirPatch;
3 use rustc_hir::lang_items;
4 use rustc_middle::mir::interpret::Scalar;
5 use rustc_middle::mir::*;
6 use rustc_middle::ty;
7 use rustc_middle::ty::TyCtxt;
8 use rustc_span::def_id::DefId;
9 use rustc_span::Span;
10
11 /// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
12 /// the intrinsic llvm.instrprof.increment.
13 pub struct InstrumentCoverage;
14
15 impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
16     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, body: &mut Body<'tcx>) {
17         if tcx.sess.opts.debugging_opts.instrument_coverage {
18             debug!("instrumenting {:?}", src.def_id());
19             instrument_coverage(tcx, body);
20         }
21     }
22 }
23
24 // The first counter (start of the function) is index zero.
25 const INIT_FUNCTION_COUNTER: u32 = 0;
26
27 /// Injects calls to placeholder function `count_code_region()`.
28 // FIXME(richkadel): As a first step, counters are only injected at the top of each function.
29 // The complete solution will inject counters at each conditional code branch.
30 pub fn instrument_coverage<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
31     let span = body.span.shrink_to_lo();
32
33     let count_code_region_fn = function_handle(
34         tcx,
35         tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None),
36         span,
37     );
38     let counter_index = Operand::const_from_scalar(
39         tcx,
40         tcx.types.u32,
41         Scalar::from_u32(INIT_FUNCTION_COUNTER),
42         span,
43     );
44
45     let mut patch = MirPatch::new(body);
46
47     let new_block = patch.new_block(placeholder_block(SourceInfo::outermost(body.span)));
48     let next_block = START_BLOCK;
49
50     let temp = patch.new_temp(tcx.mk_unit(), body.span);
51     patch.patch_terminator(
52         new_block,
53         TerminatorKind::Call {
54             func: count_code_region_fn,
55             args: vec![counter_index],
56             // new_block will swapped with the next_block, after applying patch
57             destination: Some((Place::from(temp), new_block)),
58             cleanup: None,
59             from_hir_call: false,
60             fn_span: span,
61         },
62     );
63
64     patch.add_statement(new_block.start_location(), StatementKind::StorageLive(temp));
65     patch.add_statement(next_block.start_location(), StatementKind::StorageDead(temp));
66
67     patch.apply(body);
68
69     // To insert the `new_block` in front of the first block in the counted branch (for example,
70     // the START_BLOCK, at the top of the function), just swap the indexes, leaving the rest of the
71     // graph unchanged.
72     body.basic_blocks_mut().swap(next_block, new_block);
73 }
74
75 fn function_handle<'tcx>(tcx: TyCtxt<'tcx>, fn_def_id: DefId, span: Span) -> Operand<'tcx> {
76     let ret_ty = tcx.fn_sig(fn_def_id).output();
77     let ret_ty = ret_ty.no_bound_vars().unwrap();
78     let substs = tcx.mk_substs(::std::iter::once(ty::subst::GenericArg::from(ret_ty)));
79     Operand::function_handle(tcx, fn_def_id, substs, span)
80 }
81
82 fn placeholder_block<'tcx>(source_info: SourceInfo) -> BasicBlockData<'tcx> {
83     BasicBlockData {
84         statements: vec![],
85         terminator: Some(Terminator {
86             source_info,
87             // this gets overwritten by the counter Call
88             kind: TerminatorKind::Unreachable,
89         }),
90         is_cleanup: false,
91     }
92 }