]> git.lizzy.rs Git - rust.git/commitdiff
refactor: avoid filter map next with find map separate traversal
authorrainy-me <github@yue.coffee>
Sat, 25 Dec 2021 00:05:56 +0000 (09:05 +0900)
committerrainy-me <github@yue.coffee>
Sat, 25 Dec 2021 00:08:13 +0000 (09:08 +0900)
crates/hir_ty/src/diagnostics/expr.rs

index a8c4026e31f3b1a2a7fe9a449ce404dfdce8a430..b7d765c59b47d85d0fb536ad78456acc484e95d6 100644 (file)
@@ -81,9 +81,8 @@ fn new(owner: DefWithBodyId, infer: Arc<InferenceResult>) -> ExprValidator {
     }
 
     fn validate_body(&mut self, db: &dyn HirDatabase) {
-        self.check_for_filter_map_next(db);
-
         let body = db.body(self.owner);
+        let mut filter_map_next_checker = None;
 
         for (id, expr) in body.exprs.iter() {
             if let Some((variant, missed_fields, true)) =
@@ -101,7 +100,7 @@ fn validate_body(&mut self, db: &dyn HirDatabase) {
                     self.validate_match(id, *expr, arms, db, self.infer.clone());
                 }
                 Expr::Call { .. } | Expr::MethodCall { .. } => {
-                    self.validate_call(db, id, expr);
+                    self.validate_call(db, id, expr, &mut filter_map_next_checker);
                 }
                 _ => {}
             }
@@ -143,58 +142,13 @@ fn validate_body(&mut self, db: &dyn HirDatabase) {
             });
     }
 
-    fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) {
-        // Find the FunctionIds for Iterator::filter_map and Iterator::next
-        let iterator_path = path![core::iter::Iterator];
-        let resolver = self.owner.resolver(db.upcast());
-        let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) {
-            Some(id) => id,
-            None => return,
-        };
-        let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
-        let filter_map_function_id =
-            match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) {
-                Some((_, AssocItemId::FunctionId(id))) => id,
-                _ => return,
-            };
-        let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next])
-        {
-            Some((_, AssocItemId::FunctionId(id))) => id,
-            _ => return,
-        };
-
-        // Search function body for instances of .filter_map(..).next()
-        let body = db.body(self.owner);
-        let mut prev = None;
-        for (id, expr) in body.exprs.iter() {
-            if let Expr::MethodCall { receiver, .. } = expr {
-                let function_id = match self.infer.method_resolution(id) {
-                    Some((id, _)) => id,
-                    None => continue,
-                };
-
-                if function_id == *filter_map_function_id {
-                    prev = Some(id);
-                    continue;
-                }
-
-                if function_id == *next_function_id {
-                    if let Some(filter_map_id) = prev {
-                        if *receiver == filter_map_id {
-                            self.diagnostics.push(
-                                BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
-                                    method_call_expr: id,
-                                },
-                            );
-                        }
-                    }
-                }
-            }
-            prev = None;
-        }
-    }
-
-    fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) {
+    fn validate_call(
+        &mut self,
+        db: &dyn HirDatabase,
+        call_id: ExprId,
+        expr: &Expr,
+        filter_map_next_checker: &mut Option<FilterMapNextChecker>,
+    ) {
         // Check that the number of arguments matches the number of parameters.
 
         // FIXME: Due to shortcomings in the current type system implementation, only emit this
@@ -214,6 +168,24 @@ fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr)
                 (sig, args.len())
             }
             Expr::MethodCall { receiver, args, .. } => {
+                let (callee, subst) = match self.infer.method_resolution(call_id) {
+                    Some(it) => it,
+                    None => return,
+                };
+
+                if filter_map_next_checker
+                    .get_or_insert_with(|| {
+                        FilterMapNextChecker::new(&self.owner.resolver(db.upcast()), db)
+                    })
+                    .check(call_id, receiver, &callee)
+                    .is_some()
+                {
+                    self.diagnostics.push(
+                        BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap {
+                            method_call_expr: call_id,
+                        },
+                    );
+                }
                 let receiver = &self.infer.type_of_expr[*receiver];
                 if receiver.strip_references().is_unknown() {
                     // if the receiver is of unknown type, it's very likely we
@@ -222,10 +194,6 @@ fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr)
                     return;
                 }
 
-                let (callee, subst) = match self.infer.method_resolution(call_id) {
-                    Some(it) => it,
-                    None => return,
-                };
                 let sig = db.callable_item_signature(callee.into()).substitute(Interner, &subst);
 
                 (sig, args.len() + 1)
@@ -424,6 +392,63 @@ fn validate_missing_tail_expr(&mut self, body_id: ExprId, possible_tail_id: Expr
     }
 }
 
+struct FilterMapNextChecker {
+    filter_map_function_id: Option<hir_def::FunctionId>,
+    next_function_id: Option<hir_def::FunctionId>,
+    prev_filter_map_expr_id: Option<ExprId>,
+}
+
+impl FilterMapNextChecker {
+    fn new(resolver: &hir_def::resolver::Resolver, db: &dyn HirDatabase) -> Self {
+        // Find and store the FunctionIds for Iterator::filter_map and Iterator::next
+        let iterator_path = path![core::iter::Iterator];
+        let mut filter_map_function_id = None;
+        let mut next_function_id = None;
+
+        if let Some(iterator_trait_id) = resolver.resolve_known_trait(db.upcast(), &iterator_path) {
+            let iterator_trait_items = &db.trait_data(iterator_trait_id).items;
+            for item in iterator_trait_items.iter() {
+                if let (name, AssocItemId::FunctionId(id)) = item {
+                    if *name == name![filter_map] {
+                        filter_map_function_id = Some(*id);
+                    }
+                    if *name == name![next] {
+                        next_function_id = Some(*id);
+                    }
+                }
+                if filter_map_function_id.is_some() && next_function_id.is_some() {
+                    break;
+                }
+            }
+        }
+        Self { filter_map_function_id, next_function_id, prev_filter_map_expr_id: None }
+    }
+
+    // check for instances of .filter_map(..).next()
+    fn check(
+        &mut self,
+        current_expr_id: ExprId,
+        receiver_expr_id: &ExprId,
+        function_id: &hir_def::FunctionId,
+    ) -> Option<()> {
+        if *function_id == self.filter_map_function_id? {
+            self.prev_filter_map_expr_id = Some(current_expr_id);
+            return None;
+        }
+
+        if *function_id == self.next_function_id? {
+            if let Some(prev_filter_map_expr_id) = self.prev_filter_map_expr_id {
+                if *receiver_expr_id == prev_filter_map_expr_id {
+                    return Some(());
+                }
+            }
+        }
+
+        self.prev_filter_map_expr_id = None;
+        None
+    }
+}
+
 pub fn record_literal_missing_fields(
     db: &dyn HirDatabase,
     infer: &InferenceResult,