]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/early_otherwise_branch.rs
Reenable feature(nll) in alloc.
[rust.git] / compiler / rustc_mir_transform / src / early_otherwise_branch.rs
1 use rustc_middle::mir::patch::MirPatch;
2 use rustc_middle::mir::*;
3 use rustc_middle::ty::{Ty, TyCtxt};
4 use std::fmt::Debug;
5
6 use super::simplify::simplify_cfg;
7
8 /// This pass optimizes something like
9 /// ```text
10 /// let x: Option<()>;
11 /// let y: Option<()>;
12 /// match (x,y) {
13 ///     (Some(_), Some(_)) => {0},
14 ///     _ => {1}
15 /// }
16 /// ```
17 /// into something like
18 /// ```text
19 /// let x: Option<()>;
20 /// let y: Option<()>;
21 /// let discriminant_x = // get discriminant of x
22 /// let discriminant_y = // get discriminant of y
23 /// if discriminant_x != discriminant_y || discriminant_x == None {1} else {0}
24 /// ```
25 pub struct EarlyOtherwiseBranch;
26
27 impl<'tcx> MirPass<'tcx> for EarlyOtherwiseBranch {
28     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
29         //  FIXME(#78496)
30         if !tcx.sess.opts.debugging_opts.unsound_mir_opts {
31             return;
32         }
33
34         if tcx.sess.mir_opt_level() < 3 {
35             return;
36         }
37         trace!("running EarlyOtherwiseBranch on {:?}", body.source);
38         // we are only interested in this bb if the terminator is a switchInt
39         let bbs_with_switch =
40             body.basic_blocks().iter_enumerated().filter(|(_, bb)| is_switch(bb.terminator()));
41
42         let opts_to_apply: Vec<OptimizationToApply<'tcx>> = bbs_with_switch
43             .flat_map(|(bb_idx, bb)| {
44                 let switch = bb.terminator();
45                 let helper = Helper { body, tcx };
46                 let infos = helper.go(bb, switch)?;
47                 Some(OptimizationToApply { infos, basic_block_first_switch: bb_idx })
48             })
49             .collect();
50
51         let should_cleanup = !opts_to_apply.is_empty();
52
53         for opt_to_apply in opts_to_apply {
54             if !tcx.consider_optimizing(|| format!("EarlyOtherwiseBranch {:?}", &opt_to_apply)) {
55                 break;
56             }
57
58             trace!("SUCCESS: found optimization possibility to apply: {:?}", &opt_to_apply);
59
60             let statements_before =
61                 body.basic_blocks()[opt_to_apply.basic_block_first_switch].statements.len();
62             let end_of_block_location = Location {
63                 block: opt_to_apply.basic_block_first_switch,
64                 statement_index: statements_before,
65             };
66
67             let mut patch = MirPatch::new(body);
68
69             // create temp to store second discriminant in
70             let discr_type = opt_to_apply.infos[0].second_switch_info.discr_ty;
71             let discr_span = opt_to_apply.infos[0].second_switch_info.discr_source_info.span;
72             let second_discriminant_temp = patch.new_temp(discr_type, discr_span);
73
74             patch.add_statement(
75                 end_of_block_location,
76                 StatementKind::StorageLive(second_discriminant_temp),
77             );
78
79             // create assignment of discriminant
80             let place_of_adt_to_get_discriminant_of =
81                 opt_to_apply.infos[0].second_switch_info.place_of_adt_discr_read;
82             patch.add_assign(
83                 end_of_block_location,
84                 Place::from(second_discriminant_temp),
85                 Rvalue::Discriminant(place_of_adt_to_get_discriminant_of),
86             );
87
88             // create temp to store NotEqual comparison between the two discriminants
89             let not_equal = BinOp::Ne;
90             let not_equal_res_type = not_equal.ty(tcx, discr_type, discr_type);
91             let not_equal_temp = patch.new_temp(not_equal_res_type, discr_span);
92             patch.add_statement(end_of_block_location, StatementKind::StorageLive(not_equal_temp));
93
94             // create NotEqual comparison between the two discriminants
95             let first_descriminant_place =
96                 opt_to_apply.infos[0].first_switch_info.discr_used_in_switch;
97             let not_equal_rvalue = Rvalue::BinaryOp(
98                 not_equal,
99                 Box::new((
100                     Operand::Copy(Place::from(second_discriminant_temp)),
101                     Operand::Copy(first_descriminant_place),
102                 )),
103             );
104             patch.add_statement(
105                 end_of_block_location,
106                 StatementKind::Assign(Box::new((Place::from(not_equal_temp), not_equal_rvalue))),
107             );
108
109             let new_targets = opt_to_apply
110                 .infos
111                 .iter()
112                 .flat_map(|x| x.second_switch_info.targets_with_values.iter())
113                 .cloned();
114
115             let targets = SwitchTargets::new(
116                 new_targets,
117                 opt_to_apply.infos[0].first_switch_info.otherwise_bb,
118             );
119
120             // new block that jumps to the correct discriminant case. This block is switched to if the discriminants are equal
121             let new_switch_data = BasicBlockData::new(Some(Terminator {
122                 source_info: opt_to_apply.infos[0].second_switch_info.discr_source_info,
123                 kind: TerminatorKind::SwitchInt {
124                     // the first and second discriminants are equal, so just pick one
125                     discr: Operand::Copy(first_descriminant_place),
126                     switch_ty: discr_type,
127                     targets,
128                 },
129             }));
130
131             let new_switch_bb = patch.new_block(new_switch_data);
132
133             // switch on the NotEqual. If true, then jump to the `otherwise` case.
134             // If false, then jump to a basic block that then jumps to the correct disciminant case
135             let true_case = opt_to_apply.infos[0].first_switch_info.otherwise_bb;
136             let false_case = new_switch_bb;
137             patch.patch_terminator(
138                 opt_to_apply.basic_block_first_switch,
139                 TerminatorKind::if_(
140                     tcx,
141                     Operand::Move(Place::from(not_equal_temp)),
142                     true_case,
143                     false_case,
144                 ),
145             );
146
147             // generate StorageDead for the second_discriminant_temp not in use anymore
148             patch.add_statement(
149                 end_of_block_location,
150                 StatementKind::StorageDead(second_discriminant_temp),
151             );
152
153             // Generate a StorageDead for not_equal_temp in each of the targets, since we moved it into the switch
154             for bb in [false_case, true_case].iter() {
155                 patch.add_statement(
156                     Location { block: *bb, statement_index: 0 },
157                     StatementKind::StorageDead(not_equal_temp),
158                 );
159             }
160
161             patch.apply(body);
162         }
163
164         // Since this optimization adds new basic blocks and invalidates others,
165         // clean up the cfg to make it nicer for other passes
166         if should_cleanup {
167             simplify_cfg(tcx, body);
168         }
169     }
170 }
171
172 fn is_switch<'tcx>(terminator: &Terminator<'tcx>) -> bool {
173     matches!(terminator.kind, TerminatorKind::SwitchInt { .. })
174 }
175
176 struct Helper<'a, 'tcx> {
177     body: &'a Body<'tcx>,
178     tcx: TyCtxt<'tcx>,
179 }
180
181 #[derive(Debug, Clone)]
182 struct SwitchDiscriminantInfo<'tcx> {
183     /// Type of the discriminant being switched on
184     discr_ty: Ty<'tcx>,
185     /// The basic block that the otherwise branch points to
186     otherwise_bb: BasicBlock,
187     /// Target along with the value being branched from. Otherwise is not included
188     targets_with_values: Vec<(u128, BasicBlock)>,
189     discr_source_info: SourceInfo,
190     /// The place of the discriminant used in the switch
191     discr_used_in_switch: Place<'tcx>,
192     /// The place of the adt that has its discriminant read
193     place_of_adt_discr_read: Place<'tcx>,
194     /// The type of the adt that has its discriminant read
195     type_adt_matched_on: Ty<'tcx>,
196 }
197
198 #[derive(Debug)]
199 struct OptimizationToApply<'tcx> {
200     infos: Vec<OptimizationInfo<'tcx>>,
201     /// Basic block of the original first switch
202     basic_block_first_switch: BasicBlock,
203 }
204
205 #[derive(Debug)]
206 struct OptimizationInfo<'tcx> {
207     /// Info about the first switch and discriminant
208     first_switch_info: SwitchDiscriminantInfo<'tcx>,
209     /// Info about the second switch and discriminant
210     second_switch_info: SwitchDiscriminantInfo<'tcx>,
211 }
212
213 impl<'a, 'tcx> Helper<'a, 'tcx> {
214     pub fn go(
215         &self,
216         bb: &BasicBlockData<'tcx>,
217         switch: &Terminator<'tcx>,
218     ) -> Option<Vec<OptimizationInfo<'tcx>>> {
219         // try to find the statement that defines the discriminant that is used for the switch
220         let discr = self.find_switch_discriminant_info(bb, switch)?;
221
222         // go through each target, finding a discriminant read, and a switch
223         let results = discr
224             .targets_with_values
225             .iter()
226             .map(|(value, target)| self.find_discriminant_switch_pairing(&discr, *target, *value));
227
228         // if the optimization did not apply for one of the targets, then abort
229         if results.clone().any(|x| x.is_none()) || results.len() == 0 {
230             trace!("NO: not all of the targets matched the pattern for optimization");
231             return None;
232         }
233
234         Some(results.flatten().collect())
235     }
236
237     fn find_discriminant_switch_pairing(
238         &self,
239         discr_info: &SwitchDiscriminantInfo<'tcx>,
240         target: BasicBlock,
241         value: u128,
242     ) -> Option<OptimizationInfo<'tcx>> {
243         let bb = &self.body.basic_blocks()[target];
244         // find switch
245         let terminator = bb.terminator();
246         if is_switch(terminator) {
247             let this_bb_discr_info = self.find_switch_discriminant_info(bb, terminator)?;
248
249             // the types of the two adts matched on have to be equalfor this optimization to apply
250             if discr_info.type_adt_matched_on != this_bb_discr_info.type_adt_matched_on {
251                 trace!(
252                     "NO: types do not match. LHS: {:?}, RHS: {:?}",
253                     discr_info.type_adt_matched_on,
254                     this_bb_discr_info.type_adt_matched_on
255                 );
256                 return None;
257             }
258
259             // the otherwise branch of the two switches have to point to the same bb
260             if discr_info.otherwise_bb != this_bb_discr_info.otherwise_bb {
261                 trace!("NO: otherwise target is not the same");
262                 return None;
263             }
264
265             // check that the value being matched on is the same. The
266             if !this_bb_discr_info.targets_with_values.iter().any(|x| x.0 == value) {
267                 trace!("NO: values being matched on are not the same");
268                 return None;
269             }
270
271             // only allow optimization if the left and right of the tuple being matched are the same variants.
272             // so the following should not optimize
273             //  ```rust
274             // let x: Option<()>;
275             // let y: Option<()>;
276             // match (x,y) {
277             //     (Some(_), None) => {},
278             //     _ => {}
279             // }
280             //  ```
281             // We check this by seeing that the value of the first discriminant is the only other discriminant value being used as a target in the second switch
282             if !(this_bb_discr_info.targets_with_values.len() == 1
283                 && this_bb_discr_info.targets_with_values[0].0 == value)
284             {
285                 trace!(
286                     "NO: The second switch did not have only 1 target (besides otherwise) that had the same value as the value from the first switch that got us here"
287                 );
288                 return None;
289             }
290
291             // when the second place is a projection of the first one, it's not safe to calculate their discriminant values sequentially.
292             // for example, this should not be optimized:
293             //
294             // ```rust
295             // enum E<'a> { Empty, Some(&'a E<'a>), }
296             // let Some(Some(_)) = e;
297             // ```
298             //
299             // ```mir
300             // bb0: {
301             //   _2 = discriminant(*_1)
302             //   switchInt(_2) -> [...]
303             // }
304             // bb1: {
305             //   _3 = discriminant(*(((*_1) as Some).0: &E))
306             //   switchInt(_3) -> [...]
307             // }
308             // ```
309             let discr_place = discr_info.place_of_adt_discr_read;
310             let this_discr_place = this_bb_discr_info.place_of_adt_discr_read;
311             if discr_place.local == this_discr_place.local
312                 && this_discr_place.projection.starts_with(discr_place.projection)
313             {
314                 trace!("NO: one target is the projection of another");
315                 return None;
316             }
317
318             // if we reach this point, the optimization applies, and we should be able to optimize this case
319             // store the info that is needed to apply the optimization
320
321             Some(OptimizationInfo {
322                 first_switch_info: discr_info.clone(),
323                 second_switch_info: this_bb_discr_info,
324             })
325         } else {
326             None
327         }
328     }
329
330     fn find_switch_discriminant_info(
331         &self,
332         bb: &BasicBlockData<'tcx>,
333         switch: &Terminator<'tcx>,
334     ) -> Option<SwitchDiscriminantInfo<'tcx>> {
335         match &switch.kind {
336             TerminatorKind::SwitchInt { discr, targets, .. } => {
337                 let discr_local = discr.place()?.as_local()?;
338                 // the declaration of the discriminant read. Place of this read is being used in the switch
339                 let discr_decl = &self.body.local_decls()[discr_local];
340                 let discr_ty = discr_decl.ty;
341                 // the otherwise target lies as the last element
342                 let otherwise_bb = targets.otherwise();
343                 let targets_with_values = targets.iter().collect();
344
345                 // find the place of the adt where the discriminant is being read from
346                 // assume this is the last statement of the block
347                 let place_of_adt_discr_read = match bb.statements.last()?.kind {
348                     StatementKind::Assign(box (_, Rvalue::Discriminant(adt_place))) => {
349                         Some(adt_place)
350                     }
351                     _ => None,
352                 }?;
353
354                 let type_adt_matched_on = place_of_adt_discr_read.ty(self.body, self.tcx).ty;
355
356                 Some(SwitchDiscriminantInfo {
357                     discr_used_in_switch: discr.place()?,
358                     discr_ty,
359                     otherwise_bb,
360                     targets_with_values,
361                     discr_source_info: discr_decl.source_info,
362                     place_of_adt_discr_read,
363                     type_adt_matched_on,
364                 })
365             }
366             _ => unreachable!("must only be passed terminator that is a switch"),
367         }
368     }
369 }