]> git.lizzy.rs Git - rust.git/blob - src/librustc_mir/transform/instrument_coverage.rs
94aa26b3081e52b44571a012c56296cd3a644651
[rust.git] / src / librustc_mir / transform / instrument_coverage.rs
1 use crate::transform::{MirPass, MirSource};
2 use crate::util::patch::MirPatch;
3 use rustc_data_structures::fingerprint::Fingerprint;
4 use rustc_data_structures::stable_hasher::{HashStable, StableHasher};
5 use rustc_hir::lang_items;
6 use rustc_middle::hir;
7 use rustc_middle::ich::StableHashingContext;
8 use rustc_middle::mir::interpret::Scalar;
9 use rustc_middle::mir::{
10     self, BasicBlock, BasicBlockData, CoverageData, Operand, Place, SourceInfo, StatementKind,
11     Terminator, TerminatorKind, START_BLOCK,
12 };
13 use rustc_middle::ty;
14 use rustc_middle::ty::TyCtxt;
15 use rustc_span::def_id::DefId;
16 use rustc_span::Span;
17
18 /// Inserts call to count_code_region() as a placeholder to be replaced during code generation with
19 /// the intrinsic llvm.instrprof.increment.
20 pub struct InstrumentCoverage;
21
22 struct Instrumentor<'tcx> {
23     tcx: TyCtxt<'tcx>,
24     num_counters: u32,
25 }
26
27 impl<'tcx> MirPass<'tcx> for InstrumentCoverage {
28     fn run_pass(&self, tcx: TyCtxt<'tcx>, src: MirSource<'tcx>, mir_body: &mut mir::Body<'tcx>) {
29         if tcx.sess.opts.debugging_opts.instrument_coverage {
30             // If the InstrumentCoverage pass is called on promoted MIRs, skip them.
31             // See: https://github.com/rust-lang/rust/pull/73011#discussion_r438317601
32             if src.promoted.is_none() {
33                 assert!(mir_body.coverage_data.is_none());
34
35                 let hash = hash_mir_source(tcx, &src);
36
37                 debug!(
38                     "instrumenting {:?}, hash: {}, span: {}",
39                     src.def_id(),
40                     hash,
41                     tcx.sess.source_map().span_to_string(mir_body.span)
42                 );
43
44                 let num_counters = Instrumentor::new(tcx).inject_counters(mir_body);
45
46                 mir_body.coverage_data = Some(CoverageData { hash, num_counters });
47             }
48         }
49     }
50 }
51
52 impl<'tcx> Instrumentor<'tcx> {
53     fn new(tcx: TyCtxt<'tcx>) -> Self {
54         Self { tcx, num_counters: 0 }
55     }
56
57     fn next_counter(&mut self) -> u32 {
58         let next = self.num_counters;
59         self.num_counters += 1;
60         next
61     }
62
63     fn inject_counters(&mut self, mir_body: &mut mir::Body<'tcx>) -> u32 {
64         // FIXME(richkadel): As a first step, counters are only injected at the top of each
65         // function. The complete solution will inject counters at each conditional code branch.
66         let top_of_function = START_BLOCK;
67         let entire_function = mir_body.span;
68
69         self.inject_counter(mir_body, top_of_function, entire_function);
70
71         self.num_counters
72     }
73
74     fn inject_counter(
75         &mut self,
76         mir_body: &mut mir::Body<'tcx>,
77         next_block: BasicBlock,
78         code_region: Span,
79     ) {
80         let injection_point = code_region.shrink_to_lo();
81
82         let count_code_region_fn = function_handle(
83             self.tcx,
84             self.tcx.require_lang_item(lang_items::CountCodeRegionFnLangItem, None),
85             injection_point,
86         );
87         let counter_index = Operand::const_from_scalar(
88             self.tcx,
89             self.tcx.types.u32,
90             Scalar::from_u32(self.next_counter()),
91             injection_point,
92         );
93
94         let mut patch = MirPatch::new(mir_body);
95
96         let temp = patch.new_temp(self.tcx.mk_unit(), code_region);
97         let new_block = patch.new_block(placeholder_block(code_region));
98         patch.patch_terminator(
99             new_block,
100             TerminatorKind::Call {
101                 func: count_code_region_fn,
102                 args: vec![counter_index],
103                 // new_block will swapped with the next_block, after applying patch
104                 destination: Some((Place::from(temp), new_block)),
105                 cleanup: None,
106                 from_hir_call: false,
107                 fn_span: injection_point,
108             },
109         );
110
111         patch.add_statement(new_block.start_location(), StatementKind::StorageLive(temp));
112         patch.add_statement(next_block.start_location(), StatementKind::StorageDead(temp));
113
114         patch.apply(mir_body);
115
116         // To insert the `new_block` in front of the first block in the counted branch (the
117         // `next_block`), just swap the indexes, leaving the rest of the graph unchanged.
118         mir_body.basic_blocks_mut().swap(next_block, new_block);
119     }
120 }
121
122 fn function_handle<'tcx>(tcx: TyCtxt<'tcx>, fn_def_id: DefId, span: Span) -> Operand<'tcx> {
123     let ret_ty = tcx.fn_sig(fn_def_id).output();
124     let ret_ty = ret_ty.no_bound_vars().unwrap();
125     let substs = tcx.mk_substs(::std::iter::once(ty::subst::GenericArg::from(ret_ty)));
126     Operand::function_handle(tcx, fn_def_id, substs, span)
127 }
128
129 fn placeholder_block(span: Span) -> BasicBlockData<'tcx> {
130     BasicBlockData {
131         statements: vec![],
132         terminator: Some(Terminator {
133             source_info: SourceInfo::outermost(span),
134             // this gets overwritten by the counter Call
135             kind: TerminatorKind::Unreachable,
136         }),
137         is_cleanup: false,
138     }
139 }
140
141 fn hash_mir_source<'tcx>(tcx: TyCtxt<'tcx>, src: &MirSource<'tcx>) -> u64 {
142     let fn_body_id = match tcx.hir().get_if_local(src.def_id()) {
143         Some(node) => match hir::map::associated_body(node) {
144             Some(body_id) => body_id,
145             _ => bug!("instrumented MirSource does not include a function body: {:?}", node),
146         },
147         None => bug!("instrumented MirSource is not local: {:?}", src),
148     };
149     let hir_body = tcx.hir().body(fn_body_id);
150     let mut hcx = tcx.create_no_span_stable_hashing_context();
151     hash(&mut hcx, &hir_body.value).to_smaller_hash()
152 }
153
154 fn hash(
155     hcx: &mut StableHashingContext<'tcx>,
156     node: &impl HashStable<StableHashingContext<'tcx>>,
157 ) -> Fingerprint {
158     let mut stable_hasher = StableHasher::new();
159     node.hash_stable(hcx, &mut stable_hasher);
160     stable_hasher.finish()
161 }