]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/instrument_coverage.rs
add spans to injected coverage counters
[rust.git] / src / librustc_mir / transform / instrument_coverage.rs
1 use crate::transform::{MirPass, MirSource};
2 use crate::util::patch::MirPatch;
3 use rustc_data_structures::fingerprint::Fingerprint;
4 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
5 use rustc_hir::lang_items;
6 use rustc_middle::hir;
7 use rustc_middle::ich::StableHashingContext;
8 use rustc_middle::mir::coverage::*;
9 use rustc_middle::mir::interpret::Scalar;
10 use rustc_middle::mir::CoverageInfo;
11 use rustc_middle::mir::{
12     self, traversal, BasicBlock, BasicBlockData, Operand, Place, SourceInfo, StatementKind,
13     Terminator, TerminatorKind, START_BLOCK,
14 };
15 use rustc_middle::ty;
16 use rustc_middle::ty::query::Providers;
17 use rustc_middle::ty::FnDef;
18 use rustc_middle::ty::TyCtxt;
19 use rustc_span::def_id::DefId;
20 use rustc_span::{Pos, Span};
21
22 /// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
23 /// the intrinsic llvm.instrprof.increment.
24 pub struct InstrumentCoverage;
25
26 /// The `query` provider for `CoverageInfo`, requested by `codegen_intrinsic_call()` when
27 /// constructing the arguments for `llvm.instrprof.increment`.
28 pub(crate) fn provide(providers: &mut Providers<'_>) {
29     providers.coverageinfo = |tcx, def_id| coverageinfo_from_mir(tcx, def_id);
30 }
31
32 fn coverageinfo_from_mir<'tcx>(tcx: TyCtxt<'tcx>, mir_def_id: DefId) -> CoverageInfo {
33     let mir_body = tcx.optimized_mir(mir_def_id);
34     // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
35     // represents a single function. Validate and/or correct if inlining (which should be disabled
36     // if -Zinstrument-coverage is enabled) and/or monomorphization invalidates these assumptions.
37     let count_code_region_fn = tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None);
38
39     // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
40     // counters, with each counter having an index from `0..num_counters-1`. MIR optimization
41     // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
42     // not work; but computing the num_counters by adding `1` to the highest index (for a given
43     // instrumented function) is valid.
44     let mut num_counters: u32 = 0;
45     for terminator in traversal::preorder(mir_body)
46         .map(|(_, data)| (data, count_code_region_fn))
47         .filter_map(terminators_that_call_given_fn)
48     {
49         if let TerminatorKind::Call { args, .. } = &terminator.kind {
50             let index_arg = args.get(count_code_region_args::COUNTER_INDEX).expect("arg found");
51             let index =
52                 mir::Operand::scalar_from_const(index_arg).to_u32().expect("index arg is u32");
53             num_counters = std::cmp::max(num_counters, index + 1);
54         }
55     }
56     let hash = if num_counters > 0 { hash_mir_source(tcx, mir_def_id) } else { 0 };
57     CoverageInfo { num_counters, hash }
58 }
59
60 fn terminators_that_call_given_fn(
61     (data, fn_def_id): (&'tcx BasicBlockData<'tcx>, DefId),
62 ) -> Option<&'tcx Terminator<'tcx>> {
63     if let Some(terminator) = &data.terminator {
64         if let TerminatorKind::Call { func: Operand::Constant(func), .. } = &terminator.kind {
65             if let FnDef(called_fn_def_id, _) = func.literal.ty.kind {
66                 if called_fn_def_id == fn_def_id {
67                     return Some(&terminator);
68                 }
69             }
70         }
71     }
72     None
73 }
74
75 struct Instrumentor<'tcx> {
76     tcx: TyCtxt<'tcx>,
77     num_counters: u32,
78 }
79
80 impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
81     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &mut mir::Body<'tcx>) {
82         if tcx.sess.opts.debugging_opts.instrument_coverage {
83             // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
84             // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
85             if src.promoted.is_none() {
86                 debug!(
87                     "instrumenting {:?}, span: {}",
88                     src.def_id(),
89                     tcx.sess.source_map().span_to_string(mir_body.span)
90                 );
91                 Instrumentor::new(tcx).inject_counters(mir_body);
92             }
93         }
94     }
95 }
96
97 impl<'tcx> Instrumentor<'tcx> {
98     fn new(tcx: TyCtxt<'tcx>) -> Self {
99         Self { tcx, num_counters: 0 }
100     }
101
102     fn next_counter(&mut self) -> u32 {
103         let next = self.num_counters;
104         self.num_counters += 1;
105         next
106     }
107
108     fn inject_counters(&mut self, mir_body: &mut mir::Body<'tcx>) {
109         // FIXME(richkadel): As a first step, counters are only injected at the top of each
110         // function. The complete solution will inject counters at each conditional code branch.
111         let code_region = mir_body.span;
112         let next_block = START_BLOCK;
113         self.inject_counter(mir_body, code_region, next_block);
114     }
115
116     fn inject_counter(
117         &mut self,
118         mir_body: &mut mir::Body<'tcx>,
119         code_region: Span,
120         next_block: BasicBlock,
121     ) {
122         let injection_point = code_region.shrink_to_lo();
123
124         let count_code_region_fn = function_handle(
125             self.tcx,
126             self.tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None),
127             injection_point,
128         );
129
130         let index = self.next_counter();
131
132         let mut args = Vec::new();
133
134         use count_code_region_args::*;
135         debug_assert_eq!(COUNTER_INDEX, args.len());
136         args.push(self.const_u32(index, injection_point));
137
138         debug_assert_eq!(START_BYTE_POS, args.len());
139         args.push(self.const_u32(code_region.lo().to_u32(), injection_point));
140
141         debug_assert_eq!(END_BYTE_POS, args.len());
142         args.push(self.const_u32(code_region.hi().to_u32(), injection_point));
143
144         let mut patch = MirPatch::new(mir_body);
145
146         let temp = patch.new_temp(self.tcx.mk_unit(), code_region);
147         let new_block = patch.new_block(placeholder_block(code_region));
148         patch.patch_terminator(
149             new_block,
150             TerminatorKind::Call {
151                 func: count_code_region_fn,
152                 args,
153                 // new_block will swapped with the next_block, after applying patch
154                 destination: Some((Place::from(temp), new_block)),
155                 cleanup: None,
156                 from_hir_call: false,
157                 fn_span: injection_point,
158             },
159         );
160
161         patch.add_statement(new_block.start_location(), StatementKind::StorageLive(temp));
162         patch.add_statement(next_block.start_location(), StatementKind::StorageDead(temp));
163
164         patch.apply(mir_body);
165
166         // To insert the `new_block` in front of the first block in the counted branch (the
167         // `next_block`), just swap the indexes, leaving the rest of the graph unchanged.
168         mir_body.basic_blocks_mut().swap(next_block, new_block);
169     }
170
171     fn const_u32(&self, value: u32, span: Span) -> Operand<'tcx> {
172         Operand::const_from_scalar(self.tcx, self.tcx.types.u32, Scalar::from_u32(value), span)
173     }
174 }
175
176 fn function_handle<'tcx>(tcx: TyCtxt<'tcx>, fn_def_id: DefId, span: Span) -> Operand<'tcx> {
177     let ret_ty = tcx.fn_sig(fn_def_id).output();
178     let ret_ty = ret_ty.no_bound_vars().unwrap();
179     let substs = tcx.mk_substs(::std::iter::once(ty::subst::GenericArg::from(ret_ty)));
180     Operand::function_handle(tcx, fn_def_id, substs, span)
181 }
182
183 fn placeholder_block(span: Span) -> BasicBlockData<'tcx> {
184     BasicBlockData {
185         statements: vec![],
186         terminator: Some(Terminator {
187             source_info: SourceInfo::outermost(span),
188             // this gets overwritten by the counter Call
189             kind: TerminatorKind::Unreachable,
190         }),
191         is_cleanup: false,
192     }
193 }
194
195 fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> u64 {
196     let hir_node = tcx.hir().get_if_local(def_id).expect("DefId is local");
197     let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body");
198     let hir_body = tcx.hir().body(fn_body_id);
199     let mut hcx = tcx.create_no_span_stable_hashing_context();
200     hash(&mut hcx, &hir_body.value).to_smaller_hash()
201 }
202
203 fn hash(
204     hcx: &mut StableHashingContext<'tcx>,
205     node: &impl HashStable<StableHashingContext<'tcx>>,
206 ) -> Fingerprint {
207     let mut stable_hasher = StableHasher::new();
208     node.hash_stable(hcx, &mut stable_hasher);
209     stable_hasher.finish()
210 }