]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/normalize_array_len.rs
Rollup merge of #98811 - RalfJung:interpret-alloc-range, r=oli-obk
[rust.git] / compiler / rustc_mir_transform / src / normalize_array_len.rs
1 //! This pass eliminates casting of arrays into slices when their length
2 //! is taken using `.len()` method. Handy to preserve information in MIR for const prop
3
4 use crate::MirPass;
5 use rustc_data_structures::fx::FxIndexMap;
6 use rustc_data_structures::intern::Interned;
7 use rustc_index::bit_set::BitSet;
8 use rustc_index::vec::IndexVec;
9 use rustc_middle::mir::*;
10 use rustc_middle::ty::{self, ReErased, Region, TyCtxt};
11
12 const MAX_NUM_BLOCKS: usize = 800;
13 const MAX_NUM_LOCALS: usize = 3000;
14
15 pub struct NormalizeArrayLen;
16
17 impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
18     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
19         sess.mir_opt_level() >= 4
20     }
21
22     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
23         // early returns for edge cases of highly unrolled functions
24         if body.basic_blocks().len() > MAX_NUM_BLOCKS {
25             return;
26         }
27         if body.local_decls().len() > MAX_NUM_LOCALS {
28             return;
29         }
30         normalize_array_len_calls(tcx, body)
31     }
32 }
33
34 pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
35     let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
36
37     // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
38     let mut interesting_locals = BitSet::new_empty(local_decls.len());
39     for (local, decl) in local_decls.iter_enumerated() {
40         match decl.ty.kind() {
41             ty::Array(..) => {
42                 interesting_locals.insert(local);
43             }
44             ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
45                 ty::Array(..) => {
46                     interesting_locals.insert(local);
47                 }
48                 _ => {}
49             },
50             _ => {}
51         }
52     }
53     if interesting_locals.is_empty() {
54         // we have found nothing to analyze
55         return;
56     }
57     let num_intesting_locals = interesting_locals.count();
58     let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
59     let mut patches_scratchpad =
60         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
61     let mut replacements_scratchpad =
62         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
63     for block in basic_blocks {
64         // make length calls for arrays [T; N] not to decay into length calls for &[T]
65         // that forbids constant propagation
66         normalize_array_len_call(
67             tcx,
68             block,
69             local_decls,
70             &interesting_locals,
71             &mut state,
72             &mut patches_scratchpad,
73             &mut replacements_scratchpad,
74         );
75         state.clear();
76         patches_scratchpad.clear();
77         replacements_scratchpad.clear();
78     }
79 }
80
81 struct Patcher<'a, 'tcx> {
82     tcx: TyCtxt<'tcx>,
83     patches_scratchpad: &'a FxIndexMap<usize, usize>,
84     replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
85     local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
86     statement_idx: usize,
87 }
88
89 impl<'tcx> Patcher<'_, 'tcx> {
90     fn patch_expand_statement(
91         &mut self,
92         statement: &mut Statement<'tcx>,
93     ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
94         let idx = self.statement_idx;
95         if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
96             let mut statements = Vec::with_capacity(2);
97
98             // we are at statement that performs a cast. The only sound way is
99             // to create another local that performs a similar copy without a cast and then
100             // use this copy in the Len operation
101
102             match &statement.kind {
103                 StatementKind::Assign(box (
104                     ..,
105                     Rvalue::Cast(
106                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
107                         operand,
108                         _,
109                     ),
110                 )) => {
111                     match operand {
112                         Operand::Copy(place) | Operand::Move(place) => {
113                             // create new local
114                             let ty = operand.ty(self.local_decls, self.tcx);
115                             let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
116                             let local = self.local_decls.push(local_decl);
117                             // make it live
118                             let mut make_live_statement = statement.clone();
119                             make_live_statement.kind = StatementKind::StorageLive(local);
120                             statements.push(make_live_statement);
121                             // copy into it
122
123                             let operand = Operand::Copy(*place);
124                             let mut make_copy_statement = statement.clone();
125                             let assign_to = Place::from(local);
126                             let rvalue = Rvalue::Use(operand);
127                             make_copy_statement.kind =
128                                 StatementKind::Assign(Box::new((assign_to, rvalue)));
129                             statements.push(make_copy_statement);
130
131                             // to reorder we have to copy and make NOP
132                             statements.push(statement.clone());
133                             statement.make_nop();
134
135                             self.replacements_scratchpad.insert(len_statemnt_idx, local);
136                         }
137                         _ => {
138                             unreachable!("it's a bug in the implementation")
139                         }
140                     }
141                 }
142                 _ => {
143                     unreachable!("it's a bug in the implementation")
144                 }
145             }
146
147             self.statement_idx += 1;
148
149             Some(statements.into_iter())
150         } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
151             let mut statements = Vec::with_capacity(2);
152
153             match &statement.kind {
154                 StatementKind::Assign(box (into, Rvalue::Len(place))) => {
155                     let add_deref = if let Some(..) = place.as_local() {
156                         false
157                     } else if let Some(..) = place.local_or_deref_local() {
158                         true
159                     } else {
160                         unreachable!("it's a bug in the implementation")
161                     };
162                     // replace len statement
163                     let mut len_statement = statement.clone();
164                     let mut place = Place::from(local);
165                     if add_deref {
166                         place = self.tcx.mk_place_deref(place);
167                     }
168                     len_statement.kind =
169                         StatementKind::Assign(Box::new((*into, Rvalue::Len(place))));
170                     statements.push(len_statement);
171
172                     // make temporary dead
173                     let mut make_dead_statement = statement.clone();
174                     make_dead_statement.kind = StatementKind::StorageDead(local);
175                     statements.push(make_dead_statement);
176
177                     // make original statement NOP
178                     statement.make_nop();
179                 }
180                 _ => {
181                     unreachable!("it's a bug in the implementation")
182                 }
183             }
184
185             self.statement_idx += 1;
186
187             Some(statements.into_iter())
188         } else {
189             self.statement_idx += 1;
190             None
191         }
192     }
193 }
194
195 fn normalize_array_len_call<'tcx>(
196     tcx: TyCtxt<'tcx>,
197     block: &mut BasicBlockData<'tcx>,
198     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
199     interesting_locals: &BitSet<Local>,
200     state: &mut FxIndexMap<Local, usize>,
201     patches_scratchpad: &mut FxIndexMap<usize, usize>,
202     replacements_scratchpad: &mut FxIndexMap<usize, Local>,
203 ) {
204     for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
205         match &mut statement.kind {
206             StatementKind::Assign(box (place, rvalue)) => {
207                 match rvalue {
208                     Rvalue::Cast(
209                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
210                         operand,
211                         cast_ty,
212                     ) => {
213                         let Some(local) = place.as_local() else { return };
214                         match operand {
215                             Operand::Copy(place) | Operand::Move(place) => {
216                                 let Some(operand_local) = place.local_or_deref_local() else { return; };
217                                 if !interesting_locals.contains(operand_local) {
218                                     return;
219                                 }
220                                 let operand_ty = local_decls[operand_local].ty;
221                                 match (operand_ty.kind(), cast_ty.kind()) {
222                                     (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
223                                         if of_ty_src == of_ty_dst {
224                                             // this is a cast from [T; N] into [T], so we are good
225                                             state.insert(local, statement_idx);
226                                         }
227                                     }
228                                     // current way of patching doesn't allow to work with `mut`
229                                     (
230                                         ty::Ref(
231                                             Region(Interned(ReErased, _)),
232                                             operand_ty,
233                                             Mutability::Not,
234                                         ),
235                                         ty::Ref(
236                                             Region(Interned(ReErased, _)),
237                                             cast_ty,
238                                             Mutability::Not,
239                                         ),
240                                     ) => {
241                                         match (operand_ty.kind(), cast_ty.kind()) {
242                                             // current way of patching doesn't allow to work with `mut`
243                                             (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
244                                                 if of_ty_src == of_ty_dst {
245                                                     // this is a cast from [T; N] into [T], so we are good
246                                                     state.insert(local, statement_idx);
247                                                 }
248                                             }
249                                             _ => {}
250                                         }
251                                     }
252                                     _ => {}
253                                 }
254                             }
255                             _ => {}
256                         }
257                     }
258                     Rvalue::Len(place) => {
259                         let Some(local) = place.local_or_deref_local() else {
260                             return;
261                         };
262                         if let Some(cast_statement_idx) = state.get(&local).copied() {
263                             patches_scratchpad.insert(cast_statement_idx, statement_idx);
264                         }
265                     }
266                     _ => {
267                         // invalidate
268                         state.remove(&place.local);
269                     }
270                 }
271             }
272             _ => {}
273         }
274     }
275
276     let mut patcher = Patcher {
277         tcx,
278         patches_scratchpad: &*patches_scratchpad,
279         replacements_scratchpad,
280         local_decls,
281         statement_idx: 0,
282     };
283
284     block.expand_statements(|st| patcher.patch_expand_statement(st));
285 }