]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Rollup merge of #93269 - jacobbramley:dev/pauth-option-1, r=petrochenkov
[rust.git] / compiler / rustc_mir_transform / src / early_otherwise_branch.rs
1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```text
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 ///     (Some(_), Some(_)) => {0},
14 ///     _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```text
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = // get discriminant of x
22 /// let discriminant_y = // get discriminant of y
23 /// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
24 /// ```
25 pub struct EarlyOtherwiseBranch;
26
27 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
28     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
29         //  FIXME(#78496)
30         sess.opts.debugging_opts.unsound_mir_opts && sess.mir_opt_level() >= 3
31     }
32
33     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
34         trace!("running EarlyOtherwiseBranch on {:?}", body.source);
35
36         // we are only interested in this bb if the terminator is a switchInt
37         let bbs_with_switch =
38             body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
39
40         let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
41             .flat_map(|(bb_idx, bb)| {
42                 let switch = bb.terminator();
43                 let helper = Helper { body, tcx };
44                 let infos = helper.go(bb, switch)?;
45                 Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
46             })
47             .collect();
48
49         let should_cleanup = !opts_to_apply.is_empty();
50
51         for opt_to_apply in opts_to_apply {
52             if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) {
53                 break;
54             }
55
56             trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
57
58             let statements_before =
59                 body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
60             let end_of_block_location = Location {
61                 block: opt_to_apply.basic_block_first_switch,
62                 statement_index: statements_before,
63             };
64
65             let mut patch = MirPatch::new(body);
66
67             // create temp to store second discriminant in
68             let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
69             let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
70             let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
71
72             patch.add_statement(
73                 end_of_block_location,
74                 StatementKind::StorageLive(second_discriminant_temp),
75             );
76
77             // create assignment of discriminant
78             let place_of_adt_to_get_discriminant_of =
79                 opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
80             patch.add_assign(
81                 end_of_block_location,
82                 Place::from(second_discriminant_temp),
83                 Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
84             );
85
86             // create temp to store NotEqual comparison between the two discriminants
87             let not_equal = BinOp::Ne;
88             let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
89             let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
90             patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
91
92             // create NotEqual comparison between the two discriminants
93             let first_descriminant_place =
94                 opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
95             let not_equal_rvalue = Rvalue::BinaryOp(
96                 not_equal,
97                 Box::new((
98                     Operand::Copy(Place::from(second_discriminant_temp)),
99                     Operand::Copy(first_descriminant_place),
100                 )),
101             );
102             patch.add_statement(
103                 end_of_block_location,
104                 StatementKind::Assign(Box::new((Place::from(not_equal_temp), not_equal_rvalue))),
105             );
106
107             let new_targets = opt_to_apply
108                 .infos
109                 .iter()
110                 .flat_map(|x| x.second_switch_info.targets_with_values.iter())
111                 .cloned();
112
113             let targets = SwitchTargets::new(
114                 new_targets,
115                 opt_to_apply.infos[0].first_switch_info.otherwise_bb,
116             );
117
118             // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
119             let new_switch_data = BasicBlockData::new(Some(Terminator {
120                 source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
121                 kind: TerminatorKind::SwitchInt {
122                     // the first and second discriminants are equal, so just pick one
123                     discr: Operand::Copy(first_descriminant_place),
124                     switch_ty: discr_type,
125                     targets,
126                 },
127             }));
128
129             let new_switch_bb = patch.new_block(new_switch_data);
130
131             // switch on the NotEqual. If true, then jump to the `otherwise` case.
132             // If false, then jump to a basic block that then jumps to the correct disciminant case
133             let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
134             let false_case = new_switch_bb;
135             patch.patch_terminator(
136                 opt_to_apply.basic_block_first_switch,
137                 TerminatorKind::if_(
138                     tcx,
139                     Operand::Move(Place::from(not_equal_temp)),
140                     true_case,
141                     false_case,
142                 ),
143             );
144
145             // generate StorageDead for the second_discriminant_temp not in use anymore
146             patch.add_statement(
147                 end_of_block_location,
148                 StatementKind::StorageDead(second_discriminant_temp),
149             );
150
151             // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
152             for bb in [false_case, true_case].iter() {
153                 patch.add_statement(
154                     Location { block: *bb, statement_index: 0 },
155                     StatementKind::StorageDead(not_equal_temp),
156                 );
157             }
158
159             patch.apply(body);
160         }
161
162         // Since this optimization adds new basic blocks and invalidates others,
163         // clean up the cfg to make it nicer for other passes
164         if should_cleanup {
165             simplify_cfg(tcx, body);
166         }
167     }
168 }
169
170 fn is_switch(terminator: &Terminator<'_>) -> bool {
171     matches!(terminator.kind, TerminatorKind::SwitchInt { .. })
172 }
173
174 struct Helper<'a, 'tcx> {
175     body: &'a Body<'tcx>,
176     tcx: TyCtxt<'tcx>,
177 }
178
179 #[derive(Debug, Clone)]
180 struct SwitchDiscriminantInfo<'tcx> {
181     /// Type of the discriminant being switched on
182     discr_ty: Ty<'tcx>,
183     /// The basic block that the otherwise branch points to
184     otherwise_bb: BasicBlock,
185     /// Target along with the value being branched from. Otherwise is not included
186     targets_with_values: Vec<(u128, BasicBlock)>,
187     discr_source_info: SourceInfo,
188     /// The place of the discriminant used in the switch
189     discr_used_in_switch: Place<'tcx>,
190     /// The place of the adt that has its discriminant read
191     place_of_adt_discr_read: Place<'tcx>,
192     /// The type of the adt that has its discriminant read
193     type_adt_matched_on: Ty<'tcx>,
194 }
195
196 #[derive(Debug)]
197 struct OptimizationToApply<'tcx> {
198     infos: Vec<OptimizationInfo<'tcx>>,
199     /// Basic block of the original first switch
200     basic_block_first_switch: BasicBlock,
201 }
202
203 #[derive(Debug)]
204 struct OptimizationInfo<'tcx> {
205     /// Info about the first switch and discriminant
206     first_switch_info: SwitchDiscriminantInfo<'tcx>,
207     /// Info about the second switch and discriminant
208     second_switch_info: SwitchDiscriminantInfo<'tcx>,
209 }
210
211 impl<'tcx> Helper<'_, 'tcx> {
212     pub fn go(
213         &self,
214         bb: &BasicBlockData<'tcx>,
215         switch: &Terminator<'tcx>,
216     ) -> Option<Vec<OptimizationInfo<'tcx>>> {
217         // try to find the statement that defines the discriminant that is used for the switch
218         let discr = self.find_switch_discriminant_info(bb, switch)?;
219
220         // go through each target, finding a discriminant read, and a switch
221         let results = discr
222             .targets_with_values
223             .iter()
224             .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value));
225
226         // if the optimization did not apply for one of the targets, then abort
227         if results.clone().any(|x| x.is_none()) || results.len() == 0 {
228             trace!("NO: not all of the targets matched the pattern for optimization");
229             return None;
230         }
231
232         Some(results.flatten().collect())
233     }
234
235     fn find_discriminant_switch_pairing(
236         &self,
237         discr_info: &SwitchDiscriminantInfo<'tcx>,
238         target: BasicBlock,
239         value: u128,
240     ) -> Option<OptimizationInfo<'tcx>> {
241         let bb = &self.body.basic_blocks()[target];
242         // find switch
243         let terminator = bb.terminator();
244         if is_switch(terminator) {
245             let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
246
247             // the types of the two adts matched on have to be equalfor this optimization to apply
248             if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
249                 trace!(
250                     "NO: types do not match. LHS: {:?}, RHS: {:?}",
251                     discr_info.type_adt_matched_on,
252                     this_bb_discr_info.type_adt_matched_on
253                 );
254                 return None;
255             }
256
257             // the otherwise branch of the two switches have to point to the same bb
258             if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
259                 trace!("NO: otherwise target is not the same");
260                 return None;
261             }
262
263             // check that the value being matched on is the same. The
264             if !this_bb_discr_info.targets_with_values.iter().any(|x| x.0 == value) {
265                 trace!("NO: values being matched on are not the same");
266                 return None;
267             }
268
269             // only allow optimization if the left and right of the tuple being matched are the same variants.
270             // so the following should not optimize
271             //  ```rust
272             // let x: Option<()>;
273             // let y: Option<()>;
274             // match (x,y) {
275             //     (Some(_), None) => {},
276             //     _ => {}
277             // }
278             //  ```
279             // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
280             if !(this_bb_discr_info.targets_with_values.len() == 1
281                 && this_bb_discr_info.targets_with_values[0].0 == value)
282             {
283                 trace!(
284                     "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
285                 );
286                 return None;
287             }
288
289             // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially.
290             // for example, this should not be optimized:
291             //
292             // ```rust
293             // enum E<'a> { Empty, Some(&'a E<'a>), }
294             // let Some(Some(_)) = e;
295             // ```
296             //
297             // ```mir
298             // bb0: {
299             //   _2 = discriminant(*_1)
300             //   switchInt(_2) -> [...]
301             // }
302             // bb1: {
303             //   _3 = discriminant(*(((*_1) as Some).0: &E))
304             //   switchInt(_3) -> [...]
305             // }
306             // ```
307             let discr_place = discr_info.place_of_adt_discr_read;
308             let this_discr_place = this_bb_discr_info.place_of_adt_discr_read;
309             if discr_place.local == this_discr_place.local
310                 && this_discr_place.projection.starts_with(discr_place.projection)
311             {
312                 trace!("NO: one target is the projection of another");
313                 return None;
314             }
315
316             // if we reach this point, the optimization applies, and we should be able to optimize this case
317             // store the info that is needed to apply the optimization
318
319             Some(OptimizationInfo {
320                 first_switch_info: discr_info.clone(),
321                 second_switch_info: this_bb_discr_info,
322             })
323         } else {
324             None
325         }
326     }
327
328     fn find_switch_discriminant_info(
329         &self,
330         bb: &BasicBlockData<'tcx>,
331         switch: &Terminator<'tcx>,
332     ) -> Option<SwitchDiscriminantInfo<'tcx>> {
333         match &switch.kind {
334             TerminatorKind::SwitchInt { discr, targets, .. } => {
335                 let discr_local = discr.place()?.as_local()?;
336                 // the declaration of the discriminant read. Place of this read is being used in the switch
337                 let discr_decl = &self.body.local_decls()[discr_local];
338                 let discr_ty = discr_decl.ty;
339                 // the otherwise target lies as the last element
340                 let otherwise_bb = targets.otherwise();
341                 let targets_with_values = targets.iter().collect();
342
343                 // find the place of the adt where the discriminant is being read from
344                 // assume this is the last statement of the block
345                 let place_of_adt_discr_read = match bb.statements.last()?.kind {
346                     StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
347                         Some(adt_place)
348                     }
349                     _ => None,
350                 }?;
351
352                 let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
353
354                 Some(SwitchDiscriminantInfo {
355                     discr_used_in_switch: discr.place()?,
356                     discr_ty,
357                     otherwise_bb,
358                     targets_with_values,
359                     discr_source_info: discr_decl.source_info,
360                     place_of_adt_discr_read,
361                     type_adt_matched_on,
362                 })
363             }
364             _ => unreachable!("must only be passed terminator that is a switch"),
365         }
366     }
367 }