]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/instrument_coverage.rs
Rollup merge of #74759 - carbotaniuman:uabs, r=shepmaster
[rust.git] / src / librustc_mir / transform / instrument_coverage.rs
1 use crate::transform::{MirPass, MirSource};
2 use crate::util::patch::MirPatch;
3 use rustc_data_structures::fingerprint::Fingerprint;
4 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
5 use rustc_hir::lang_items;
6 use rustc_middle::hir;
7 use rustc_middle::ich::StableHashingContext;
8 use rustc_middle::mir::coverage::*;
9 use rustc_middle::mir::interpret::Scalar;
10 use rustc_middle::mir::{
11     self, traversal, BasicBlock, BasicBlockData, CoverageInfo, Operand, Place, SourceInfo,
12     SourceScope, StatementKind, Terminator, TerminatorKind,
13 };
14 use rustc_middle::ty;
15 use rustc_middle::ty::query::Providers;
16 use rustc_middle::ty::FnDef;
17 use rustc_middle::ty::TyCtxt;
18 use rustc_span::def_id::DefId;
19 use rustc_span::{Pos, Span};
20
21 /// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
22 /// the intrinsic llvm.instrprof.increment.
23 pub struct InstrumentCoverage;
24
25 /// The `query` provider for `CoverageInfo`, requested by `codegen_intrinsic_call()` when
26 /// constructing the arguments for `llvm.instrprof.increment`.
27 pub(crate) fn provide(providers: &mut Providers) {
28     providers.coverageinfo = |tcx, def_id| coverageinfo_from_mir(tcx, def_id);
29 }
30
31 fn coverageinfo_from_mir<'tcx>(tcx: TyCtxt<'tcx>, mir_def_id: DefId) -> CoverageInfo {
32     let mir_body = tcx.optimized_mir(mir_def_id);
33     // FIXME(richkadel): The current implementation assumes the MIR for the given DefId
34     // represents a single function. Validate and/or correct if inlining (which should be disabled
35     // if -Zinstrument-coverage is enabled) and/or monomorphization invalidates these assumptions.
36     let count_code_region_fn = tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None);
37     let coverage_counter_add_fn =
38         tcx.require_lang_item(lang_items::CoverageCounterAddFnLangItem, None);
39     let coverage_counter_subtract_fn =
40         tcx.require_lang_item(lang_items::CoverageCounterSubtractFnLangItem, None);
41
42     // The `num_counters` argument to `llvm.instrprof.increment` is the number of injected
43     // counters, with each counter having a counter ID from `0..num_counters-1`. MIR optimization
44     // may split and duplicate some BasicBlock sequences. Simply counting the calls may not
45     // work; but computing the num_counters by adding `1` to the highest counter_id (for a given
46     // instrumented function) is valid.
47     //
48     // `num_expressions` is the number of counter expressions added to the MIR body. Both
49     // `num_counters` and `num_expressions` are used to initialize new vectors, during backend
50     // code generate, to lookup counters and expressions by simple u32 indexes.
51     let mut num_counters: u32 = 0;
52     let mut num_expressions: u32 = 0;
53     for terminator in
54         traversal::preorder(mir_body).map(|(_, data)| data).filter_map(call_terminators)
55     {
56         if let TerminatorKind::Call { func: Operand::Constant(func), args, .. } = &terminator.kind {
57             match func.literal.ty.kind {
58                 FnDef(id, _) if id == count_code_region_fn => {
59                     let counter_id_arg =
60                         args.get(count_code_region_args::COUNTER_ID).expect("arg found");
61                     let counter_id = mir::Operand::scalar_from_const(counter_id_arg)
62                         .to_u32()
63                         .expect("counter_id arg is u32");
64                     num_counters = std::cmp::max(num_counters, counter_id + 1);
65                 }
66                 FnDef(id, _)
67                     if id == coverage_counter_add_fn || id == coverage_counter_subtract_fn =>
68                 {
69                     let expression_id_arg = args
70                         .get(coverage_counter_expression_args::EXPRESSION_ID)
71                         .expect("arg found");
72                     let id_descending_from_max = mir::Operand::scalar_from_const(expression_id_arg)
73                         .to_u32()
74                         .expect("expression_id arg is u32");
75                     // Counter expressions are initially assigned IDs descending from `u32::MAX`, so
76                     // the range of expression IDs is disjoint from the range of counter IDs. This
77                     // way, both counters and expressions can be operands in other expressions.
78                     let expression_index = u32::MAX - id_descending_from_max;
79                     num_expressions = std::cmp::max(num_expressions, expression_index + 1);
80                 }
81                 _ => {}
82             }
83         }
84     }
85     CoverageInfo { num_counters, num_expressions }
86 }
87
88 fn call_terminators(data: &'tcx BasicBlockData<'tcx>) -> Option<&'tcx Terminator<'tcx>> {
89     let terminator = data.terminator();
90     match terminator.kind {
91         TerminatorKind::Call { .. } => Some(terminator),
92         _ => None,
93     }
94 }
95
96 impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
97     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &mut mir::Body<'tcx>) {
98         // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
99         // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
100         if src.promoted.is_none() {
101             Instrumentor::new(tcx, src, mir_body).inject_counters();
102         }
103     }
104 }
105
106 /// Distinguishes the expression operators.
107 enum Op {
108     Add,
109     Subtract,
110 }
111
112 struct InjectedCall<'tcx> {
113     func: Operand<'tcx>,
114     args: Vec<Operand<'tcx>>,
115     inject_at: Span,
116 }
117
118 struct Instrumentor<'a, 'tcx> {
119     tcx: TyCtxt<'tcx>,
120     mir_def_id: DefId,
121     mir_body: &'a mut mir::Body<'tcx>,
122     hir_body: &'tcx rustc_hir::Body<'tcx>,
123     function_source_hash: Option<u64>,
124     num_counters: u32,
125     num_expressions: u32,
126 }
127
128 impl<'a, 'tcx> Instrumentor<'a, 'tcx> {
129     fn new(tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &'a mut mir::Body<'tcx>) -> Self {
130         let mir_def_id = src.def_id();
131         let hir_body = hir_body(tcx, mir_def_id);
132         Self {
133             tcx,
134             mir_def_id,
135             mir_body,
136             hir_body,
137             function_source_hash: None,
138             num_counters: 0,
139             num_expressions: 0,
140         }
141     }
142
143     /// Counter IDs start from zero and go up.
144     fn next_counter(&mut self) -> u32 {
145         assert!(self.num_counters < u32::MAX - self.num_expressions);
146         let next = self.num_counters;
147         self.num_counters += 1;
148         next
149     }
150
151     /// Expression IDs start from u32::MAX and go down because a CounterExpression can reference
152     /// (add or subtract counts) of both Counter regions and CounterExpression regions. The counter
153     /// expression operand IDs must be unique across both types.
154     fn next_expression(&mut self) -> u32 {
155         assert!(self.num_counters < u32::MAX - self.num_expressions);
156         let next = u32::MAX - self.num_expressions;
157         self.num_expressions += 1;
158         next
159     }
160
161     fn function_source_hash(&mut self) -> u64 {
162         match self.function_source_hash {
163             Some(hash) => hash,
164             None => {
165                 let hash = hash_mir_source(self.tcx, self.hir_body);
166                 self.function_source_hash.replace(hash);
167                 hash
168             }
169         }
170     }
171
172     fn inject_counters(&mut self) {
173         let mir_body = &self.mir_body;
174         let body_span = self.hir_body.value.span;
175         debug!("instrumenting {:?}, span: {:?}", self.mir_def_id, body_span);
176
177         // FIXME(richkadel): As a first step, counters are only injected at the top of each
178         // function. The complete solution will inject counters at each conditional code branch.
179         let _ignore = mir_body;
180         let id = self.next_counter();
181         let function_source_hash = self.function_source_hash();
182         let code_region = body_span;
183         let scope = rustc_middle::mir::OUTERMOST_SOURCE_SCOPE;
184         let is_cleanup = false;
185         let next_block = rustc_middle::mir::START_BLOCK;
186         self.inject_call(
187             self.make_counter(id, function_source_hash, code_region),
188             scope,
189             is_cleanup,
190             next_block,
191         );
192
193         // FIXME(richkadel): The next step to implement source based coverage analysis will be
194         // instrumenting branches within functions, and some regions will be counted by "counter
195         // expression". The function to inject counter expression is implemented. Replace this
196         // "fake use" with real use.
197         let fake_use = false;
198         if fake_use {
199             let add = false;
200             let lhs = 1;
201             let op = if add { Op::Add } else { Op::Subtract };
202             let rhs = 2;
203
204             let code_region = body_span;
205             let scope = rustc_middle::mir::OUTERMOST_SOURCE_SCOPE;
206             let is_cleanup = false;
207             let next_block = rustc_middle::mir::START_BLOCK;
208
209             let id = self.next_expression();
210             self.inject_call(
211                 self.make_expression(id, code_region, lhs, op, rhs),
212                 scope,
213                 is_cleanup,
214                 next_block,
215             );
216         }
217     }
218
219     fn make_counter(
220         &self,
221         id: u32,
222         function_source_hash: u64,
223         code_region: Span,
224     ) -> InjectedCall<'tcx> {
225         let inject_at = code_region.shrink_to_lo();
226
227         let func = function_handle(
228             self.tcx,
229             self.tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None),
230             inject_at,
231         );
232
233         let mut args = Vec::new();
234
235         use count_code_region_args::*;
236         debug_assert_eq!(FUNCTION_SOURCE_HASH, args.len());
237         args.push(self.const_u64(function_source_hash, inject_at));
238
239         debug_assert_eq!(COUNTER_ID, args.len());
240         args.push(self.const_u32(id, inject_at));
241
242         debug_assert_eq!(START_BYTE_POS, args.len());
243         args.push(self.const_u32(code_region.lo().to_u32(), inject_at));
244
245         debug_assert_eq!(END_BYTE_POS, args.len());
246         args.push(self.const_u32(code_region.hi().to_u32(), inject_at));
247
248         InjectedCall { func, args, inject_at }
249     }
250
251     fn make_expression(
252         &self,
253         id: u32,
254         code_region: Span,
255         lhs: u32,
256         op: Op,
257         rhs: u32,
258     ) -> InjectedCall<'tcx> {
259         let inject_at = code_region.shrink_to_lo();
260
261         let func = function_handle(
262             self.tcx,
263             self.tcx.require_lang_item(
264                 match op {
265                     Op::Add => lang_items::CoverageCounterAddFnLangItem,
266                     Op::Subtract => lang_items::CoverageCounterSubtractFnLangItem,
267                 },
268                 None,
269             ),
270             inject_at,
271         );
272
273         let mut args = Vec::new();
274
275         use coverage_counter_expression_args::*;
276         debug_assert_eq!(EXPRESSION_ID, args.len());
277         args.push(self.const_u32(id, inject_at));
278
279         debug_assert_eq!(LEFT_ID, args.len());
280         args.push(self.const_u32(lhs, inject_at));
281
282         debug_assert_eq!(RIGHT_ID, args.len());
283         args.push(self.const_u32(rhs, inject_at));
284
285         debug_assert_eq!(START_BYTE_POS, args.len());
286         args.push(self.const_u32(code_region.lo().to_u32(), inject_at));
287
288         debug_assert_eq!(END_BYTE_POS, args.len());
289         args.push(self.const_u32(code_region.hi().to_u32(), inject_at));
290
291         InjectedCall { func, args, inject_at }
292     }
293
294     fn inject_call(
295         &mut self,
296         call: InjectedCall<'tcx>,
297         scope: SourceScope,
298         is_cleanup: bool,
299         next_block: BasicBlock,
300     ) {
301         let InjectedCall { func, args, inject_at } = call;
302         debug!(
303             "  injecting {}call to {:?}({:?}) at: {:?}, scope: {:?}",
304             if is_cleanup { "cleanup " } else { "" },
305             func,
306             args,
307             inject_at,
308             scope,
309         );
310
311         let mut patch = MirPatch::new(self.mir_body);
312
313         let temp = patch.new_temp(self.tcx.mk_unit(), inject_at);
314         let new_block = patch.new_block(placeholder_block(inject_at, scope, is_cleanup));
315         patch.patch_terminator(
316             new_block,
317             TerminatorKind::Call {
318                 func,
319                 args,
320                 // new_block will swapped with the next_block, after applying patch
321                 destination: Some((Place::from(temp), new_block)),
322                 cleanup: None,
323                 from_hir_call: false,
324                 fn_span: inject_at,
325             },
326         );
327
328         patch.add_statement(new_block.start_location(), StatementKind::StorageLive(temp));
329         patch.add_statement(next_block.start_location(), StatementKind::StorageDead(temp));
330
331         patch.apply(self.mir_body);
332
333         // To insert the `new_block` in front of the first block in the counted branch (the
334         // `next_block`), just swap the indexes, leaving the rest of the graph unchanged.
335         self.mir_body.basic_blocks_mut().swap(next_block, new_block);
336     }
337
338     fn const_u32(&self, value: u32, span: Span) -> Operand<'tcx> {
339         Operand::const_from_scalar(self.tcx, self.tcx.types.u32, Scalar::from_u32(value), span)
340     }
341
342     fn const_u64(&self, value: u64, span: Span) -> Operand<'tcx> {
343         Operand::const_from_scalar(self.tcx, self.tcx.types.u64, Scalar::from_u64(value), span)
344     }
345 }
346
347 fn function_handle<'tcx>(tcx: TyCtxt<'tcx>, fn_def_id: DefId, span: Span) -> Operand<'tcx> {
348     let ret_ty = tcx.fn_sig(fn_def_id).output();
349     let ret_ty = ret_ty.no_bound_vars().unwrap();
350     let substs = tcx.mk_substs(::std::iter::once(ty::subst::GenericArg::from(ret_ty)));
351     Operand::function_handle(tcx, fn_def_id, substs, span)
352 }
353
354 fn placeholder_block(span: Span, scope: SourceScope, is_cleanup: bool) -> BasicBlockData<'tcx> {
355     BasicBlockData {
356         statements: vec![],
357         terminator: Some(Terminator {
358             source_info: SourceInfo { span, scope },
359             // this gets overwritten by the counter Call
360             kind: TerminatorKind::Unreachable,
361         }),
362         is_cleanup,
363     }
364 }
365
366 fn hir_body<'tcx>(tcx: TyCtxt<'tcx>, def_id: DefId) -> &'tcx rustc_hir::Body<'tcx> {
367     let hir_node = tcx.hir().get_if_local(def_id).expect("DefId is local");
368     let fn_body_id = hir::map::associated_body(hir_node).expect("HIR node is a function with body");
369     tcx.hir().body(fn_body_id)
370 }
371
372 fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, hir_body: &'tcx rustc_hir::Body<'tcx>) -> u64 {
373     let mut hcx = tcx.create_no_span_stable_hashing_context();
374     hash(&mut hcx, &hir_body.value).to_smaller_hash()
375 }
376
377 fn hash(
378     hcx: &mut StableHashingContext<'tcx>,
379     node: &impl HashStable<StableHashingContext<'tcx>>,
380 ) -> Fingerprint {
381     let mut stable_hasher = StableHasher::new();
382     node.hash_stable(hcx, &mut stable_hasher);
383     stable_hasher.finish()
384 }