]> git.lizzy.rs Git - rust.git/blob - compiler/rustc_mir_transform/src/normalize_array_len.rs
Rollup merge of #93269 - jacobbramley:dev/pauth-option-1, r=petrochenkov
[rust.git] / compiler / rustc_mir_transform / src / normalize_array_len.rs
1 //! This pass eliminates casting of arrays into slices when their length
2 //! is taken using `.len()` method. Handy to preserve information in MIR for const prop
3
4 use crate::MirPass;
5 use rustc_data_structures::fx::FxIndexMap;
6 use rustc_index::bit_set::BitSet;
7 use rustc_index::vec::IndexVec;
8 use rustc_middle::mir::*;
9 use rustc_middle::ty::{self, TyCtxt};
10
11 const MAX_NUM_BLOCKS: usize = 800;
12 const MAX_NUM_LOCALS: usize = 3000;
13
14 pub struct NormalizeArrayLen;
15
16 impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
17     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
18         sess.mir_opt_level() >= 4
19     }
20
21     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
22         // early returns for edge cases of highly unrolled functions
23         if body.basic_blocks().len() > MAX_NUM_BLOCKS {
24             return;
25         }
26         if body.local_decls().len() > MAX_NUM_LOCALS {
27             return;
28         }
29         normalize_array_len_calls(tcx, body)
30     }
31 }
32
33 pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
34     let (basic_blocks, local_decls) = body.basic_blocks_and_local_decls_mut();
35
36     // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
37     let mut interesting_locals = BitSet::new_empty(local_decls.len());
38     for (local, decl) in local_decls.iter_enumerated() {
39         match decl.ty.kind() {
40             ty::Array(..) => {
41                 interesting_locals.insert(local);
42             }
43             ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
44                 ty::Array(..) => {
45                     interesting_locals.insert(local);
46                 }
47                 _ => {}
48             },
49             _ => {}
50         }
51     }
52     if interesting_locals.is_empty() {
53         // we have found nothing to analyze
54         return;
55     }
56     let num_intesting_locals = interesting_locals.count();
57     let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
58     let mut patches_scratchpad =
59         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
60     let mut replacements_scratchpad =
61         FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
62     for block in basic_blocks {
63         // make length calls for arrays [T; N] not to decay into length calls for &[T]
64         // that forbids constant propagation
65         normalize_array_len_call(
66             tcx,
67             block,
68             local_decls,
69             &interesting_locals,
70             &mut state,
71             &mut patches_scratchpad,
72             &mut replacements_scratchpad,
73         );
74         state.clear();
75         patches_scratchpad.clear();
76         replacements_scratchpad.clear();
77     }
78 }
79
80 struct Patcher<'a, 'tcx> {
81     tcx: TyCtxt<'tcx>,
82     patches_scratchpad: &'a FxIndexMap<usize, usize>,
83     replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
84     local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
85     statement_idx: usize,
86 }
87
88 impl<'tcx> Patcher<'_, 'tcx> {
89     fn patch_expand_statement(
90         &mut self,
91         statement: &mut Statement<'tcx>,
92     ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
93         let idx = self.statement_idx;
94         if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
95             let mut statements = Vec::with_capacity(2);
96
97             // we are at statement that performs a cast. The only sound way is
98             // to create another local that performs a similar copy without a cast and then
99             // use this copy in the Len operation
100
101             match &statement.kind {
102                 StatementKind::Assign(box (
103                     ..,
104                     Rvalue::Cast(
105                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
106                         operand,
107                         _,
108                     ),
109                 )) => {
110                     match operand {
111                         Operand::Copy(place) | Operand::Move(place) => {
112                             // create new local
113                             let ty = operand.ty(self.local_decls, self.tcx);
114                             let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
115                             let local = self.local_decls.push(local_decl);
116                             // make it live
117                             let mut make_live_statement = statement.clone();
118                             make_live_statement.kind = StatementKind::StorageLive(local);
119                             statements.push(make_live_statement);
120                             // copy into it
121
122                             let operand = Operand::Copy(*place);
123                             let mut make_copy_statement = statement.clone();
124                             let assign_to = Place::from(local);
125                             let rvalue = Rvalue::Use(operand);
126                             make_copy_statement.kind =
127                                 StatementKind::Assign(box (assign_to, rvalue));
128                             statements.push(make_copy_statement);
129
130                             // to reorder we have to copy and make NOP
131                             statements.push(statement.clone());
132                             statement.make_nop();
133
134                             self.replacements_scratchpad.insert(len_statemnt_idx, local);
135                         }
136                         _ => {
137                             unreachable!("it's a bug in the implementation")
138                         }
139                     }
140                 }
141                 _ => {
142                     unreachable!("it's a bug in the implementation")
143                 }
144             }
145
146             self.statement_idx += 1;
147
148             Some(statements.into_iter())
149         } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
150             let mut statements = Vec::with_capacity(2);
151
152             match &statement.kind {
153                 StatementKind::Assign(box (into, Rvalue::Len(place))) => {
154                     let add_deref = if let Some(..) = place.as_local() {
155                         false
156                     } else if let Some(..) = place.local_or_deref_local() {
157                         true
158                     } else {
159                         unreachable!("it's a bug in the implementation")
160                     };
161                     // replace len statement
162                     let mut len_statement = statement.clone();
163                     let mut place = Place::from(local);
164                     if add_deref {
165                         place = self.tcx.mk_place_deref(place);
166                     }
167                     len_statement.kind = StatementKind::Assign(box (*into, Rvalue::Len(place)));
168                     statements.push(len_statement);
169
170                     // make temporary dead
171                     let mut make_dead_statement = statement.clone();
172                     make_dead_statement.kind = StatementKind::StorageDead(local);
173                     statements.push(make_dead_statement);
174
175                     // make original statement NOP
176                     statement.make_nop();
177                 }
178                 _ => {
179                     unreachable!("it's a bug in the implementation")
180                 }
181             }
182
183             self.statement_idx += 1;
184
185             Some(statements.into_iter())
186         } else {
187             self.statement_idx += 1;
188             None
189         }
190     }
191 }
192
193 fn normalize_array_len_call<'tcx>(
194     tcx: TyCtxt<'tcx>,
195     block: &mut BasicBlockData<'tcx>,
196     local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
197     interesting_locals: &BitSet<Local>,
198     state: &mut FxIndexMap<Local, usize>,
199     patches_scratchpad: &mut FxIndexMap<usize, usize>,
200     replacements_scratchpad: &mut FxIndexMap<usize, Local>,
201 ) {
202     for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
203         match &mut statement.kind {
204             StatementKind::Assign(box (place, rvalue)) => {
205                 match rvalue {
206                     Rvalue::Cast(
207                         CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
208                         operand,
209                         cast_ty,
210                     ) => {
211                         let Some(local) = place.as_local() else { return };
212                         match operand {
213                             Operand::Copy(place) | Operand::Move(place) => {
214                                 let operand_local =
215                                     if let Some(local) = place.local_or_deref_local() {
216                                         local
217                                     } else {
218                                         return;
219                                     };
220                                 if !interesting_locals.contains(operand_local) {
221                                     return;
222                                 }
223                                 let operand_ty = local_decls[operand_local].ty;
224                                 match (operand_ty.kind(), cast_ty.kind()) {
225                                     (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
226                                         if of_ty_src == of_ty_dst {
227                                             // this is a cast from [T; N] into [T], so we are good
228                                             state.insert(local, statement_idx);
229                                         }
230                                     }
231                                     // current way of patching doesn't allow to work with `mut`
232                                     (
233                                         ty::Ref(
234                                             ty::RegionKind::ReErased,
235                                             operand_ty,
236                                             Mutability::Not,
237                                         ),
238                                         ty::Ref(ty::RegionKind::ReErased, cast_ty, Mutability::Not),
239                                     ) => {
240                                         match (operand_ty.kind(), cast_ty.kind()) {
241                                             // current way of patching doesn't allow to work with `mut`
242                                             (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
243                                                 if of_ty_src == of_ty_dst {
244                                                     // this is a cast from [T; N] into [T], so we are good
245                                                     state.insert(local, statement_idx);
246                                                 }
247                                             }
248                                             _ => {}
249                                         }
250                                     }
251                                     _ => {}
252                                 }
253                             }
254                             _ => {}
255                         }
256                     }
257                     Rvalue::Len(place) => {
258                         let Some(local) = place.local_or_deref_local() else {
259                             return;
260                         };
261                         if let Some(cast_statement_idx) = state.get(&local).copied() {
262                             patches_scratchpad.insert(cast_statement_idx, statement_idx);
263                         }
264                     }
265                     _ => {
266                         // invalidate
267                         state.remove(&place.local);
268                     }
269                 }
270             }
271             _ => {}
272         }
273     }
274
275     let mut patcher = Patcher {
276         tcx,
277         patches_scratchpad: &*patches_scratchpad,
278         replacements_scratchpad,
279         local_decls,
280         statement_idx: 0,
281     };
282
283     block.expand_statements(|st| patcher.patch_expand_statement(st));
284 }