]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/normalize_array_len.rs
Rollup merge of #94011 - est31:let_else, r=lcnr
[rust.git] / compiler / rustc_mir_transform / src / normalize_array_len.rs
1 //! This pass eliminates casting of arrays into slices when their length
2 //! is taken using `.len()` method. Handy to preserve information in MIR for const prop
3
4 use crate::MirPass;
5 use rustc_data_structures::fx::FxIndexMap;
6 use rustc_data_structures::intern::Interned;
7 use rustc_index::bit_set::BitSet;
8 use rustc_index::vec::IndexVec;
9 use rustc_middle::mir::*;
10 use rustc_middle::ty::{self, ReErased, Region, TyCtxt};
11
12 const MAX_NUM_BLOCKS: usize = 800;
13 const MAX_NUM_LOCALS: usize = 3000;
14
15 pub struct NormalizeArrayLen;
16
17 impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
18     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19         sess.mir_opt_level() >= 4
20     }
21
22     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23         // early returns for edge cases of highly unrolled functions
24         if body.basic_blocks().len() > MAX_NUM_BLOCKS {
25             return;
26         }
27         if body.local_decls().len() > MAX_NUM_LOCALS {
28             return;
29         }
30         normalize_array_len_calls(tcx, body)
31     }
32 }
33
34 pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
35     let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
36
37     // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
38     let mut interesting_locals = BitSet::new_empty(local_decls.len());
39     for (local, decl) in local_decls.iter_enumerated() {
40         match decl.ty.kind() {
41             ty::Array(..) => {
42                 interesting_locals.insert(local);
43             }
44             ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
45                 ty::Array(..) => {
46                     interesting_locals.insert(local);
47                 }
48                 _ => {}
49             },
50             _ => {}
51         }
52     }
53     if interesting_locals.is_empty() {
54         // we have found nothing to analyze
55         return;
56     }
57     let num_intesting_locals = interesting_locals.count();
58     let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
59     let mut patches_scratchpad =
60         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
61     let mut replacements_scratchpad =
62         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
63     for block in basic_blocks {
64         // make length calls for arrays [T; N] not to decay into length calls for &[T]
65         // that forbids constant propagation
66         normalize_array_len_call(
67             tcx,
68             block,
69             local_decls,
70             &interesting_locals,
71             &mut state,
72             &mut patches_scratchpad,
73             &mut replacements_scratchpad,
74         );
75         state.clear();
76         patches_scratchpad.clear();
77         replacements_scratchpad.clear();
78     }
79 }
80
81 struct Patcher<'a, 'tcx> {
82     tcx: TyCtxt<'tcx>,
83     patches_scratchpad: &'a FxIndexMap<usize, usize>,
84     replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
85     local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
86     statement_idx: usize,
87 }
88
89 impl<'tcx> Patcher<'_, 'tcx> {
90     fn patch_expand_statement(
91         &mut self,
92         statement: &mut Statement<'tcx>,
93     ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
94         let idx = self.statement_idx;
95         if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
96             let mut statements = Vec::with_capacity(2);
97
98             // we are at statement that performs a cast. The only sound way is
99             // to create another local that performs a similar copy without a cast and then
100             // use this copy in the Len operation
101
102             match &statement.kind {
103                 StatementKind::Assign(box (
104                     ..,
105                     Rvalue::Cast(
106                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
107                         operand,
108                         _,
109                     ),
110                 )) => {
111                     match operand {
112                         Operand::Copy(place) | Operand::Move(place) => {
113                             // create new local
114                             let ty = operand.ty(self.local_decls, self.tcx);
115                             let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
116                             let local = self.local_decls.push(local_decl);
117                             // make it live
118                             let mut make_live_statement = statement.clone();
119                             make_live_statement.kind = StatementKind::StorageLive(local);
120                             statements.push(make_live_statement);
121                             // copy into it
122
123                             let operand = Operand::Copy(*place);
124                             let mut make_copy_statement = statement.clone();
125                             let assign_to = Place::from(local);
126                             let rvalue = Rvalue::Use(operand);
127                             make_copy_statement.kind =
128                                 StatementKind::Assign(box (assign_to, rvalue));
129                             statements.push(make_copy_statement);
130
131                             // to reorder we have to copy and make NOP
132                             statements.push(statement.clone());
133                             statement.make_nop();
134
135                             self.replacements_scratchpad.insert(len_statemnt_idx, local);
136                         }
137                         _ => {
138                             unreachable!("it's a bug in the implementation")
139                         }
140                     }
141                 }
142                 _ => {
143                     unreachable!("it's a bug in the implementation")
144                 }
145             }
146
147             self.statement_idx += 1;
148
149             Some(statements.into_iter())
150         } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
151             let mut statements = Vec::with_capacity(2);
152
153             match &statement.kind {
154                 StatementKind::Assign(box (into, Rvalue::Len(place))) => {
155                     let add_deref = if let Some(..) = place.as_local() {
156                         false
157                     } else if let Some(..) = place.local_or_deref_local() {
158                         true
159                     } else {
160                         unreachable!("it's a bug in the implementation")
161                     };
162                     // replace len statement
163                     let mut len_statement = statement.clone();
164                     let mut place = Place::from(local);
165                     if add_deref {
166                         place = self.tcx.mk_place_deref(place);
167                     }
168                     len_statement.kind = StatementKind::Assign(box (*into, Rvalue::Len(place)));
169                     statements.push(len_statement);
170
171                     // make temporary dead
172                     let mut make_dead_statement = statement.clone();
173                     make_dead_statement.kind = StatementKind::StorageDead(local);
174                     statements.push(make_dead_statement);
175
176                     // make original statement NOP
177                     statement.make_nop();
178                 }
179                 _ => {
180                     unreachable!("it's a bug in the implementation")
181                 }
182             }
183
184             self.statement_idx += 1;
185
186             Some(statements.into_iter())
187         } else {
188             self.statement_idx += 1;
189             None
190         }
191     }
192 }
193
194 fn normalize_array_len_call<'tcx>(
195     tcx: TyCtxt<'tcx>,
196     block: &mut BasicBlockData<'tcx>,
197     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
198     interesting_locals: &BitSet<Local>,
199     state: &mut FxIndexMap<Local, usize>,
200     patches_scratchpad: &mut FxIndexMap<usize, usize>,
201     replacements_scratchpad: &mut FxIndexMap<usize, Local>,
202 ) {
203     for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
204         match &mut statement.kind {
205             StatementKind::Assign(box (place, rvalue)) => {
206                 match rvalue {
207                     Rvalue::Cast(
208                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
209                         operand,
210                         cast_ty,
211                     ) => {
212                         let Some(local) = place.as_local() else { return };
213                         match operand {
214                             Operand::Copy(place) | Operand::Move(place) => {
215                                 let Some(operand_local) = place.local_or_deref_local() else { return; };
216                                 if !interesting_locals.contains(operand_local) {
217                                     return;
218                                 }
219                                 let operand_ty = local_decls[operand_local].ty;
220                                 match (operand_ty.kind(), cast_ty.kind()) {
221                                     (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
222                                         if of_ty_src == of_ty_dst {
223                                             // this is a cast from [T; N] into [T], so we are good
224                                             state.insert(local, statement_idx);
225                                         }
226                                     }
227                                     // current way of patching doesn't allow to work with `mut`
228                                     (
229                                         ty::Ref(
230                                             Region(Interned(ReErased, _)),
231                                             operand_ty,
232                                             Mutability::Not,
233                                         ),
234                                         ty::Ref(
235                                             Region(Interned(ReErased, _)),
236                                             cast_ty,
237                                             Mutability::Not,
238                                         ),
239                                     ) => {
240                                         match (operand_ty.kind(), cast_ty.kind()) {
241                                             // current way of patching doesn't allow to work with `mut`
242                                             (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
243                                                 if of_ty_src == of_ty_dst {
244                                                     // this is a cast from [T; N] into [T], so we are good
245                                                     state.insert(local, statement_idx);
246                                                 }
247                                             }
248                                             _ => {}
249                                         }
250                                     }
251                                     _ => {}
252                                 }
253                             }
254                             _ => {}
255                         }
256                     }
257                     Rvalue::Len(place) => {
258                         let Some(local) = place.local_or_deref_local() else {
259                             return;
260                         };
261                         if let Some(cast_statement_idx) = state.get(&local).copied() {
262                             patches_scratchpad.insert(cast_statement_idx, statement_idx);
263                         }
264                     }
265                     _ => {
266                         // invalidate
267                         state.remove(&place.local);
268                     }
269                 }
270             }
271             _ => {}
272         }
273     }
274
275     let mut patcher = Patcher {
276         tcx,
277         patches_scratchpad: &*patches_scratchpad,
278         replacements_scratchpad,
279         local_decls,
280         statement_idx: 0,
281     };
282
283     block.expand_statements(|st| patcher.patch_expand_statement(st));
284 }