]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/instrument_coverage.rs
pin docs: add some forward references
[rust.git] / src / librustc_mir / transform / instrument_coverage.rs
1 use crate::transform::{MirPass, MirSource};
2 use crate::util::patch::MirPatch;
3 use rustc_data_structures::fingerprint::Fingerprint;
4 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
5 use rustc_hir::lang_items;
6 use rustc_middle::hir;
7 use rustc_middle::ich::StableHashingContext;
8 use rustc_middle::mir;
9 use rustc_middle::mir::coverage::*;
10 use rustc_middle::mir::interpret::Scalar;
11 use rustc_middle::mir::traversal;
12 use rustc_middle::mir::{
13     BasicBlock, BasicBlockData, CoverageInfo, Operand, Place, SourceInfo, SourceScope,
14     StatementKind, Terminator, TerminatorKind,
15 };
16 use rustc_middle::ty;
17 use rustc_middle::ty::query::Providers;
18 use rustc_middle::ty::{FnDef, TyCtxt};
19 use rustc_span::def_id::DefId;
20 use rustc_span::{FileName, Pos, RealFileName, Span};
21
22 /// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
23 /// the intrinsic llvm.instrprof.increment.
24 pub struct InstrumentCoverage;
25
26 /// The `query` provider for `CoverageInfo`, requested by `codegen_intrinsic_call()` when
27 /// constructing the arguments for `llvm.instrprof.increment`.
28 pub(crate) fn provide(providers: &mut Providers) {
29     providers.coverageinfo = |tcx, def_id| coverageinfo_from_mir(tcx, def_id);
30 }
31
32 fn coverageinfo_from_mir<'tcx>(tcx: TyCtxt<'tcx>, mir_def_id: DefId) -> CoverageInfo {
33     let mir_body = tcx.optimized_mir(mir_def_id);
34     // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
35     // represents a single function. Validate and/or correct if inlining (which should be disabled
36     // if -Zinstrument-coverage is enabled) and/or monomorphization invalidates these assumptions.
37     let count_code_region_fn = tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None);
38     let coverage_counter_add_fn =
39         tcx.require_lang_item(lang_items::CoverageCounterAddFnLangItem, None);
40     let coverage_counter_subtract_fn =
41         tcx.require_lang_item(lang_items::CoverageCounterSubtractFnLangItem, None);
42
43     // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
44     // counters, with each counter having a counter ID from `0..num_counters-1`. MIR optimization
45     // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
46     // work; but computing the num_counters by adding `1` to the highest counter_id (for a given
47     // instrumented function) is valid.
48     //
49     // `num_expressions` is the number of counter expressions added to the MIR body. Both
50     // `num_counters` and `num_expressions` are used to initialize new vectors, during backend
51     // code generate, to lookup counters and expressions by simple u32 indexes.
52     let mut num_counters: u32 = 0;
53     let mut num_expressions: u32 = 0;
54     for terminator in
55         traversal::preorder(mir_body).map(|(_, data)| data).filter_map(call_terminators)
56     {
57         if let TerminatorKind::Call { func: Operand::Constant(func), args, .. } = &terminator.kind {
58             match func.literal.ty.kind {
59                 FnDef(id, _) if id == count_code_region_fn => {
60                     let counter_id_arg =
61                         args.get(count_code_region_args::COUNTER_ID).expect("arg found");
62                     let counter_id = mir::Operand::scalar_from_const(counter_id_arg)
63                         .to_u32()
64                         .expect("counter_id arg is u32");
65                     num_counters = std::cmp::max(num_counters, counter_id + 1);
66                 }
67                 FnDef(id, _)
68                     if id == coverage_counter_add_fn || id == coverage_counter_subtract_fn =>
69                 {
70                     let expression_id_arg = args
71                         .get(coverage_counter_expression_args::EXPRESSION_ID)
72                         .expect("arg found");
73                     let id_descending_from_max = mir::Operand::scalar_from_const(expression_id_arg)
74                         .to_u32()
75                         .expect("expression_id arg is u32");
76                     // Counter expressions are initially assigned IDs descending from `u32::MAX`, so
77                     // the range of expression IDs is disjoint from the range of counter IDs. This
78                     // way, both counters and expressions can be operands in other expressions.
79                     let expression_index = u32::MAX - id_descending_from_max;
80                     num_expressions = std::cmp::max(num_expressions, expression_index + 1);
81                 }
82                 _ => {}
83             }
84         }
85     }
86     CoverageInfo { num_counters, num_expressions }
87 }
88
89 fn call_terminators(data: &'tcx BasicBlockData<'tcx>) -> Option<&'tcx Terminator<'tcx>> {
90     let terminator = data.terminator();
91     match terminator.kind {
92         TerminatorKind::Call { .. } => Some(terminator),
93         _ => None,
94     }
95 }
96
97 impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
98     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &mut mir::Body<'tcx>) {
99         // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
100         // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
101         if src.promoted.is_none() {
102             Instrumentor::new(tcx, src, mir_body).inject_counters();
103         }
104     }
105 }
106
107 /// Distinguishes the expression operators.
108 enum Op {
109     Add,
110     Subtract,
111 }
112
113 struct InjectedCall<'tcx> {
114     func: Operand<'tcx>,
115     args: Vec<Operand<'tcx>>,
116     span: Span,
117     inject_at: Span,
118 }
119
120 struct Instrumentor<'a, 'tcx> {
121     tcx: TyCtxt<'tcx>,
122     mir_def_id: DefId,
123     mir_body: &'a mut mir::Body<'tcx>,
124     hir_body: &'tcx rustc_hir::Body<'tcx>,
125     function_source_hash: Option<u64>,
126     num_counters: u32,
127     num_expressions: u32,
128 }
129
130 impl<'a, 'tcx> Instrumentor<'a, 'tcx> {
131     fn new(tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self {
132         let mir_def_id = src.def_id();
133         let hir_body = hir_body(tcx, mir_def_id);
134         Self {
135             tcx,
136             mir_def_id,
137             mir_body,
138             hir_body,
139             function_source_hash: None,
140             num_counters: 0,
141             num_expressions: 0,
142         }
143     }
144
145     /// Counter IDs start from zero and go up.
146     fn next_counter(&mut self) -> u32 {
147         assert!(self.num_counters < u32::MAX - self.num_expressions);
148         let next = self.num_counters;
149         self.num_counters += 1;
150         next
151     }
152
153     /// Expression IDs start from u32::MAX and go down because a CounterExpression can reference
154     /// (add or subtract counts) of both Counter regions and CounterExpression regions. The counter
155     /// expression operand IDs must be unique across both types.
156     fn next_expression(&mut self) -> u32 {
157         assert!(self.num_counters < u32::MAX - self.num_expressions);
158         let next = u32::MAX - self.num_expressions;
159         self.num_expressions += 1;
160         next
161     }
162
163     fn function_source_hash(&mut self) -> u64 {
164         match self.function_source_hash {
165             Some(hash) => hash,
166             None => {
167                 let hash = hash_mir_source(self.tcx, self.hir_body);
168                 self.function_source_hash.replace(hash);
169                 hash
170             }
171         }
172     }
173
174     fn inject_counters(&mut self) {
175         let mir_body = &self.mir_body;
176         let body_span = self.hir_body.value.span;
177         debug!("instrumenting {:?}, span: {:?}", self.mir_def_id, body_span);
178
179         // FIXME(richkadel): As a first step, counters are only injected at the top of each
180         // function. The complete solution will inject counters at each conditional code branch.
181         let _ignore = mir_body;
182         let id = self.next_counter();
183         let function_source_hash = self.function_source_hash();
184         let scope = rustc_middle::mir::OUTERMOST_SOURCE_SCOPE;
185         let is_cleanup = false;
186         let next_block = rustc_middle::mir::START_BLOCK;
187         self.inject_call(
188             self.make_counter(id, function_source_hash, body_span),
189             scope,
190             is_cleanup,
191             next_block,
192         );
193
194         // FIXME(richkadel): The next step to implement source based coverage analysis will be
195         // instrumenting branches within functions, and some regions will be counted by "counter
196         // expression". The function to inject counter expression is implemented. Replace this
197         // "fake use" with real use.
198         let fake_use = false;
199         if fake_use {
200             let add = false;
201             let lhs = 1;
202             let op = if add { Op::Add } else { Op::Subtract };
203             let rhs = 2;
204
205             let scope = rustc_middle::mir::OUTERMOST_SOURCE_SCOPE;
206             let is_cleanup = false;
207             let next_block = rustc_middle::mir::START_BLOCK;
208
209             let id = self.next_expression();
210             self.inject_call(
211                 self.make_expression(id, body_span, lhs, op, rhs),
212                 scope,
213                 is_cleanup,
214                 next_block,
215             );
216         }
217     }
218
219     fn make_counter(&self, id: u32, function_source_hash: u64, span: Span) -> InjectedCall<'tcx> {
220         let inject_at = span.shrink_to_lo();
221
222         let func = function_handle(
223             self.tcx,
224             self.tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None),
225             inject_at,
226         );
227
228         let mut args = Vec::new();
229
230         use count_code_region_args::*;
231         debug_assert_eq!(FUNCTION_SOURCE_HASH, args.len());
232         args.push(self.const_u64(function_source_hash, inject_at));
233
234         debug_assert_eq!(COUNTER_ID, args.len());
235         args.push(self.const_u32(id, inject_at));
236
237         InjectedCall { func, args, span, inject_at }
238     }
239
240     fn make_expression(
241         &self,
242         id: u32,
243         span: Span,
244         lhs: u32,
245         op: Op,
246         rhs: u32,
247     ) -> InjectedCall<'tcx> {
248         let inject_at = span.shrink_to_lo();
249
250         let func = function_handle(
251             self.tcx,
252             self.tcx.require_lang_item(
253                 match op {
254                     Op::Add => lang_items::CoverageCounterAddFnLangItem,
255                     Op::Subtract => lang_items::CoverageCounterSubtractFnLangItem,
256                 },
257                 None,
258             ),
259             inject_at,
260         );
261
262         let mut args = Vec::new();
263
264         use coverage_counter_expression_args::*;
265         debug_assert_eq!(EXPRESSION_ID, args.len());
266         args.push(self.const_u32(id, inject_at));
267
268         debug_assert_eq!(LEFT_ID, args.len());
269         args.push(self.const_u32(lhs, inject_at));
270
271         debug_assert_eq!(RIGHT_ID, args.len());
272         args.push(self.const_u32(rhs, inject_at));
273
274         InjectedCall { func, args, span, inject_at }
275     }
276
277     fn inject_call(
278         &mut self,
279         call: InjectedCall<'tcx>,
280         scope: SourceScope,
281         is_cleanup: bool,
282         next_block: BasicBlock,
283     ) {
284         let InjectedCall { func, mut args, span, inject_at } = call;
285         debug!(
286             "  injecting {}call to {:?}({:?}) at: {:?}, scope: {:?}",
287             if is_cleanup { "cleanup " } else { "" },
288             func,
289             args,
290             inject_at,
291             scope,
292         );
293
294         let mut patch = MirPatch::new(self.mir_body);
295
296         let (file_name, start_line, start_col, end_line, end_col) = self.code_region(&span);
297
298         args.push(self.const_str(&file_name, inject_at));
299         args.push(self.const_u32(start_line, inject_at));
300         args.push(self.const_u32(start_col, inject_at));
301         args.push(self.const_u32(end_line, inject_at));
302         args.push(self.const_u32(end_col, inject_at));
303
304         let temp = patch.new_temp(self.tcx.mk_unit(), inject_at);
305         let new_block = patch.new_block(placeholder_block(inject_at, scope, is_cleanup));
306         patch.patch_terminator(
307             new_block,
308             TerminatorKind::Call {
309                 func,
310                 args,
311                 // new_block will swapped with the next_block, after applying patch
312                 destination: Some((Place::from(temp), new_block)),
313                 cleanup: None,
314                 from_hir_call: false,
315                 fn_span: inject_at,
316             },
317         );
318
319         patch.add_statement(new_block.start_location(), StatementKind::StorageLive(temp));
320         patch.add_statement(next_block.start_location(), StatementKind::StorageDead(temp));
321
322         patch.apply(self.mir_body);
323
324         // To insert the `new_block` in front of the first block in the counted branch (the
325         // `next_block`), just swap the indexes, leaving the rest of the graph unchanged.
326         self.mir_body.basic_blocks_mut().swap(next_block, new_block);
327     }
328
329     /// Convert the Span into its file name, start line and column, and end line and column
330     fn code_region(&self, span: &Span) -> (String, u32, u32, u32, u32) {
331         let source_map = self.tcx.sess.source_map();
332         let start = source_map.lookup_char_pos(span.lo());
333         let end = if span.hi() == span.lo() {
334             start.clone()
335         } else {
336             let end = source_map.lookup_char_pos(span.hi());
337             debug_assert_eq!(
338                 start.file.name,
339                 end.file.name,
340                 "Region start ({:?} -> {:?}) and end ({:?} -> {:?}) don't come from the same source file!",
341                 span.lo(),
342                 start,
343                 span.hi(),
344                 end
345             );
346             end
347         };
348         match &start.file.name {
349             FileName::Real(RealFileName::Named(path)) => (
350                 path.to_string_lossy().to_string(),
351                 start.line as u32,
352                 start.col.to_u32() + 1,
353                 end.line as u32,
354                 end.col.to_u32() + 1,
355             ),
356             _ => {
357                 bug!("start.file.name should be a RealFileName, but it was: {:?}", start.file.name)
358             }
359         }
360     }
361
362     fn const_str(&self, value: &str, span: Span) -> Operand<'tcx> {
363         Operand::const_from_str(self.tcx, value, span)
364     }
365
366     fn const_u32(&self, value: u32, span: Span) -> Operand<'tcx> {
367         Operand::const_from_scalar(self.tcx, self.tcx.types.u32, Scalar::from_u32(value), span)
368     }
369
370     fn const_u64(&self, value: u64, span: Span) -> Operand<'tcx> {
371         Operand::const_from_scalar(self.tcx, self.tcx.types.u64, Scalar::from_u64(value), span)
372     }
373 }
374
375 fn function_handle<'tcx>(tcx: TyCtxt<'tcx>, fn_def_id: DefId, span: Span) -> Operand<'tcx> {
376     let ret_ty = tcx.fn_sig(fn_def_id).output();
377     let ret_ty = ret_ty.no_bound_vars().unwrap();
378     let substs = tcx.mk_substs(::std::iter::once(ty::subst::GenericArg::from(ret_ty)));
379     Operand::function_handle(tcx, fn_def_id, substs, span)
380 }
381
382 fn placeholder_block(span: Span, scope: SourceScope, is_cleanup: bool) -> BasicBlockData<'tcx> {
383     BasicBlockData {
384         statements: vec![],
385         terminator: Some(Terminator {
386             source_info: SourceInfo { span, scope },
387             // this gets overwritten by the counter Call
388             kind: TerminatorKind::Unreachable,
389         }),
390         is_cleanup,
391     }
392 }
393
394 fn hir_body<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx rustc_hir::Body<'tcx> {
395     let hir_node = tcx.hir().get_if_local(def_id).expect("DefId is local");
396     let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body");
397     tcx.hir().body(fn_body_id)
398 }
399
400 fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 {
401     let mut hcx = tcx.create_no_span_stable_hashing_context();
402     hash(&mut hcx, &hir_body.value).to_smaller_hash()
403 }
404
405 fn hash(
406     hcx: &mut StableHashingContext<'tcx>,
407     node: &impl HashStable<StableHashingContext<'tcx>>,
408 ) -> Fingerprint {
409     let mut stable_hasher = StableHasher::new();
410     node.hash_stable(hcx, &mut stable_hasher);
411     stable_hasher.finish()
412 }