]> git.lizzy.rs Git - rust.git/commitdiff
Fix overflow checking when multiplying two i64
authorbjorn3 <bjorn3@users.noreply.github.com>
Fri, 16 Apr 2021 12:36:07 +0000 (14:36 +0200)
committerbjorn3 <bjorn3@users.noreply.github.com>
Fri, 16 Apr 2021 12:36:07 +0000 (14:36 +0200)
Fixes #1162

example/std_example.rs
src/num.rs

index 437b7726980bf02633868fc6f359ccf13cd64643..77ba72df8ef371ddbc1163c7b0f6f0aca38b0d7e 100644 (file)
@@ -48,6 +48,8 @@ fn main() {
     assert_eq!(2.3f32.copysign(-1.0), -2.3f32);
     println!("{}", 2.3f32.powf(2.0));
 
+    assert_eq!(i64::MAX.checked_mul(2), None);
+
     assert_eq!(-128i8, (-128i8).saturating_sub(1));
     assert_eq!(127i8, 127i8.saturating_sub(-128));
     assert_eq!(-128i8, (-128i8).saturating_add(-128));
index 2ebf30da2d8ba930e973995dac9bc173eea636da..b6d378a5fe10ae59b53ca8c4656923902250fc05 100644 (file)
@@ -271,14 +271,17 @@ pub(crate) fn codegen_checked_int_binop<'tcx>(
                         let val_hi = fx.bcx.ins().umulhi(lhs, rhs);
                         fx.bcx.ins().icmp_imm(IntCC::NotEqual, val_hi, 0)
                     } else {
+                        // Based on LLVM's instruction sequence for compiling
+                        // a.checked_mul(b).is_some() to riscv64gc:
+                        // mulh    a2, a0, a1
+                        // mul     a0, a0, a1
+                        // srai    a0, a0, 63
+                        // xor     a0, a0, a2
+                        // snez    a0, a0
                         let val_hi = fx.bcx.ins().smulhi(lhs, rhs);
-                        let not_all_zero = fx.bcx.ins().icmp_imm(IntCC::NotEqual, val_hi, 0);
-                        let not_all_ones = fx.bcx.ins().icmp_imm(
-                            IntCC::NotEqual,
-                            val_hi,
-                            u64::try_from((1u128 << ty.bits()) - 1).unwrap() as i64,
-                        );
-                        fx.bcx.ins().band(not_all_zero, not_all_ones)
+                        let val_sign = fx.bcx.ins().sshr_imm(val, i64::from(ty.bits() - 1));
+                        let xor = fx.bcx.ins().bxor(val_hi, val_sign);
+                        fx.bcx.ins().icmp_imm(IntCC::NotEqual, xor, 0)
                     };
                     (val, has_overflow)
                 }