]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/normalize_array_len.rs
Rollup merge of #105208 - chenyukang:yukang/fix-105069, r=cjgillot
[rust.git] / compiler / rustc_mir_transform / src / normalize_array_len.rs
1 //! This pass eliminates casting of arrays into slices when their length
2 //! is taken using `.len()` method. Handy to preserve information in MIR for const prop
3
4 use crate::MirPass;
5 use rustc_data_structures::fx::FxIndexMap;
6 use rustc_data_structures::intern::Interned;
7 use rustc_index::bit_set::BitSet;
8 use rustc_index::vec::IndexVec;
9 use rustc_middle::mir::*;
10 use rustc_middle::ty::{self, ReErased, Region, TyCtxt};
11
12 const MAX_NUM_BLOCKS: usize = 800;
13 const MAX_NUM_LOCALS: usize = 3000;
14
15 pub struct NormalizeArrayLen;
16
17 impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
18     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19         sess.mir_opt_level() >= 4
20     }
21
22     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23         // early returns for edge cases of highly unrolled functions
24         if body.basic_blocks.len() > MAX_NUM_BLOCKS {
25             return;
26         }
27         if body.local_decls.len() > MAX_NUM_LOCALS {
28             return;
29         }
30         normalize_array_len_calls(tcx, body)
31     }
32 }
33
34 pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
35     // We don't ever touch terminators, so no need to invalidate the CFG cache
36     let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
37     let local_decls = &mut body.local_decls;
38
39     // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
40     let mut interesting_locals = BitSet::new_empty(local_decls.len());
41     for (local, decl) in local_decls.iter_enumerated() {
42         match decl.ty.kind() {
43             ty::Array(..) => {
44                 interesting_locals.insert(local);
45             }
46             ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
47                 ty::Array(..) => {
48                     interesting_locals.insert(local);
49                 }
50                 _ => {}
51             },
52             _ => {}
53         }
54     }
55     if interesting_locals.is_empty() {
56         // we have found nothing to analyze
57         return;
58     }
59     let num_intesting_locals = interesting_locals.count();
60     let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
61     let mut patches_scratchpad =
62         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
63     let mut replacements_scratchpad =
64         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
65     for block in basic_blocks {
66         // make length calls for arrays [T; N] not to decay into length calls for &[T]
67         // that forbids constant propagation
68         normalize_array_len_call(
69             tcx,
70             block,
71             local_decls,
72             &interesting_locals,
73             &mut state,
74             &mut patches_scratchpad,
75             &mut replacements_scratchpad,
76         );
77         state.clear();
78         patches_scratchpad.clear();
79         replacements_scratchpad.clear();
80     }
81 }
82
83 struct Patcher<'a, 'tcx> {
84     tcx: TyCtxt<'tcx>,
85     patches_scratchpad: &'a FxIndexMap<usize, usize>,
86     replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
87     local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
88     statement_idx: usize,
89 }
90
91 impl<'tcx> Patcher<'_, 'tcx> {
92     fn patch_expand_statement(
93         &mut self,
94         statement: &mut Statement<'tcx>,
95     ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
96         let idx = self.statement_idx;
97         if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
98             let mut statements = Vec::with_capacity(2);
99
100             // we are at statement that performs a cast. The only sound way is
101             // to create another local that performs a similar copy without a cast and then
102             // use this copy in the Len operation
103
104             match &statement.kind {
105                 StatementKind::Assign(box (
106                     ..,
107                     Rvalue::Cast(
108                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
109                         operand,
110                         _,
111                     ),
112                 )) => {
113                     match operand {
114                         Operand::Copy(place) | Operand::Move(place) => {
115                             // create new local
116                             let ty = operand.ty(self.local_decls, self.tcx);
117                             let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
118                             let local = self.local_decls.push(local_decl);
119                             // make it live
120                             let mut make_live_statement = statement.clone();
121                             make_live_statement.kind = StatementKind::StorageLive(local);
122                             statements.push(make_live_statement);
123                             // copy into it
124
125                             let operand = Operand::Copy(*place);
126                             let mut make_copy_statement = statement.clone();
127                             let assign_to = Place::from(local);
128                             let rvalue = Rvalue::Use(operand);
129                             make_copy_statement.kind =
130                                 StatementKind::Assign(Box::new((assign_to, rvalue)));
131                             statements.push(make_copy_statement);
132
133                             // to reorder we have to copy and make NOP
134                             statements.push(statement.clone());
135                             statement.make_nop();
136
137                             self.replacements_scratchpad.insert(len_statemnt_idx, local);
138                         }
139                         _ => {
140                             unreachable!("it's a bug in the implementation")
141                         }
142                     }
143                 }
144                 _ => {
145                     unreachable!("it's a bug in the implementation")
146                 }
147             }
148
149             self.statement_idx += 1;
150
151             Some(statements.into_iter())
152         } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
153             let mut statements = Vec::with_capacity(2);
154
155             match &statement.kind {
156                 StatementKind::Assign(box (into, Rvalue::Len(place))) => {
157                     let add_deref = if let Some(..) = place.as_local() {
158                         false
159                     } else if let Some(..) = place.local_or_deref_local() {
160                         true
161                     } else {
162                         unreachable!("it's a bug in the implementation")
163                     };
164                     // replace len statement
165                     let mut len_statement = statement.clone();
166                     let mut place = Place::from(local);
167                     if add_deref {
168                         place = self.tcx.mk_place_deref(place);
169                     }
170                     len_statement.kind =
171                         StatementKind::Assign(Box::new((*into, Rvalue::Len(place))));
172                     statements.push(len_statement);
173
174                     // make temporary dead
175                     let mut make_dead_statement = statement.clone();
176                     make_dead_statement.kind = StatementKind::StorageDead(local);
177                     statements.push(make_dead_statement);
178
179                     // make original statement NOP
180                     statement.make_nop();
181                 }
182                 _ => {
183                     unreachable!("it's a bug in the implementation")
184                 }
185             }
186
187             self.statement_idx += 1;
188
189             Some(statements.into_iter())
190         } else {
191             self.statement_idx += 1;
192             None
193         }
194     }
195 }
196
197 fn normalize_array_len_call<'tcx>(
198     tcx: TyCtxt<'tcx>,
199     block: &mut BasicBlockData<'tcx>,
200     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
201     interesting_locals: &BitSet<Local>,
202     state: &mut FxIndexMap<Local, usize>,
203     patches_scratchpad: &mut FxIndexMap<usize, usize>,
204     replacements_scratchpad: &mut FxIndexMap<usize, Local>,
205 ) {
206     for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
207         match &mut statement.kind {
208             StatementKind::Assign(box (place, rvalue)) => {
209                 match rvalue {
210                     Rvalue::Cast(
211                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
212                         operand,
213                         cast_ty,
214                     ) => {
215                         let Some(local) = place.as_local() else { return };
216                         match operand {
217                             Operand::Copy(place) | Operand::Move(place) => {
218                                 let Some(operand_local) = place.local_or_deref_local() else { return; };
219                                 if !interesting_locals.contains(operand_local) {
220                                     return;
221                                 }
222                                 let operand_ty = local_decls[operand_local].ty;
223                                 match (operand_ty.kind(), cast_ty.kind()) {
224                                     (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
225                                         if of_ty_src == of_ty_dst {
226                                             // this is a cast from [T; N] into [T], so we are good
227                                             state.insert(local, statement_idx);
228                                         }
229                                     }
230                                     // current way of patching doesn't allow to work with `mut`
231                                     (
232                                         ty::Ref(
233                                             Region(Interned(ReErased, _)),
234                                             operand_ty,
235                                             Mutability::Not,
236                                         ),
237                                         ty::Ref(
238                                             Region(Interned(ReErased, _)),
239                                             cast_ty,
240                                             Mutability::Not,
241                                         ),
242                                     ) => {
243                                         match (operand_ty.kind(), cast_ty.kind()) {
244                                             // current way of patching doesn't allow to work with `mut`
245                                             (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
246                                                 if of_ty_src == of_ty_dst {
247                                                     // this is a cast from [T; N] into [T], so we are good
248                                                     state.insert(local, statement_idx);
249                                                 }
250                                             }
251                                             _ => {}
252                                         }
253                                     }
254                                     _ => {}
255                                 }
256                             }
257                             _ => {}
258                         }
259                     }
260                     Rvalue::Len(place) => {
261                         let Some(local) = place.local_or_deref_local() else {
262                             return;
263                         };
264                         if let Some(cast_statement_idx) = state.get(&local).copied() {
265                             patches_scratchpad.insert(cast_statement_idx, statement_idx);
266                         }
267                     }
268                     _ => {
269                         // invalidate
270                         state.remove(&place.local);
271                     }
272                 }
273             }
274             _ => {}
275         }
276     }
277
278     let mut patcher = Patcher {
279         tcx,
280         patches_scratchpad: &*patches_scratchpad,
281         replacements_scratchpad,
282         local_decls,
283         statement_idx: 0,
284     };
285
286     block.expand_statements(|st| patcher.patch_expand_statement(st));
287 }