]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir/src/transform/early_otherwise_branch.rs
Auto merge of #79895 - Kerollmops:slice-group-by, r=m-ou-se
[rust.git] / compiler / rustc_mir / src / transform / early_otherwise_branch.rs
1 use crate::{transform::MirPass, util::patch::MirPatch};
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```text
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 ///     (Some(_), Some(_)) => {0},
14 ///     _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```text
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = // get discriminant of x
22 /// let discriminant_y = // get discriminant of y
23 /// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
24 /// ```
25 pub struct EarlyOtherwiseBranch;
26
27 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
28     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
29         if tcx.sess.opts.debugging_opts.mir_opt_level < 2 {
30             return;
31         }
32         trace!("running EarlyOtherwiseBranch on {:?}", body.source);
33         // we are only interested in this bb if the terminator is a switchInt
34         let bbs_with_switch =
35             body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
36
37         let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
38             .flat_map(|(bb_idx, bb)| {
39                 let switch = bb.terminator();
40                 let helper = Helper { body, tcx };
41                 let infos = helper.go(bb, switch)?;
42                 Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
43             })
44             .collect();
45
46         let should_cleanup = !opts_to_apply.is_empty();
47
48         for opt_to_apply in opts_to_apply {
49             if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) {
50                 break;
51             }
52
53             trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
54
55             let statements_before =
56                 body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
57             let end_of_block_location = Location {
58                 block: opt_to_apply.basic_block_first_switch,
59                 statement_index: statements_before,
60             };
61
62             let mut patch = MirPatch::new(body);
63
64             // create temp to store second discriminant in
65             let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
66             let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
67             let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
68
69             patch.add_statement(
70                 end_of_block_location,
71                 StatementKind::StorageLive(second_discriminant_temp),
72             );
73
74             // create assignment of discriminant
75             let place_of_adt_to_get_discriminant_of =
76                 opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
77             patch.add_assign(
78                 end_of_block_location,
79                 Place::from(second_discriminant_temp),
80                 Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
81             );
82
83             // create temp to store NotEqual comparison between the two discriminants
84             let not_equal = BinOp::Ne;
85             let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
86             let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
87             patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
88
89             // create NotEqual comparison between the two discriminants
90             let first_descriminant_place =
91                 opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
92             let not_equal_rvalue = Rvalue::BinaryOp(
93                 not_equal,
94                 Operand::Copy(Place::from(second_discriminant_temp)),
95                 Operand::Copy(first_descriminant_place),
96             );
97             patch.add_statement(
98                 end_of_block_location,
99                 StatementKind::Assign(box (Place::from(not_equal_temp), not_equal_rvalue)),
100             );
101
102             let new_targets = opt_to_apply
103                 .infos
104                 .iter()
105                 .flat_map(|x| x.second_switch_info.targets_with_values.iter())
106                 .cloned();
107
108             let targets = SwitchTargets::new(
109                 new_targets,
110                 opt_to_apply.infos[0].first_switch_info.otherwise_bb,
111             );
112
113             // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
114             let new_switch_data = BasicBlockData::new(Some(Terminator {
115                 source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
116                 kind: TerminatorKind::SwitchInt {
117                     // the first and second discriminants are equal, so just pick one
118                     discr: Operand::Copy(first_descriminant_place),
119                     switch_ty: discr_type,
120                     targets,
121                 },
122             }));
123
124             let new_switch_bb = patch.new_block(new_switch_data);
125
126             // switch on the NotEqual. If true, then jump to the `otherwise` case.
127             // If false, then jump to a basic block that then jumps to the correct disciminant case
128             let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
129             let false_case = new_switch_bb;
130             patch.patch_terminator(
131                 opt_to_apply.basic_block_first_switch,
132                 TerminatorKind::if_(
133                     tcx,
134                     Operand::Move(Place::from(not_equal_temp)),
135                     true_case,
136                     false_case,
137                 ),
138             );
139
140             // generate StorageDead for the second_discriminant_temp not in use anymore
141             patch.add_statement(
142                 end_of_block_location,
143                 StatementKind::StorageDead(second_discriminant_temp),
144             );
145
146             // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
147             for bb in [false_case, true_case].iter() {
148                 patch.add_statement(
149                     Location { block: *bb, statement_index: 0 },
150                     StatementKind::StorageDead(not_equal_temp),
151                 );
152             }
153
154             patch.apply(body);
155         }
156
157         // Since this optimization adds new basic blocks and invalidates others,
158         // clean up the cfg to make it nicer for other passes
159         if should_cleanup {
160             simplify_cfg(body);
161         }
162     }
163 }
164
165 fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool {
166     match terminator.kind {
167         TerminatorKind::SwitchInt { .. } => true,
168         _ => false,
169     }
170 }
171
172 struct Helper<'a, 'tcx> {
173     body: &'a Body<'tcx>,
174     tcx: TyCtxt<'tcx>,
175 }
176
177 #[derive(Debug, Clone)]
178 struct SwitchDiscriminantInfo<'tcx> {
179     /// Type of the discriminant being switched on
180     discr_ty: Ty<'tcx>,
181     /// The basic block that the otherwise branch points to
182     otherwise_bb: BasicBlock,
183     /// Target along with the value being branched from. Otherwise is not included
184     targets_with_values: Vec<(u128, BasicBlock)>,
185     discr_source_info: SourceInfo,
186     /// The place of the discriminant used in the switch
187     discr_used_in_switch: Place<'tcx>,
188     /// The place of the adt that has its discriminant read
189     place_of_adt_discr_read: Place<'tcx>,
190     /// The type of the adt that has its discriminant read
191     type_adt_matched_on: Ty<'tcx>,
192 }
193
194 #[derive(Debug)]
195 struct OptimizationToApply<'tcx> {
196     infos: Vec<OptimizationInfo<'tcx>>,
197     /// Basic block of the original first switch
198     basic_block_first_switch: BasicBlock,
199 }
200
201 #[derive(Debug)]
202 struct OptimizationInfo<'tcx> {
203     /// Info about the first switch and discriminant
204     first_switch_info: SwitchDiscriminantInfo<'tcx>,
205     /// Info about the second switch and discriminant
206     second_switch_info: SwitchDiscriminantInfo<'tcx>,
207 }
208
209 impl<'a, 'tcx> Helper<'a, 'tcx> {
210     pub fn go(
211         &self,
212         bb: &BasicBlockData<'tcx>,
213         switch: &Terminator<'tcx>,
214     ) -> Option<Vec<OptimizationInfo<'tcx>>> {
215         // try to find the statement that defines the discriminant that is used for the switch
216         let discr = self.find_switch_discriminant_info(bb, switch)?;
217
218         // go through each target, finding a discriminant read, and a switch
219         let results = discr
220             .targets_with_values
221             .iter()
222             .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value));
223
224         // if the optimization did not apply for one of the targets, then abort
225         if results.clone().any(|x| x.is_none()) || results.len() == 0 {
226             trace!("NO: not all of the targets matched the pattern for optimization");
227             return None;
228         }
229
230         Some(results.flatten().collect())
231     }
232
233     fn find_discriminant_switch_pairing(
234         &self,
235         discr_info: &SwitchDiscriminantInfo<'tcx>,
236         target: BasicBlock,
237         value: u128,
238     ) -> Option<OptimizationInfo<'tcx>> {
239         let bb = &self.body.basic_blocks()[target];
240         // find switch
241         let terminator = bb.terminator();
242         if is_switch(terminator) {
243             let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
244
245             // the types of the two adts matched on have to be equalfor this optimization to apply
246             if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
247                 trace!(
248                     "NO: types do not match. LHS: {:?}, RHS: {:?}",
249                     discr_info.type_adt_matched_on,
250                     this_bb_discr_info.type_adt_matched_on
251                 );
252                 return None;
253             }
254
255             // the otherwise branch of the two switches have to point to the same bb
256             if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
257                 trace!("NO: otherwise target is not the same");
258                 return None;
259             }
260
261             // check that the value being matched on is the same. The
262             if this_bb_discr_info.targets_with_values.iter().find(|x| x.0 == value).is_none() {
263                 trace!("NO: values being matched on are not the same");
264                 return None;
265             }
266
267             // only allow optimization if the left and right of the tuple being matched are the same variants.
268             // so the following should not optimize
269             //  ```rust
270             // let x: Option<()>;
271             // let y: Option<()>;
272             // match (x,y) {
273             //     (Some(_), None) => {},
274             //     _ => {}
275             // }
276             //  ```
277             // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
278             if !(this_bb_discr_info.targets_with_values.len() == 1
279                 && this_bb_discr_info.targets_with_values[0].0 == value)
280             {
281                 trace!(
282                     "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
283                 );
284                 return None;
285             }
286
287             // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially.
288             // for example, this should not be optimized:
289             //
290             // ```rust
291             // enum E<'a> { Empty, Some(&'a E<'a>), }
292             // let Some(Some(_)) = e;
293             // ```
294             //
295             // ```mir
296             // bb0: {
297             //   _2 = discriminant(*_1)
298             //   switchInt(_2) -> [...]
299             // }
300             // bb1: {
301             //   _3 = discriminant(*(((*_1) as Some).0: &E))
302             //   switchInt(_3) -> [...]
303             // }
304             // ```
305             let discr_place = discr_info.place_of_adt_discr_read;
306             let this_discr_place = this_bb_discr_info.place_of_adt_discr_read;
307             if discr_place.local == this_discr_place.local
308                 && this_discr_place.projection.starts_with(discr_place.projection)
309             {
310                 trace!("NO: one target is the projection of another");
311                 return None;
312             }
313
314             // if we reach this point, the optimization applies, and we should be able to optimize this case
315             // store the info that is needed to apply the optimization
316
317             Some(OptimizationInfo {
318                 first_switch_info: discr_info.clone(),
319                 second_switch_info: this_bb_discr_info,
320             })
321         } else {
322             None
323         }
324     }
325
326     fn find_switch_discriminant_info(
327         &self,
328         bb: &BasicBlockData<'tcx>,
329         switch: &Terminator<'tcx>,
330     ) -> Option<SwitchDiscriminantInfo<'tcx>> {
331         match &switch.kind {
332             TerminatorKind::SwitchInt { discr, targets, .. } => {
333                 let discr_local = discr.place()?.as_local()?;
334                 // the declaration of the discriminant read. Place of this read is being used in the switch
335                 let discr_decl = &self.body.local_decls()[discr_local];
336                 let discr_ty = discr_decl.ty;
337                 // the otherwise target lies as the last element
338                 let otherwise_bb = targets.otherwise();
339                 let targets_with_values = targets.iter().collect();
340
341                 // find the place of the adt where the discriminant is being read from
342                 // assume this is the last statement of the block
343                 let place_of_adt_discr_read = match bb.statements.last()?.kind {
344                     StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
345                         Some(adt_place)
346                     }
347                     _ => None,
348                 }?;
349
350                 let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
351
352                 Some(SwitchDiscriminantInfo {
353                     discr_used_in_switch: discr.place()?,
354                     discr_ty,
355                     otherwise_bb,
356                     targets_with_values,
357                     discr_source_info: discr_decl.source_info,
358                     place_of_adt_discr_read,
359                     type_adt_matched_on,
360                 })
361             }
362             _ => unreachable!("must only be passed terminator that is a switch"),
363         }
364     }
365 }