]> git.lizzy.rs Git - rust.git/commitdiff
Simplify generated PartialOrd code
authorYoshua Wuyts <yoshuawuyts@gmail.com>
Tue, 12 Oct 2021 15:44:57 +0000 (17:44 +0200)
committerYoshua Wuyts <yoshuawuyts@gmail.com>
Tue, 12 Oct 2021 15:44:57 +0000 (17:44 +0200)
crates/ide_assists/src/handlers/replace_derive_with_manual_impl.rs
crates/ide_assists/src/utils/gen_trait_fn_body.rs

index 4fceefe331dc10baeea4056f4a2e8c27ed6315e7..b04bd6ba09845165cbb439acc3006849b3da3fe0 100644 (file)
@@ -682,6 +682,31 @@ fn add_custom_impl_partial_ord_record_struct() {
             r#"
 //- minicore: ord
 #[derive(Partial$0Ord)]
+struct Foo {
+    bin: usize,
+}
+"#,
+            r#"
+struct Foo {
+    bin: usize,
+}
+
+impl PartialOrd for Foo {
+    $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
+        self.bin.partial_cmp(other.bin)
+    }
+}
+"#,
+        )
+    }
+
+    #[test]
+    fn add_custom_impl_partial_ord_record_struct_multi_field() {
+        check_assist(
+            replace_derive_with_manual_impl,
+            r#"
+//- minicore: ord
+#[derive(Partial$0Ord)]
 struct Foo {
     bin: usize,
     bar: usize,
@@ -697,15 +722,7 @@ struct Foo {
 
 impl PartialOrd for Foo {
     $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
-        match self.bin.partial_cmp(other.bin) {
-            Some(core::cmp::Ordering::Eq) => {}
-            ord => return ord,
-        }
-        match self.bar.partial_cmp(other.bar) {
-            Some(core::cmp::Ordering::Eq) => {}
-            ord => return ord,
-        }
-        self.baz.partial_cmp(other.baz)
+        (self.bin, self.bar, self.baz).partial_cmp((other.bin, other.bar, other.baz))
     }
 }
 "#,
@@ -726,15 +743,7 @@ fn add_custom_impl_partial_ord_tuple_struct() {
 
 impl PartialOrd for Foo {
     $0fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
-        match self.0.partial_cmp(other.0) {
-            Some(core::cmp::Ordering::Eq) => {}
-            ord => return ord,
-        }
-        match self.1.partial_cmp(other.1) {
-            Some(core::cmp::Ordering::Eq) => {}
-            ord => return ord,
-        }
-        self.2.partial_cmp(other.2)
+        (self.0, self.1, self.2).partial_cmp((other.0, other.1, other.2))
     }
 }
 "#,
@@ -807,11 +816,7 @@ impl PartialOrd for Foo {
         match (self, other) {
             (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin.partial_cmp(r_bin),
             (Self::Baz { qux: l_qux, fez: l_fez }, Self::Baz { qux: r_qux, fez: r_fez }) => {
-                match l_qux.partial_cmp(r_qux) {
-                    Some(core::cmp::Ordering::Eq) => {}
-                    ord => return ord,
-                }
-                l_fez.partial_cmp(r_fez)
+                (l_qux, l_fez).partial_cmp((r_qux, r_fez))
             }
             _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)),
         }
@@ -848,11 +853,7 @@ impl PartialOrd for Foo {
         match (self, other) {
             (Self::Bar(l0), Self::Bar(r0)) => l0.partial_cmp(r0),
             (Self::Baz(l0, l1), Self::Baz(r0, r1)) => {
-                match l0.partial_cmp(r0) {
-                    Some(core::cmp::Ordering::Eq) => {}
-                    ord => return ord,
-                }
-                l1.partial_cmp(r1)
+                (l0, l1).partial_cmp((r0, r1))
             }
             _ => core::mem::discriminant(self).partial_cmp(core::mem::discriminant(other)),
         }
index 10b781636f439ea202c8e776c80af14373a10c50..c883e6fb11ba947b04f9e5dcb332a28f173520be 100644 (file)
@@ -574,27 +574,18 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
 }
 
 fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
-    fn gen_partial_eq_match(match_target: ast::Expr) -> Option<ast::Stmt> {
-        let mut arms = vec![];
-
-        let variant_name =
-            make::path_pat(make::ext::path_from_idents(["core", "cmp", "Ordering", "Eq"])?);
-        let lhs = make::tuple_struct_pat(make::ext::path_from_idents(["Some"])?, [variant_name]);
-        arms.push(make::match_arm(Some(lhs.into()), None, make::expr_empty_block()));
-
-        arms.push(make::match_arm(
-            [make::ident_pat(false, false, make::name("ord")).into()],
-            None,
-            make::expr_return(Some(make::expr_path(make::ext::ident_path("ord")))),
-        ));
-        let list = make::match_arm_list(arms).indent(ast::edit::IndentLevel(1));
-        Some(make::expr_stmt(make::expr_match(match_target, list)).into())
-    }
-
     fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr {
         let method = make::name_ref("partial_cmp");
         make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
     }
+    fn gen_partial_cmp_call2(mut lhs: Vec<ast::Expr>, mut rhs: Vec<ast::Expr>) -> ast::Expr {
+        let (lhs, rhs) = match (lhs.len(), rhs.len()) {
+            (1, 1) => (lhs.pop().unwrap(), rhs.pop().unwrap()),
+            _ => (make::expr_tuple(lhs.into_iter()), make::expr_tuple(rhs.into_iter())),
+        };
+        let method = make::name_ref("partial_cmp");
+        make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
+    }
 
     fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
         let pat = make::ext::simple_ident_pat(make::name(&pat_name));
@@ -637,7 +628,8 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                 match variant.field_list() {
                     // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
                     Some(ast::FieldList::RecordFieldList(list)) => {
-                        let mut exprs = vec![];
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
                         let mut l_fields = vec![];
                         let mut r_fields = vec![];
 
@@ -645,38 +637,36 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                             let field_name = field.name()?.to_string();
 
                             let l_name = &format!("l_{}", field_name);
-                            l_fields.push(gen_record_pat_field(&field_name, &l_name));
+                            l_pat_fields.push(gen_record_pat_field(&field_name, &l_name));
 
                             let r_name = &format!("r_{}", field_name);
-                            r_fields.push(gen_record_pat_field(&field_name, &r_name));
+                            r_pat_fields.push(gen_record_pat_field(&field_name, &r_name));
 
                             let lhs = make::expr_path(make::ext::ident_path(l_name));
                             let rhs = make::expr_path(make::ext::ident_path(r_name));
-                            let ord = gen_partial_cmp_call(lhs, rhs);
-                            exprs.push(ord);
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
                         }
 
-                        let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
-                        let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
-                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
+                        let left_pat = gen_record_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat = gen_record_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
 
-                        if let Some(tail) = exprs.pop() {
-                            let stmts = exprs
-                                .into_iter()
-                                .map(gen_partial_eq_match)
-                                .collect::<Option<Vec<ast::Stmt>>>()?;
-                            let expr = match stmts.len() {
-                                0 => tail,
-                                _ => make::block_expr(stmts.into_iter(), Some(tail))
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
                                     .indent(ast::edit::IndentLevel(1))
-                                    .into(),
-                            };
-                            arms.push(make::match_arm(Some(tuple.into()), None, expr.into()));
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
                         }
                     }
 
                     Some(ast::FieldList::TupleFieldList(list)) => {
-                        let mut exprs = vec![];
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
                         let mut l_fields = vec![];
                         let mut r_fields = vec![];
 
@@ -684,33 +674,32 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                             let field_name = format!("{}", i);
 
                             let l_name = format!("l{}", field_name);
-                            l_fields.push(gen_tuple_field(&l_name));
+                            l_pat_fields.push(gen_tuple_field(&l_name));
 
                             let r_name = format!("r{}", field_name);
-                            r_fields.push(gen_tuple_field(&r_name));
+                            r_pat_fields.push(gen_tuple_field(&r_name));
 
                             let lhs = make::expr_path(make::ext::ident_path(&l_name));
                             let rhs = make::expr_path(make::ext::ident_path(&r_name));
-                            let ord = gen_partial_cmp_call(lhs, rhs);
-                            exprs.push(ord);
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
                         }
 
-                        let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
-                        let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
-                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
-
-                        if let Some(tail) = exprs.pop() {
-                            let stmts = exprs
-                                .into_iter()
-                                .map(gen_partial_eq_match)
-                                .collect::<Option<Vec<ast::Stmt>>>()?;
-                            let expr = match stmts.len() {
-                                0 => tail,
-                                _ => make::block_expr(stmts.into_iter(), Some(tail))
+                        let left_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
+
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
                                     .indent(ast::edit::IndentLevel(1))
-                                    .into(),
-                            };
-                            arms.push(make::match_arm(Some(tuple.into()), None, expr.into()));
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
                         }
                     }
                     None => continue,
@@ -735,41 +724,35 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
         }
         ast::Adt::Struct(strukt) => match strukt.field_list() {
             Some(ast::FieldList::RecordFieldList(field_list)) => {
-                let mut exprs = vec![];
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
                 for field in field_list.fields() {
                     let lhs = make::expr_path(make::ext::ident_path("self"));
                     let lhs = make::expr_field(lhs, &field.name()?.to_string());
                     let rhs = make::expr_path(make::ext::ident_path("other"));
                     let rhs = make::expr_field(rhs, &field.name()?.to_string());
-                    let ord = gen_partial_cmp_call(lhs, rhs);
-                    exprs.push(ord);
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
                 }
 
-                let tail = exprs.pop();
-                let stmts = exprs
-                    .into_iter()
-                    .map(gen_partial_eq_match)
-                    .collect::<Option<Vec<ast::Stmt>>>()?;
-                make::block_expr(stmts.into_iter(), tail).indent(ast::edit::IndentLevel(1))
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
             }
 
             Some(ast::FieldList::TupleFieldList(field_list)) => {
-                let mut exprs = vec![];
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
                 for (i, _) in field_list.fields().enumerate() {
                     let idx = format!("{}", i);
                     let lhs = make::expr_path(make::ext::ident_path("self"));
                     let lhs = make::expr_field(lhs, &idx);
                     let rhs = make::expr_path(make::ext::ident_path("other"));
                     let rhs = make::expr_field(rhs, &idx);
-                    let ord = gen_partial_cmp_call(lhs, rhs);
-                    exprs.push(ord);
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
                 }
-                let tail = exprs.pop();
-                let stmts = exprs
-                    .into_iter()
-                    .map(gen_partial_eq_match)
-                    .collect::<Option<Vec<ast::Stmt>>>()?;
-                make::block_expr(stmts.into_iter(), tail).indent(ast::edit::IndentLevel(1))
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
             }
 
             // No fields in the body means there's nothing to hash.