]> git.lizzy.rs Git - rust.git/commitdiff
Fix equality checks in matches
authorSimonas Kazlauskas <git@kazlauskas.me>
Wed, 30 Dec 2015 20:21:13 +0000 (22:21 +0200)
committerSimonas Kazlauskas <git@kazlauskas.me>
Fri, 1 Jan 2016 12:55:57 +0000 (14:55 +0200)
src/librustc_mir/build/matches/mod.rs
src/librustc_mir/build/matches/test.rs
src/librustc_mir/hair/cx/mod.rs
src/librustc_mir/hair/cx/pattern.rs
src/librustc_mir/hair/mod.rs
src/test/run-pass/mir_build_match_comparisons.rs [new file with mode: 0644]
src/test/run-pass/mir_trans_match_range.rs [deleted file]

index a02ed06ad099e4e77e50cc935e19f0c3d53db833..b456aabbf524bd566e82f247ee3b34d7f85667d5 100644 (file)
@@ -259,7 +259,7 @@ enum TestKind<'tcx> {
 
     // test for equality
     Eq {
-        value: Literal<'tcx>,
+        value: ConstVal,
         ty: Ty<'tcx>,
     },
 
index 7591e80e85f19666e91e0383aa37b6c466ef8643..ec67429379f952e8e3a402728d5203782060184d 100644 (file)
@@ -37,7 +37,7 @@ pub fn test<'pat>(&mut self, match_pair: &MatchPair<'pat, 'tcx>) -> Test<'tcx> {
                 }
             }
 
-            PatternKind::Constant { value: Literal::Value { .. } }
+            PatternKind::Constant { .. }
             if is_switch_ty(match_pair.pattern.ty) => {
                 // for integers, we use a SwitchInt match, which allows
                 // us to handle more cases
@@ -55,12 +55,11 @@ pub fn test<'pat>(&mut self, match_pair: &MatchPair<'pat, 'tcx>) -> Test<'tcx> {
             }
 
             PatternKind::Constant { ref value } => {
-                // for other types, we use an equality comparison
                 Test {
                     span: match_pair.pattern.span,
                     kind: TestKind::Eq {
                         value: value.clone(),
-                        ty: match_pair.pattern.ty.clone(),
+                        ty: match_pair.pattern.ty.clone()
                     }
                 }
             }
@@ -113,7 +112,7 @@ pub fn add_cases_to_switch<'pat>(&mut self,
         };
 
         match *match_pair.pattern.kind {
-            PatternKind::Constant { value: Literal::Value { ref value } } => {
+            PatternKind::Constant { ref value } => {
                 // if the lvalues match, the type should match
                 assert_eq!(match_pair.pattern.ty, switch_ty);
 
@@ -126,7 +125,6 @@ pub fn add_cases_to_switch<'pat>(&mut self,
             }
 
             PatternKind::Range { .. } |
-            PatternKind::Constant { .. } |
             PatternKind::Variant { .. } |
             PatternKind::Slice { .. } |
             PatternKind::Array { .. } |
@@ -177,11 +175,13 @@ pub fn perform_test(&mut self,
             }
 
             TestKind::Eq { ref value, ty } => {
-                // call PartialEq::eq(discrim, constant)
-                let constant = self.literal_operand(test.span, ty.clone(), value.clone());
-                let item_ref = self.hir.partial_eq(ty);
-                self.call_comparison_fn(block, test.span, item_ref,
-                                        Operand::Consume(lvalue.clone()), constant)
+                let expect = self.literal_operand(test.span, ty.clone(), Literal::Value {
+                    value: value.clone()
+                });
+                let val = Operand::Consume(lvalue.clone());
+                let fail = self.cfg.start_new_block();
+                let block = self.compare(block, fail, test.span, BinOp::Eq, expect, val.clone());
+                vec![block, fail]
             }
 
             TestKind::Range { ref lo, ref hi, ty } => {
@@ -251,39 +251,6 @@ fn compare(&mut self,
         target_block
     }
 
-    fn call_comparison_fn(&mut self,
-                          block: BasicBlock,
-                          span: Span,
-                          item_ref: ItemRef<'tcx>,
-                          lvalue1: Operand<'tcx>,
-                          lvalue2: Operand<'tcx>)
-                          -> Vec<BasicBlock> {
-        let target_blocks = vec![self.cfg.start_new_block(), self.cfg.start_new_block()];
-
-        let bool_ty = self.hir.bool_ty();
-        let eq_result = self.temp(bool_ty);
-        let func = self.item_ref_operand(span, item_ref);
-        let call_blocks = (self.cfg.start_new_block(), self.diverge_cleanup());
-        self.cfg.terminate(block,
-                           Terminator::Call {
-                               data: CallData {
-                                   destination: eq_result.clone(),
-                                   func: func,
-                                   args: vec![lvalue1, lvalue2],
-                               },
-                               targets: call_blocks,
-                           });
-
-        // check the result
-        self.cfg.terminate(call_blocks.0,
-                           Terminator::If {
-                               cond: Operand::Consume(eq_result),
-                               targets: (target_blocks[0], target_blocks[1]),
-                           });
-
-        target_blocks
-    }
-
     /// Given that we are performing `test` against `test_lvalue`,
     /// this job sorts out what the status of `candidate` will be
     /// after the test. The `resulting_candidates` vector stores, for
@@ -368,7 +335,7 @@ pub fn sort_candidate<'pat>(&mut self,
             // things out here, in some cases.
             TestKind::SwitchInt { switch_ty: _, options: _, ref indices } => {
                 match *match_pair.pattern.kind {
-                    PatternKind::Constant { value: Literal::Value { ref value } }
+                    PatternKind::Constant { ref value }
                     if is_switch_ty(match_pair.pattern.ty) => {
                         let index = indices[value];
                         let new_candidate = self.candidate_without_match_pair(match_pair_index,
index d24d0985355c2bb24e0d87208a0714f86c421232..f2bc5fec2ff5d0c70fe81f8e07f87551c41c40e0 100644 (file)
 use rustc::mir::repr::*;
 
 use rustc::middle::const_eval::{self, ConstVal};
-use rustc::middle::def_id::DefId;
 use rustc::middle::infer::InferCtxt;
-use rustc::middle::subst::{Subst, Substs};
 use rustc::middle::ty::{self, Ty};
 use syntax::codemap::Span;
-use syntax::parse::token;
 use rustc_front::hir;
 
 #[derive(Copy, Clone)]
@@ -83,11 +80,6 @@ pub fn try_const_eval_literal(&mut self, e: &hir::Expr) -> Option<Literal<'tcx>>
             .map(|v| Literal::Value { value: v })
     }
 
-    pub fn partial_eq(&mut self, ty: Ty<'tcx>) -> ItemRef<'tcx> {
-        let eq_def_id = self.tcx.lang_items.eq_trait().unwrap();
-        self.cmp_method_ref(eq_def_id, "eq", ty)
-    }
-
     pub fn num_variants(&mut self, adt_def: ty::AdtDef<'tcx>) -> usize {
         adt_def.variants.len()
     }
@@ -118,35 +110,6 @@ pub fn span_bug(&mut self, span: Span, message: &str) -> ! {
     pub fn tcx(&self) -> &'a ty::ctxt<'tcx> {
         self.tcx
     }
-
-    fn cmp_method_ref(&mut self,
-                      trait_def_id: DefId,
-                      method_name: &str,
-                      arg_ty: Ty<'tcx>)
-                      -> ItemRef<'tcx> {
-        let method_name = token::intern(method_name);
-        let substs = Substs::new_trait(vec![arg_ty], vec![], arg_ty);
-        for trait_item in self.tcx.trait_items(trait_def_id).iter() {
-            match *trait_item {
-                ty::ImplOrTraitItem::MethodTraitItem(ref method) => {
-                    if method.name == method_name {
-                        let method_ty = self.tcx.lookup_item_type(method.def_id);
-                        let method_ty = method_ty.ty.subst(self.tcx, &substs);
-                        return ItemRef {
-                            ty: method_ty,
-                            kind: ItemKind::Method,
-                            def_id: method.def_id,
-                            substs: self.tcx.mk_substs(substs),
-                        };
-                    }
-                }
-                ty::ImplOrTraitItem::ConstTraitItem(..) |
-                ty::ImplOrTraitItem::TypeTraitItem(..) => {}
-            }
-        }
-
-        self.tcx.sess.bug(&format!("found no method `{}` in `{:?}`", method_name, trait_def_id));
-    }
 }
 
 mod block;
index 8f3a1c17440fc56204f3fec8a20da367f5b4ed0c..1f425aafa256adc40f4cc4bd740c242235c053c1 100644 (file)
 use rustc::middle::const_eval;
 use rustc::middle::def;
 use rustc::middle::pat_util::{pat_is_resolved_const, pat_is_binding};
-use rustc::middle::subst::Substs;
 use rustc::middle::ty::{self, Ty};
 use rustc::mir::repr::*;
 use rustc_front::hir;
 use syntax::ast;
+use syntax::codemap::Span;
 use syntax::ptr::P;
 
 /// When there are multiple patterns in a single arm, each one has its
@@ -40,15 +40,15 @@ struct PatCx<'patcx, 'cx: 'patcx, 'tcx: 'cx> {
 }
 
 impl<'cx, 'tcx> Cx<'cx, 'tcx> {
-    pub fn irrefutable_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
-        PatCx::new(self, None).to_pat(pat)
+    pub fn irrefutable_pat(&mut self, pat: &hir::Pat) -> Pattern<'tcx> {
+        PatCx::new(self, None).to_pattern(pat)
     }
 
     pub fn refutable_pat(&mut self,
                          binding_map: Option<&FnvHashMap<ast::Name, ast::NodeId>>,
-                         pat: &'tcx hir::Pat)
+                         pat: &hir::Pat)
                          -> Pattern<'tcx> {
-        PatCx::new(self, binding_map).to_pat(pat)
+        PatCx::new(self, binding_map).to_pattern(pat)
     }
 }
 
@@ -62,13 +62,12 @@ fn new(cx: &'patcx mut Cx<'cx, 'tcx>,
         }
     }
 
-    fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
+    fn to_pattern(&mut self, pat: &hir::Pat) -> Pattern<'tcx> {
         let kind = match pat.node {
             hir::PatWild => PatternKind::Wild,
 
             hir::PatLit(ref value) => {
                 let value = const_eval::eval_const_expr(self.cx.tcx, value);
-                let value = Literal::Value { value: value };
                 PatternKind::Constant { value: value }
             }
 
@@ -88,22 +87,9 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                     def::DefConst(def_id) | def::DefAssociatedConst(def_id) =>
                         match const_eval::lookup_const_by_id(self.cx.tcx, def_id, Some(pat.id)) {
                             Some(const_expr) => {
-                                let opt_value =
-                                    const_eval::eval_const_expr_partial(
-                                        self.cx.tcx, const_expr,
-                                        const_eval::EvalHint::ExprTypeChecked,
-                                        None);
-                                let literal = if let Ok(value) = opt_value {
-                                    Literal::Value { value: value }
-                                } else {
-                                    let substs = self.cx.tcx.mk_substs(Substs::empty());
-                                    Literal::Item {
-                                        def_id: def_id,
-                                        kind: ItemKind::Constant,
-                                        substs: substs
-                                    }
-                                };
-                                PatternKind::Constant { value: literal }
+                                let pat = const_eval::const_expr_to_pat(self.cx.tcx, const_expr,
+                                                                        pat.span);
+                                return self.to_pattern(&*pat);
                             }
                             None => {
                                 self.cx.tcx.sess.span_bug(
@@ -120,7 +106,7 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
 
             hir::PatRegion(ref subpattern, _) |
             hir::PatBox(ref subpattern) => {
-                PatternKind::Deref { subpattern: self.to_pat(subpattern) }
+                PatternKind::Deref { subpattern: self.to_pattern(subpattern) }
             }
 
             hir::PatVec(ref prefix, ref slice, ref suffix) => {
@@ -131,14 +117,14 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                             subpattern: Pattern {
                                 ty: mt.ty,
                                 span: pat.span,
-                                kind: Box::new(self.slice_or_array_pattern(pat, mt.ty, prefix,
+                                kind: Box::new(self.slice_or_array_pattern(pat.span, mt.ty, prefix,
                                                                            slice, suffix)),
                             },
                         },
 
                     ty::TySlice(..) |
                     ty::TyArray(..) =>
-                        self.slice_or_array_pattern(pat, ty, prefix, slice, suffix),
+                        self.slice_or_array_pattern(pat.span, ty, prefix, slice, suffix),
 
                     ref sty =>
                         self.cx.tcx.sess.span_bug(
@@ -153,7 +139,7 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                                .enumerate()
                                .map(|(i, subpattern)| FieldPattern {
                                    field: Field::new(i),
-                                   pattern: self.to_pat(subpattern),
+                                   pattern: self.to_pattern(subpattern),
                                })
                                .collect();
 
@@ -188,7 +174,7 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                     name: ident.node.name,
                     var: id,
                     ty: var_ty,
-                    subpattern: self.to_opt_pat(sub),
+                    subpattern: self.to_opt_pattern(sub),
                 }
             }
 
@@ -203,7 +189,7 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                                    .enumerate()
                                    .map(|(i, field)| FieldPattern {
                                        field: Field::new(i),
-                                       pattern: self.to_pat(field),
+                                       pattern: self.to_pattern(field),
                                    })
                                    .collect();
                 self.variant_or_leaf(pat, subpatterns)
@@ -234,7 +220,7 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
                               });
                               FieldPattern {
                                   field: Field::new(index),
-                                  pattern: self.to_pat(&field.node.pat),
+                                  pattern: self.to_pattern(&field.node.pat),
                               }
                           })
                           .collect();
@@ -256,28 +242,28 @@ fn to_pat(&mut self, pat: &'tcx hir::Pat) -> Pattern<'tcx> {
         }
     }
 
-    fn to_pats(&mut self, pats: &'tcx [P<hir::Pat>]) -> Vec<Pattern<'tcx>> {
-        pats.iter().map(|p| self.to_pat(p)).collect()
+    fn to_patterns(&mut self, pats: &[P<hir::Pat>]) -> Vec<Pattern<'tcx>> {
+        pats.iter().map(|p| self.to_pattern(p)).collect()
     }
 
-    fn to_opt_pat(&mut self, pat: &'tcx Option<P<hir::Pat>>) -> Option<Pattern<'tcx>> {
-        pat.as_ref().map(|p| self.to_pat(p))
+    fn to_opt_pattern(&mut self, pat: &Option<P<hir::Pat>>) -> Option<Pattern<'tcx>> {
+        pat.as_ref().map(|p| self.to_pattern(p))
     }
 
     fn slice_or_array_pattern(&mut self,
-                              pat: &'tcx hir::Pat,
+                              span: Span,
                               ty: Ty<'tcx>,
-                              prefix: &'tcx [P<hir::Pat>],
-                              slice: &'tcx Option<P<hir::Pat>>,
-                              suffix: &'tcx [P<hir::Pat>])
+                              prefix: &[P<hir::Pat>],
+                              slice: &Option<P<hir::Pat>>,
+                              suffix: &[P<hir::Pat>])
                               -> PatternKind<'tcx> {
         match ty.sty {
             ty::TySlice(..) => {
                 // matching a slice or fixed-length array
                 PatternKind::Slice {
-                    prefix: self.to_pats(prefix),
-                    slice: self.to_opt_pat(slice),
-                    suffix: self.to_pats(suffix),
+                    prefix: self.to_patterns(prefix),
+                    slice: self.to_opt_pattern(slice),
+                    suffix: self.to_patterns(suffix),
                 }
             }
 
@@ -285,20 +271,20 @@ fn slice_or_array_pattern(&mut self,
                 // fixed-length array
                 assert!(len >= prefix.len() + suffix.len());
                 PatternKind::Array {
-                    prefix: self.to_pats(prefix),
-                    slice: self.to_opt_pat(slice),
-                    suffix: self.to_pats(suffix),
+                    prefix: self.to_patterns(prefix),
+                    slice: self.to_opt_pattern(slice),
+                    suffix: self.to_patterns(suffix),
                 }
             }
 
             _ => {
-                self.cx.tcx.sess.span_bug(pat.span, "unexpanded macro or bad constant etc");
+                self.cx.tcx.sess.span_bug(span, "unexpanded macro or bad constant etc");
             }
         }
     }
 
     fn variant_or_leaf(&mut self,
-                       pat: &'tcx hir::Pat,
+                       pat: &hir::Pat,
                        subpatterns: Vec<FieldPattern<'tcx>>)
                        -> PatternKind<'tcx> {
         let def = self.cx.tcx.def_map.borrow().get(&pat.id).unwrap().full_def();
index 99e6d6633f2889c767839c57491b86f5b92f21ed..6363ddf1e1477b987cec22959da84df1bd8e5655 100644 (file)
@@ -15,6 +15,7 @@
 //! structures.
 
 use rustc::mir::repr::{BinOp, BorrowKind, Field, Literal, Mutability, UnOp, ItemKind};
+use rustc::middle::const_eval::ConstVal;
 use rustc::middle::def_id::DefId;
 use rustc::middle::region::CodeExtent;
 use rustc::middle::subst::Substs;
@@ -305,7 +306,7 @@ pub enum PatternKind<'tcx> {
     }, // box P, &P, &mut P, etc
 
     Constant {
-        value: Literal<'tcx>,
+        value: ConstVal,
     },
 
     Range {
diff --git a/src/test/run-pass/mir_build_match_comparisons.rs b/src/test/run-pass/mir_build_match_comparisons.rs
new file mode 100644 (file)
index 0000000..1926619
--- /dev/null
@@ -0,0 +1,55 @@
+// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
+// file at the top-level directory of this distribution and at
+// http://rust-lang.org/COPYRIGHT.
+//
+// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
+// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
+// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
+// option. This file may not be copied, modified, or distributed
+// except according to those terms.
+
+#![feature(rustc_attrs)]
+
+#[rustc_mir]
+pub fn test1(x: i8) -> i32 {
+  match x {
+    1...10 => 0,
+    _ => 1,
+  }
+}
+
+const U: Option<i8> = Some(10);
+const S: &'static str = "hello";
+
+#[rustc_mir]
+pub fn test2(x: i8) -> i32 {
+  match Some(x) {
+    U => 0,
+    _ => 1,
+  }
+}
+
+#[rustc_mir]
+pub fn test3(x: &'static str) -> i32 {
+  match x {
+    S => 0,
+    _ => 1,
+  }
+}
+
+fn main() {
+  assert_eq!(test1(0), 1);
+  assert_eq!(test1(1), 0);
+  assert_eq!(test1(2), 0);
+  assert_eq!(test1(5), 0);
+  assert_eq!(test1(9), 0);
+  assert_eq!(test1(10), 0);
+  assert_eq!(test1(11), 1);
+  assert_eq!(test1(20), 1);
+  assert_eq!(test2(10), 0);
+  assert_eq!(test2(0), 1);
+  assert_eq!(test2(20), 1);
+  assert_eq!(test3("hello"), 0);
+  assert_eq!(test3(""), 1);
+  assert_eq!(test3("world"), 1);
+}
diff --git a/src/test/run-pass/mir_trans_match_range.rs b/src/test/run-pass/mir_trans_match_range.rs
deleted file mode 100644 (file)
index 14184bd..0000000
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2015 The Rust Project Developers. See the COPYRIGHT
-// file at the top-level directory of this distribution and at
-// http://rust-lang.org/COPYRIGHT.
-//
-// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
-// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
-// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
-// option. This file may not be copied, modified, or distributed
-// except according to those terms.
-
-#![feature(rustc_attrs)]
-
-#[rustc_mir]
-pub fn foo(x: i8) -> i32 {
-  match x {
-    1...10 => 0,
-    _ => 1,
-  }
-}
-
-fn main() {
-  assert_eq!(foo(0), 1);
-  assert_eq!(foo(1), 0);
-  assert_eq!(foo(2), 0);
-  assert_eq!(foo(5), 0);
-  assert_eq!(foo(9), 0);
-  assert_eq!(foo(10), 0);
-  assert_eq!(foo(11), 1);
-  assert_eq!(foo(20), 1);
-}