]> git.lizzy.rs Git - rust.git/blobdiff - crates/ide_assists/src/utils/gen_trait_fn_body.rs
Simplify generated PartialOrd code
[rust.git] / crates / ide_assists / src / utils / gen_trait_fn_body.rs
index 5ec8adc2d4cf9bd3a8f310a214e0342ec5b3aac8..c883e6fb11ba947b04f9e5dcb332a28f173520be 100644 (file)
@@ -574,11 +574,17 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
 }
 
 fn gen_partial_ord(adt: &ast::Adt, func: &ast::Fn) -> Option<()> {
-    fn gen_eq_chain(expr: Option<ast::Expr>, cmp: ast::Expr) -> Option<ast::Expr> {
-        match expr {
-            Some(expr) => Some(make::expr_op(ast::BinOp::BooleanAnd, expr, cmp)),
-            None => Some(cmp),
-        }
+    fn gen_partial_cmp_call(lhs: ast::Expr, rhs: ast::Expr) -> ast::Expr {
+        let method = make::name_ref("partial_cmp");
+        make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
+    }
+    fn gen_partial_cmp_call2(mut lhs: Vec<ast::Expr>, mut rhs: Vec<ast::Expr>) -> ast::Expr {
+        let (lhs, rhs) = match (lhs.len(), rhs.len()) {
+            (1, 1) => (lhs.pop().unwrap(), rhs.pop().unwrap()),
+            _ => (make::expr_tuple(lhs.into_iter()), make::expr_tuple(rhs.into_iter())),
+        };
+        let method = make::name_ref("partial_cmp");
+        make::expr_method_call(lhs, method, make::arg_list(Some(rhs)))
     }
 
     fn gen_record_pat_field(field_name: &str, pat_name: &str) -> ast::RecordPatField {
@@ -613,7 +619,7 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
             let lhs = make::expr_call(make_discriminant()?, make::arg_list(Some(lhs_name.clone())));
             let rhs_name = make::expr_path(make::ext::ident_path("other"));
             let rhs = make::expr_call(make_discriminant()?, make::arg_list(Some(rhs_name.clone())));
-            let eq_check = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
+            let ord_check = gen_partial_cmp_call(lhs, rhs);
 
             let mut case_count = 0;
             let mut arms = vec![];
@@ -622,7 +628,8 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                 match variant.field_list() {
                     // => (Self::Bar { bin: l_bin }, Self::Bar { bin: r_bin }) => l_bin == r_bin,
                     Some(ast::FieldList::RecordFieldList(list)) => {
-                        let mut expr = None;
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
                         let mut l_fields = vec![];
                         let mut r_fields = vec![];
 
@@ -630,28 +637,36 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                             let field_name = field.name()?.to_string();
 
                             let l_name = &format!("l_{}", field_name);
-                            l_fields.push(gen_record_pat_field(&field_name, &l_name));
+                            l_pat_fields.push(gen_record_pat_field(&field_name, &l_name));
 
                             let r_name = &format!("r_{}", field_name);
-                            r_fields.push(gen_record_pat_field(&field_name, &r_name));
+                            r_pat_fields.push(gen_record_pat_field(&field_name, &r_name));
 
                             let lhs = make::expr_path(make::ext::ident_path(l_name));
                             let rhs = make::expr_path(make::ext::ident_path(r_name));
-                            let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
-                            expr = gen_eq_chain(expr, cmp);
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
                         }
 
-                        let left = gen_record_pat(gen_variant_path(&variant)?, l_fields);
-                        let right = gen_record_pat(gen_variant_path(&variant)?, r_fields);
-                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
-
-                        if let Some(expr) = expr {
-                            arms.push(make::match_arm(Some(tuple.into()), None, expr));
+                        let left_pat = gen_record_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat = gen_record_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
+
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
+                                    .indent(ast::edit::IndentLevel(1))
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
                         }
                     }
 
                     Some(ast::FieldList::TupleFieldList(list)) => {
-                        let mut expr = None;
+                        let mut l_pat_fields = vec![];
+                        let mut r_pat_fields = vec![];
                         let mut l_fields = vec![];
                         let mut r_fields = vec![];
 
@@ -659,23 +674,32 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
                             let field_name = format!("{}", i);
 
                             let l_name = format!("l{}", field_name);
-                            l_fields.push(gen_tuple_field(&l_name));
+                            l_pat_fields.push(gen_tuple_field(&l_name));
 
                             let r_name = format!("r{}", field_name);
-                            r_fields.push(gen_tuple_field(&r_name));
+                            r_pat_fields.push(gen_tuple_field(&r_name));
 
                             let lhs = make::expr_path(make::ext::ident_path(&l_name));
                             let rhs = make::expr_path(make::ext::ident_path(&r_name));
-                            let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
-                            expr = gen_eq_chain(expr, cmp);
+                            l_fields.push(lhs);
+                            r_fields.push(rhs);
                         }
 
-                        let left = make::tuple_struct_pat(gen_variant_path(&variant)?, l_fields);
-                        let right = make::tuple_struct_pat(gen_variant_path(&variant)?, r_fields);
-                        let tuple = make::tuple_pat(vec![left.into(), right.into()]);
-
-                        if let Some(expr) = expr {
-                            arms.push(make::match_arm(Some(tuple.into()), None, expr));
+                        let left_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, l_pat_fields);
+                        let right_pat =
+                            make::tuple_struct_pat(gen_variant_path(&variant)?, r_pat_fields);
+                        let tuple_pat = make::tuple_pat(vec![left_pat.into(), right_pat.into()]);
+
+                        let len = l_fields.len();
+                        if len != 0 {
+                            let mut expr = gen_partial_cmp_call2(l_fields, r_fields);
+                            if len >= 2 {
+                                expr = make::block_expr(None, Some(expr))
+                                    .indent(ast::edit::IndentLevel(1))
+                                    .into();
+                            }
+                            arms.push(make::match_arm(Some(tuple_pat.into()), None, expr));
                         }
                     }
                     None => continue,
@@ -683,11 +707,11 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
             }
 
             let expr = match arms.len() {
-                0 => eq_check,
+                0 => ord_check,
                 _ => {
                     if case_count > arms.len() {
                         let lhs = make::wildcard_pat().into();
-                        arms.push(make::match_arm(Some(lhs), None, eq_check));
+                        arms.push(make::match_arm(Some(lhs), None, ord_check));
                     }
 
                     let match_target = make::expr_tuple(vec![lhs_name, rhs_name]);
@@ -700,30 +724,35 @@ fn gen_tuple_field(field_name: &String) -> ast::Pat {
         }
         ast::Adt::Struct(strukt) => match strukt.field_list() {
             Some(ast::FieldList::RecordFieldList(field_list)) => {
-                let mut expr = None;
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
                 for field in field_list.fields() {
                     let lhs = make::expr_path(make::ext::ident_path("self"));
                     let lhs = make::expr_field(lhs, &field.name()?.to_string());
                     let rhs = make::expr_path(make::ext::ident_path("other"));
                     let rhs = make::expr_field(rhs, &field.name()?.to_string());
-                    let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
-                    expr = gen_eq_chain(expr, cmp);
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
                 }
-                make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
+
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
             }
 
             Some(ast::FieldList::TupleFieldList(field_list)) => {
-                let mut expr = None;
+                let mut l_fields = vec![];
+                let mut r_fields = vec![];
                 for (i, _) in field_list.fields().enumerate() {
                     let idx = format!("{}", i);
                     let lhs = make::expr_path(make::ext::ident_path("self"));
                     let lhs = make::expr_field(lhs, &idx);
                     let rhs = make::expr_path(make::ext::ident_path("other"));
                     let rhs = make::expr_field(rhs, &idx);
-                    let cmp = make::expr_op(ast::BinOp::EqualityTest, lhs, rhs);
-                    expr = gen_eq_chain(expr, cmp);
+                    l_fields.push(lhs);
+                    r_fields.push(rhs);
                 }
-                make::block_expr(None, expr).indent(ast::edit::IndentLevel(1))
+                let expr = gen_partial_cmp_call2(l_fields, r_fields);
+                make::block_expr(None, Some(expr)).indent(ast::edit::IndentLevel(1))
             }
 
             // No fields in the body means there's nothing to hash.