]> git.lizzy.rs Git - rust.git/blobdiff - compiler/rustc_mir_transform/src/normalize_array_len.rs
Auto merge of #107443 - cjgillot:generator-less-query, r=compiler-errors
[rust.git] / compiler / rustc_mir_transform / src / normalize_array_len.rs
index 1708b287e56f25bca285136fca3a92649a0bafd6..b36c8a0bd5369b45fbb2864b5302957c576c4d7c 100644 (file)
 //! This pass eliminates casting of arrays into slices when their length
 //! is taken using `.len()` method. Handy to preserve information in MIR for const prop
 
+use crate::ssa::SsaLocals;
 use crate::MirPass;
-use rustc_data_structures::fx::FxIndexMap;
-use rustc_data_structures::intern::Interned;
-use rustc_index::bit_set::BitSet;
 use rustc_index::vec::IndexVec;
+use rustc_middle::mir::visit::*;
 use rustc_middle::mir::*;
-use rustc_middle::ty::{self, ReErased, Region, TyCtxt};
-
-const MAX_NUM_BLOCKS: usize = 800;
-const MAX_NUM_LOCALS: usize = 3000;
+use rustc_middle::ty::{self, TyCtxt};
+use rustc_mir_dataflow::impls::borrowed_locals;
 
 pub struct NormalizeArrayLen;
 
 impl<'tcx> MirPass<'tcx> for NormalizeArrayLen {
     fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
-        // See #105929
-        sess.mir_opt_level() >= 4 && sess.opts.unstable_opts.unsound_mir_opts
+        sess.mir_opt_level() >= 3
     }
 
+    #[instrument(level = "trace", skip(self, tcx, body))]
     fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-        // early returns for edge cases of highly unrolled functions
-        if body.basic_blocks.len() > MAX_NUM_BLOCKS {
-            return;
-        }
-        if body.local_decls.len() > MAX_NUM_LOCALS {
-            return;
-        }
+        debug!(def_id = ?body.source.def_id());
         normalize_array_len_calls(tcx, body)
     }
 }
 
-pub fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
-    // We don't ever touch terminators, so no need to invalidate the CFG cache
-    let basic_blocks = body.basic_blocks.as_mut_preserves_cfg();
-    let local_decls = &mut body.local_decls;
+fn normalize_array_len_calls<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
+    let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
+    let borrowed_locals = borrowed_locals(body);
+    let ssa = SsaLocals::new(tcx, param_env, body, &borrowed_locals);
 
-    // do a preliminary analysis to see if we ever have locals of type `[T;N]` or `&[T;N]`
-    let mut interesting_locals = BitSet::new_empty(local_decls.len());
-    for (local, decl) in local_decls.iter_enumerated() {
-        match decl.ty.kind() {
-            ty::Array(..) => {
-                interesting_locals.insert(local);
-            }
-            ty::Ref(.., ty, Mutability::Not) => match ty.kind() {
-                ty::Array(..) => {
-                    interesting_locals.insert(local);
-                }
-                _ => {}
-            },
-            _ => {}
-        }
-    }
-    if interesting_locals.is_empty() {
-        // we have found nothing to analyze
-        return;
-    }
-    let num_intesting_locals = interesting_locals.count();
-    let mut state = FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
-    let mut patches_scratchpad =
-        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
-    let mut replacements_scratchpad =
-        FxIndexMap::with_capacity_and_hasher(num_intesting_locals, Default::default());
-    for block in basic_blocks {
-        // make length calls for arrays [T; N] not to decay into length calls for &[T]
-        // that forbids constant propagation
-        normalize_array_len_call(
-            tcx,
-            block,
-            local_decls,
-            &interesting_locals,
-            &mut state,
-            &mut patches_scratchpad,
-            &mut replacements_scratchpad,
-        );
-        state.clear();
-        patches_scratchpad.clear();
-        replacements_scratchpad.clear();
-    }
-}
+    let slice_lengths = compute_slice_length(tcx, &ssa, body);
+    debug!(?slice_lengths);
 
-struct Patcher<'a, 'tcx> {
-    tcx: TyCtxt<'tcx>,
-    patches_scratchpad: &'a FxIndexMap<usize, usize>,
-    replacements_scratchpad: &'a mut FxIndexMap<usize, Local>,
-    local_decls: &'a mut IndexVec<Local, LocalDecl<'tcx>>,
-    statement_idx: usize,
+    Replacer { tcx, slice_lengths }.visit_body_preserves_cfg(body);
 }
 
-impl<'tcx> Patcher<'_, 'tcx> {
-    fn patch_expand_statement(
-        &mut self,
-        statement: &mut Statement<'tcx>,
-    ) -> Option<std::vec::IntoIter<Statement<'tcx>>> {
-        let idx = self.statement_idx;
-        if let Some(len_statemnt_idx) = self.patches_scratchpad.get(&idx).copied() {
-            let mut statements = Vec::with_capacity(2);
-
-            // we are at statement that performs a cast. The only sound way is
-            // to create another local that performs a similar copy without a cast and then
-            // use this copy in the Len operation
-
-            match &statement.kind {
-                StatementKind::Assign(box (
-                    ..,
-                    Rvalue::Cast(
-                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
-                        operand,
-                        _,
-                    ),
-                )) => {
-                    match operand {
-                        Operand::Copy(place) | Operand::Move(place) => {
-                            // create new local
-                            let ty = operand.ty(self.local_decls, self.tcx);
-                            let local_decl = LocalDecl::with_source_info(ty, statement.source_info);
-                            let local = self.local_decls.push(local_decl);
-                            // make it live
-                            let mut make_live_statement = statement.clone();
-                            make_live_statement.kind = StatementKind::StorageLive(local);
-                            statements.push(make_live_statement);
-                            // copy into it
-
-                            let operand = Operand::Copy(*place);
-                            let mut make_copy_statement = statement.clone();
-                            let assign_to = Place::from(local);
-                            let rvalue = Rvalue::Use(operand);
-                            make_copy_statement.kind =
-                                StatementKind::Assign(Box::new((assign_to, rvalue)));
-                            statements.push(make_copy_statement);
-
-                            // to reorder we have to copy and make NOP
-                            statements.push(statement.clone());
-                            statement.make_nop();
-
-                            self.replacements_scratchpad.insert(len_statemnt_idx, local);
-                        }
-                        _ => {
-                            unreachable!("it's a bug in the implementation")
-                        }
-                    }
-                }
-                _ => {
-                    unreachable!("it's a bug in the implementation")
+fn compute_slice_length<'tcx>(
+    tcx: TyCtxt<'tcx>,
+    ssa: &SsaLocals,
+    body: &Body<'tcx>,
+) -> IndexVec<Local, Option<ty::Const<'tcx>>> {
+    let mut slice_lengths = IndexVec::from_elem(None, &body.local_decls);
+
+    for (local, rvalue) in ssa.assignments(body) {
+        match rvalue {
+            Rvalue::Cast(
+                CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
+                operand,
+                cast_ty,
+            ) => {
+                let operand_ty = operand.ty(body, tcx);
+                debug!(?operand_ty);
+                if let Some(operand_ty) = operand_ty.builtin_deref(true)
+                    && let ty::Array(_, len) = operand_ty.ty.kind()
+                    && let Some(cast_ty) = cast_ty.builtin_deref(true)
+                    && let ty::Slice(..) = cast_ty.ty.kind()
+                {
+                    slice_lengths[local] = Some(*len);
                 }
             }
-
-            self.statement_idx += 1;
-
-            Some(statements.into_iter())
-        } else if let Some(local) = self.replacements_scratchpad.get(&idx).copied() {
-            let mut statements = Vec::with_capacity(2);
-
-            match &statement.kind {
-                StatementKind::Assign(box (into, Rvalue::Len(place))) => {
-                    let add_deref = if let Some(..) = place.as_local() {
-                        false
-                    } else if let Some(..) = place.local_or_deref_local() {
-                        true
-                    } else {
-                        unreachable!("it's a bug in the implementation")
-                    };
-                    // replace len statement
-                    let mut len_statement = statement.clone();
-                    let mut place = Place::from(local);
-                    if add_deref {
-                        place = self.tcx.mk_place_deref(place);
-                    }
-                    len_statement.kind =
-                        StatementKind::Assign(Box::new((*into, Rvalue::Len(place))));
-                    statements.push(len_statement);
-
-                    // make temporary dead
-                    let mut make_dead_statement = statement.clone();
-                    make_dead_statement.kind = StatementKind::StorageDead(local);
-                    statements.push(make_dead_statement);
-
-                    // make original statement NOP
-                    statement.make_nop();
+            // The length information is stored in the fat pointer, so we treat `operand` as a value.
+            Rvalue::Use(operand) => {
+                if let Some(rhs) = operand.place() && let Some(rhs) = rhs.as_local() {
+                    slice_lengths[local] = slice_lengths[rhs];
                 }
-                _ => {
-                    unreachable!("it's a bug in the implementation")
+            }
+            // The length information is stored in the fat pointer.
+            // Reborrowing copies length information from one pointer to the other.
+            Rvalue::Ref(_, _, rhs) | Rvalue::AddressOf(_, rhs) => {
+                if let [PlaceElem::Deref] = rhs.projection[..] {
+                    slice_lengths[local] = slice_lengths[rhs.local];
                 }
             }
-
-            self.statement_idx += 1;
-
-            Some(statements.into_iter())
-        } else {
-            self.statement_idx += 1;
-            None
+            _ => {}
         }
     }
+
+    slice_lengths
 }
 
-fn normalize_array_len_call<'tcx>(
+struct Replacer<'tcx> {
     tcx: TyCtxt<'tcx>,
-    block: &mut BasicBlockData<'tcx>,
-    local_decls: &mut IndexVec<Local, LocalDecl<'tcx>>,
-    interesting_locals: &BitSet<Local>,
-    state: &mut FxIndexMap<Local, usize>,
-    patches_scratchpad: &mut FxIndexMap<usize, usize>,
-    replacements_scratchpad: &mut FxIndexMap<usize, Local>,
-) {
-    for (statement_idx, statement) in block.statements.iter_mut().enumerate() {
-        match &mut statement.kind {
-            StatementKind::Assign(box (place, rvalue)) => {
-                match rvalue {
-                    Rvalue::Cast(
-                        CastKind::Pointer(ty::adjustment::PointerCast::Unsize),
-                        operand,
-                        cast_ty,
-                    ) => {
-                        let Some(local) = place.as_local() else { return };
-                        match operand {
-                            Operand::Copy(place) | Operand::Move(place) => {
-                                let Some(operand_local) = place.local_or_deref_local() else { return; };
-                                if !interesting_locals.contains(operand_local) {
-                                    return;
-                                }
-                                let operand_ty = local_decls[operand_local].ty;
-                                match (operand_ty.kind(), cast_ty.kind()) {
-                                    (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
-                                        if of_ty_src == of_ty_dst {
-                                            // this is a cast from [T; N] into [T], so we are good
-                                            state.insert(local, statement_idx);
-                                        }
-                                    }
-                                    // current way of patching doesn't allow to work with `mut`
-                                    (
-                                        ty::Ref(
-                                            Region(Interned(ReErased, _)),
-                                            operand_ty,
-                                            Mutability::Not,
-                                        ),
-                                        ty::Ref(
-                                            Region(Interned(ReErased, _)),
-                                            cast_ty,
-                                            Mutability::Not,
-                                        ),
-                                    ) => {
-                                        match (operand_ty.kind(), cast_ty.kind()) {
-                                            // current way of patching doesn't allow to work with `mut`
-                                            (ty::Array(of_ty_src, ..), ty::Slice(of_ty_dst)) => {
-                                                if of_ty_src == of_ty_dst {
-                                                    // this is a cast from [T; N] into [T], so we are good
-                                                    state.insert(local, statement_idx);
-                                                }
-                                            }
-                                            _ => {}
-                                        }
-                                    }
-                                    _ => {}
-                                }
-                            }
-                            _ => {}
-                        }
-                    }
-                    Rvalue::Len(place) => {
-                        let Some(local) = place.local_or_deref_local() else {
-                            return;
-                        };
-                        if let Some(cast_statement_idx) = state.get(&local).copied() {
-                            patches_scratchpad.insert(cast_statement_idx, statement_idx);
-                        }
-                    }
-                    _ => {
-                        // invalidate
-                        state.remove(&place.local);
-                    }
-                }
-            }
-            _ => {}
-        }
-    }
+    slice_lengths: IndexVec<Local, Option<ty::Const<'tcx>>>,
+}
 
-    let mut patcher = Patcher {
-        tcx,
-        patches_scratchpad: &*patches_scratchpad,
-        replacements_scratchpad,
-        local_decls,
-        statement_idx: 0,
-    };
+impl<'tcx> MutVisitor<'tcx> for Replacer<'tcx> {
+    fn tcx(&self) -> TyCtxt<'tcx> {
+        self.tcx
+    }
 
-    block.expand_statements(|st| patcher.patch_expand_statement(st));
+    fn visit_rvalue(&mut self, rvalue: &mut Rvalue<'tcx>, loc: Location) {
+        if let Rvalue::Len(place) = rvalue
+            && let [PlaceElem::Deref] = &place.projection[..]
+            && let Some(len) = self.slice_lengths[place.local]
+        {
+            *rvalue = Rvalue::Use(Operand::Constant(Box::new(Constant {
+                span: rustc_span::DUMMY_SP,
+                user_ty: None,
+                literal: ConstantKind::from_const(len, self.tcx),
+            })));
+        }
+        self.super_rvalue(rvalue, loc);
+    }
 }