]> git.lizzy.rs Git - rust.git/commitdiff
Improve `match` MIR generation for ranges
authorShotaro Yamada <sinkuu@sinkuu.xyz>
Thu, 13 Dec 2018 13:35:54 +0000 (22:35 +0900)
committerShotaro Yamada <sinkuu@sinkuu.xyz>
Fri, 14 Dec 2018 23:39:40 +0000 (08:39 +0900)
Makes testing a range rule out ranges/constant
covered by the range that is being tested

src/librustc_mir/build/matches/test.rs
src/librustc_mir/hair/pattern/mod.rs
src/test/run-pass/mir/mir_match_test.rs [new file with mode: 0644]

index a95804e05c9061771373ed08b7722597a521d6c8..77db74685cd2f8424219b8d2b88ee59488074e2c 100644 (file)
@@ -18,6 +18,7 @@
 use build::Builder;
 use build::matches::{Candidate, MatchPair, Test, TestKind};
 use hair::*;
+use hair::pattern::compare_const_vals;
 use rustc_data_structures::bit_set::BitSet;
 use rustc_data_structures::fx::FxHashMap;
 use rustc::ty::{self, Ty};
@@ -136,7 +137,15 @@ pub fn add_cases_to_switch<'pat>(&mut self,
             PatternKind::Variant { .. } => {
                 panic!("you should have called add_variants_to_switch instead!");
             }
-            PatternKind::Range { .. } |
+            PatternKind::Range { ty, lo, hi, end } => {
+                indices
+                    .keys()
+                    .all(|value| {
+                        !self
+                            .const_range_contains(ty, lo, hi, end, value)
+                            .unwrap_or(true)
+                    })
+            }
             PatternKind::Slice { .. } |
             PatternKind::Array { .. } |
             PatternKind::Wild |
@@ -529,6 +538,28 @@ pub fn sort_candidate<'pat>(&mut self,
                 resulting_candidates[index].push(new_candidate);
                 true
             }
+
+            (&TestKind::SwitchInt { switch_ty: _, ref options, ref indices },
+             &PatternKind::Range { ty, lo, hi, end }) => {
+                let not_contained = indices
+                    .keys()
+                    .all(|value| {
+                        !self
+                            .const_range_contains(ty, lo, hi, end, value)
+                            .unwrap_or(true)
+                    });
+
+                if not_contained {
+                    // No values are contained in the pattern range,
+                    // so the pattern can be matched only if this test fails.
+                    let otherwise = options.len();
+                    resulting_candidates[otherwise].push(candidate.clone());
+                    true
+                } else {
+                    false
+                }
+            }
+
             (&TestKind::SwitchInt { .. }, _) => false,
 
 
@@ -607,8 +638,70 @@ pub fn sort_candidate<'pat>(&mut self,
                 }
             }
 
+            (&TestKind::Range {
+                lo: test_lo, hi: test_hi, ty: test_ty, end: test_end,
+            }, &PatternKind::Range {
+                lo: pat_lo, hi: pat_hi, ty: _, end: pat_end,
+            }) => {
+                if (test_lo, test_hi, test_end) == (pat_lo, pat_hi, pat_end) {
+                    resulting_candidates[0]
+                        .push(self.candidate_without_match_pair(
+                            match_pair_index,
+                            candidate,
+                        ));
+                    return true;
+                }
+
+                let no_overlap = (|| {
+                    use std::cmp::Ordering::*;
+                    use rustc::hir::RangeEnd::*;
+
+                    let param_env = ty::ParamEnv::empty().and(test_ty);
+                    let tcx = self.hir.tcx();
+
+                    let lo = compare_const_vals(tcx, test_lo, pat_hi, param_env)?;
+                    let hi = compare_const_vals(tcx, test_hi, pat_lo, param_env)?;
+
+                    match (test_end, pat_end, lo, hi) {
+                        // pat < test
+                        (_, _, Greater, _) |
+                        (_, Excluded, Equal, _) |
+                        // pat > test
+                        (_, _, _, Less) |
+                        (Excluded, _, _, Equal) => Some(true),
+                        _ => Some(false),
+                    }
+                })();
+
+                if no_overlap == Some(true) {
+                    // Testing range does not overlap with pattern range,
+                    // so the pattern can be matched only if this test fails.
+                    resulting_candidates[1].push(candidate.clone());
+                    true
+                } else {
+                    false
+                }
+            }
+
+            (&TestKind::Range {
+                lo, hi, ty, end
+            }, &PatternKind::Constant {
+                ref value
+            }) => {
+                if self.const_range_contains(ty, lo, hi, end, value) == Some(false) {
+                    // `value` is not contained in the testing range,
+                    // so `value` can be matched only if this test fails.
+                    resulting_candidates[1].push(candidate.clone());
+                    true
+                } else {
+                    false
+                }
+            }
+
+            (&TestKind::Range { .. }, _) => false,
+
+
             (&TestKind::Eq { .. }, _) |
-            (&TestKind::Range { .. }, _) |
             (&TestKind::Len { .. }, _) => {
                 // These are all binary tests.
                 //
@@ -719,6 +812,29 @@ fn error_simplifyable<'pat>(&mut self, match_pair: &MatchPair<'pat, 'tcx>) -> !
                   "simplifyable pattern found: {:?}",
                   match_pair.pattern)
     }
+
+    fn const_range_contains(
+        &self,
+        ty: Ty<'tcx>,
+        lo: &'tcx ty::Const<'tcx>,
+        hi: &'tcx ty::Const<'tcx>,
+        end: RangeEnd,
+        value: &'tcx ty::Const<'tcx>,
+    ) -> Option<bool> {
+        use std::cmp::Ordering::*;
+
+        let param_env = ty::ParamEnv::empty().and(ty);
+        let tcx = self.hir.tcx();
+
+        let a = compare_const_vals(tcx, lo, value, param_env)?;
+        let b = compare_const_vals(tcx, value, hi, param_env)?;
+
+        match (b, end) {
+            (Less, _) |
+            (Equal, RangeEnd::Included) if a != Greater => Some(true),
+            _ => Some(false),
+        }
+    }
 }
 
 fn is_switch_ty<'tcx>(ty: Ty<'tcx>) -> bool {
index d695a64f62a08e1e2a1a1a211996e5807529cb5f..864f242a304e01b5455c968ce1db9abeca801f5f 100644 (file)
@@ -24,7 +24,7 @@
 use rustc::mir::{fmt_const_val, Field, BorrowKind, Mutability};
 use rustc::mir::{ProjectionElem, UserTypeAnnotation, UserTypeProjection, UserTypeProjections};
 use rustc::mir::interpret::{Scalar, GlobalId, ConstValue, sign_extend};
-use rustc::ty::{self, Region, TyCtxt, AdtDef, Ty};
+use rustc::ty::{self, Region, TyCtxt, AdtDef, Ty, Lift};
 use rustc::ty::subst::{Substs, Kind};
 use rustc::ty::layout::VariantIdx;
 use rustc::hir::{self, PatKind, RangeEnd};
@@ -1210,8 +1210,8 @@ fn super_fold_with<F: PatternFolder<'tcx>>(&self, folder: &mut F) -> Self {
     }
 }
 
-pub fn compare_const_vals<'a, 'tcx>(
-    tcx: TyCtxt<'a, 'tcx, 'tcx>,
+pub fn compare_const_vals<'a, 'gcx, 'tcx>(
+    tcx: TyCtxt<'a, 'gcx, 'tcx>,
     a: &'tcx ty::Const<'tcx>,
     b: &'tcx ty::Const<'tcx>,
     ty: ty::ParamEnvAnd<'tcx, Ty<'tcx>>,
@@ -1233,6 +1233,9 @@ pub fn compare_const_vals<'a, 'tcx>(
         return fallback();
     }
 
+    let tcx = tcx.global_tcx();
+    let (a, b, ty) = (a, b, ty).lift_to_tcx(tcx).unwrap();
+
     // FIXME: This should use assert_bits(ty) instead of use_bits
     // but triggers possibly bugs due to mismatching of arrays and slices
     if let (Some(a), Some(b)) = (a.to_bits(tcx, ty), b.to_bits(tcx, ty)) {
diff --git a/src/test/run-pass/mir/mir_match_test.rs b/src/test/run-pass/mir/mir_match_test.rs
new file mode 100644 (file)
index 0000000..1f96d67
--- /dev/null
@@ -0,0 +1,83 @@
+#![feature(exclusive_range_pattern)]
+
+// run-pass
+
+fn main() {
+    let incl_range = |x, b| {
+        match x {
+            0..=5 if b => 0,
+            5..=10 if b => 1,
+            1..=4 if !b => 2,
+            _ => 3,
+        }
+    };
+    assert_eq!(incl_range(3, false), 2);
+    assert_eq!(incl_range(3, true), 0);
+    assert_eq!(incl_range(5, false), 3);
+    assert_eq!(incl_range(5, true), 0);
+
+    let excl_range = |x, b| {
+        match x {
+            0..5 if b => 0,
+            5..10 if b => 1,
+            1..4 if !b => 2,
+            _ => 3,
+        }
+    };
+    assert_eq!(excl_range(3, false), 2);
+    assert_eq!(excl_range(3, true), 0);
+    assert_eq!(excl_range(5, false), 3);
+    assert_eq!(excl_range(5, true), 1);
+
+    let incl_range_vs_const = |x, b| {
+        match x {
+            0..=5 if b => 0,
+            7 => 1,
+            3 => 2,
+            _ => 3,
+        }
+    };
+    assert_eq!(incl_range_vs_const(5, false), 3);
+    assert_eq!(incl_range_vs_const(5, true), 0);
+    assert_eq!(incl_range_vs_const(3, false), 2);
+    assert_eq!(incl_range_vs_const(3, true), 0);
+    assert_eq!(incl_range_vs_const(7, false), 1);
+    assert_eq!(incl_range_vs_const(7, true), 1);
+
+    let excl_range_vs_const = |x, b| {
+        match x {
+            0..5 if b => 0,
+            7 => 1,
+            3 => 2,
+            _ => 3,
+        }
+    };
+    assert_eq!(excl_range_vs_const(5, false), 3);
+    assert_eq!(excl_range_vs_const(5, true), 3);
+    assert_eq!(excl_range_vs_const(3, false), 2);
+    assert_eq!(excl_range_vs_const(3, true), 0);
+    assert_eq!(excl_range_vs_const(7, false), 1);
+    assert_eq!(excl_range_vs_const(7, true), 1);
+
+    let const_vs_incl_range = |x, b| {
+        match x {
+            3 if b => 0,
+            5..=7 => 2,
+            1..=4 => 1,
+            _ => 3,
+        }
+    };
+    assert_eq!(const_vs_incl_range(3, false), 1);
+    assert_eq!(const_vs_incl_range(3, true), 0);
+
+    let const_vs_excl_range = |x, b| {
+        match x {
+            3 if b => 0,
+            5..7 => 2,
+            1..4 => 1,
+            _ => 3,
+        }
+    };
+    assert_eq!(const_vs_excl_range(3, false), 1);
+    assert_eq!(const_vs_excl_range(3, true), 0);
+}