]> git.lizzy.rs Git - rust.git/commitdiff
impl the proper partial order between fn types
authorNiko Matsakis <niko@alum.mit.edu>
Fri, 16 Dec 2011 21:50:22 +0000 (13:50 -0800)
committerNiko Matsakis <niko@alum.mit.edu>
Mon, 19 Dec 2011 22:07:46 +0000 (14:07 -0800)
src/comp/middle/ty.rs
src/test/compile-fail/sendfn-is-not-a-lambda.rs [new file with mode: 0644]
src/test/run-pass/sendfn-is-a-block.rs [new file with mode: 0644]

index 19a4dcdff0b52910e3f78672923a7d2660f39fa3..5e8ccaf4f9044a2704116ea90edb4efcb27f81d1 100644 (file)
@@ -1918,20 +1918,19 @@ fn unify_fn_common(cx: @ctxt, _expected: t, _actual: t,
                        actual_inputs: [arg], actual_output: t,
                        variance: variance) ->
        fn_common_res {
-        let expected_len = vec::len::<arg>(expected_inputs);
-        let actual_len = vec::len::<arg>(actual_inputs);
-        if expected_len != actual_len {
+        if !vec::same_length(expected_inputs, actual_inputs) {
             ret fn_common_res_err(ures_err(terr_arg_count));
         }
-        // TODO: as above, we should have an iter2 iterator.
 
-        let result_ins: [arg] = [];
-        let i = 0u;
-        while i < expected_len {
+        // Would use vec::map2(), but for the need to return in case of
+        // error:
+        let i = 0u, n = vec::len(expected_inputs);
+        let result_ins = [];
+        while i < n {
             let expected_input = expected_inputs[i];
             let actual_input = actual_inputs[i];
-            // Unify the result modes.
 
+            // Unify the result modes.
             let result_mode = if expected_input.mode == ast::mode_infer {
                 actual_input.mode
             } else if actual_input.mode == ast::mode_infer {
@@ -1941,6 +1940,7 @@ fn unify_fn_common(cx: @ctxt, _expected: t, _actual: t,
                     (ures_err(terr_mode_mismatch(expected_input.mode,
                                                  actual_input.mode)));
             } else { expected_input.mode };
+
             // The variance changes (flips basically) when descending
             // into arguments of function types
             let result = unify_step(
@@ -1949,11 +1949,11 @@ fn unify_fn_common(cx: @ctxt, _expected: t, _actual: t,
             alt result {
               ures_ok(rty) { result_ins += [{mode: result_mode, ty: rty}]; }
               _ { ret fn_common_res_err(result); }
-            }
+            };
             i += 1u;
         }
-        // Check the output.
 
+        // Check the output.
         let result = unify_step(cx, expected_output, actual_output, variance);
         alt result {
           ures_ok(rty) { ret fn_common_res_ok(result_ins, rty); }
@@ -1962,38 +1962,33 @@ fn unify_fn_common(cx: @ctxt, _expected: t, _actual: t,
     }
     fn unify_fn_proto(e_proto: ast::proto, a_proto: ast::proto,
                       variance: variance) -> option::t<result> {
-        fn rank(proto: ast::proto) -> int {
-            ret alt proto {
-              ast::proto_block. { 0 }
-              ast::proto_shared(_) { 1 }
-              ast::proto_send. { 2 }
-              ast::proto_bare. { 3 }
+        // Prototypes form a diamond-shaped partial order:
+        //
+        //        block
+        //        ^   ^
+        //   shared   send
+        //        ^   ^
+        //        bare
+        //
+        // where "^" means "subtype of" (forgive the abuse of the term
+        // subtype).
+        fn sub_proto(p_sub: ast::proto, p_sup: ast::proto) -> bool {
+            ret alt (p_sub, p_sup) {
+              (_, ast::proto_block.) { true }
+              (ast::proto_bare., _) { true }
+
+              // Equal prototypes (modulo sugar) are always subprotos:
+              (ast::proto_shared(_), ast::proto_shared(_)) { true }
+              (_, _) { p_sub == p_sup }
             };
         }
 
-        fn gt(e_proto: ast::proto, a_proto: ast::proto) -> bool {
-            ret rank(e_proto) > rank(a_proto);
-        }
-
-        ret if e_proto == a_proto {
-            none
-        } else if variance == invariant {
-            some(ures_err(terr_mismatch))
-        } else if variance == covariant {
-            if gt(e_proto, a_proto) {
-                some(ures_err(terr_mismatch))
-            } else {
-                none
-            }
-        } else if variance == contravariant {
-            if gt(a_proto, e_proto) {
-                some(ures_err(terr_mismatch))
-            } else {
-                none
-            }
-        } else {
-            fail
-        }
+        ret alt variance {
+          invariant. when e_proto == a_proto { none }
+          covariant. when sub_proto(a_proto, e_proto) { none }
+          contravariant. when sub_proto(e_proto, a_proto) { none }
+          _ { some(ures_err(terr_mismatch)) }
+        };
     }
     fn unify_fn(cx: @ctxt, e_proto: ast::proto, a_proto: ast::proto,
                 expected: t, actual: t, expected_inputs: [arg],
diff --git a/src/test/compile-fail/sendfn-is-not-a-lambda.rs b/src/test/compile-fail/sendfn-is-not-a-lambda.rs
new file mode 100644 (file)
index 0000000..696f237
--- /dev/null
@@ -0,0 +1,10 @@
+// error-pattern: mismatched types: expected lambda(++uint) -> uint
+
+fn test(f: lambda(uint) -> uint) -> uint {
+    ret f(22u);
+}
+
+fn main() {
+    let f = sendfn(x: uint) -> uint { ret 4u; };
+    log test(f);
+}
\ No newline at end of file
diff --git a/src/test/run-pass/sendfn-is-a-block.rs b/src/test/run-pass/sendfn-is-a-block.rs
new file mode 100644 (file)
index 0000000..3761c1e
--- /dev/null
@@ -0,0 +1,8 @@
+fn test(f: block(uint) -> uint) -> uint {
+    ret f(22u);
+}
+
+fn main() {
+    let y = test(sendfn(x: uint) -> uint { ret 4u * x; });
+    assert y == 88u;
+}
\ No newline at end of file