]> git.lizzy.rs Git - rust.git/commitdiff
Fix mod_inv termination for the last iteration
authorSimonas Kazlauskas <git@kazlauskas.me>
Sat, 22 Oct 2022 00:03:48 +0000 (03:03 +0300)
committerSimonas Kazlauskas <git@kazlauskas.me>
Sat, 22 Oct 2022 00:46:48 +0000 (03:46 +0300)
On usize=u64 platforms, the 4th iteration would overflow the `mod_gate`
back to 0. Similarly for usize=u32 platforms, the 3rd iteration would
overflow much the same way.

I tested various approaches to resolving this, including approaches with
`saturating_mul` and `widening_mul` to a double usize. Turns out LLVM
likes `mul_with_overflow` the best. In fact now, that LLVM can see the
iteration count is limited, it will happily unroll the loop into a nice
linear sequence.

You will also notice that the code around the loop got simplified
somewhat. Now that LLVM is handling the loop nicely, there isn’t any
more reasons to manually unroll the first iteration out of the loop
(though looking at the code today I’m not sure all that complexity was
necessary in the first place).

Fixes #103361

library/core/src/ptr/mod.rs
library/core/tests/ptr.rs

index 8e2bad35993dfb6b9552d45fe38588647d2fbddd..8a9bcf38b906cf4c19ce11215d1287d48f95c0ce 100644 (file)
@@ -1571,8 +1571,8 @@ pub(crate) unsafe fn align_offset<T: Sized>(p: *const T, a: usize) -> usize {
     // FIXME(#75598): Direct use of these intrinsics improves codegen significantly at opt-level <=
     // 1, where the method versions of these operations are not inlined.
     use intrinsics::{
-        cttz_nonzero, exact_div, unchecked_rem, unchecked_shl, unchecked_shr, unchecked_sub,
-        wrapping_add, wrapping_mul, wrapping_sub,
+        cttz_nonzero, exact_div, mul_with_overflow, unchecked_rem, unchecked_shl, unchecked_shr,
+        unchecked_sub, wrapping_add, wrapping_mul, wrapping_sub,
     };
 
     /// Calculate multiplicative modular inverse of `x` modulo `m`.
@@ -1592,36 +1592,38 @@ unsafe fn mod_inv(x: usize, m: usize) -> usize {
         const INV_TABLE_MOD_16: [u8; 8] = [1, 11, 13, 7, 9, 3, 5, 15];
         /// Modulo for which the `INV_TABLE_MOD_16` is intended.
         const INV_TABLE_MOD: usize = 16;
-        /// INV_TABLE_MOD²
-        const INV_TABLE_MOD_SQUARED: usize = INV_TABLE_MOD * INV_TABLE_MOD;
 
-        let table_inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
         // SAFETY: `m` is required to be a power-of-two, hence non-zero.
         let m_minus_one = unsafe { unchecked_sub(m, 1) };
-        if m <= INV_TABLE_MOD {
-            table_inverse & m_minus_one
-        } else {
-            // We iterate "up" using the following formula:
-            //
-            // $$ xy ≡ 1 (mod 2ⁿ) → xy (2 - xy) ≡ 1 (mod 2²ⁿ) $$
+        let mut inverse = INV_TABLE_MOD_16[(x & (INV_TABLE_MOD - 1)) >> 1] as usize;
+        let mut mod_gate = INV_TABLE_MOD;
+        // We iterate "up" using the following formula:
+        //
+        // $$ xy ≡ 1 (mod 2ⁿ) → xy (2 - xy) ≡ 1 (mod 2²ⁿ) $$
+        //
+        // This application needs to be applied at least until `2²ⁿ ≥ m`, at which point we can
+        // finally reduce the computation to our desired `m` by taking `inverse mod m`.
+        //
+        // This computation is `O(log log m)`, which is to say, that on 64-bit machines this loop
+        // will always finish in at most 4 iterations.
+        loop {
+            // y = y * (2 - xy) mod n
             //
-            // until 2²ⁿ ≥ m. Then we can reduce to our desired `m` by taking the result `mod m`.
-            let mut inverse = table_inverse;
-            let mut going_mod = INV_TABLE_MOD_SQUARED;
-            loop {
-                // y = y * (2 - xy) mod n
-                //
-                // Note, that we use wrapping operations here intentionally – the original formula
-                // uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
-                // usize::MAX` instead, because we take the result `mod n` at the end
-                // anyway.
-                inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
-                if going_mod >= m {
-                    return inverse & m_minus_one;
-                }
-                going_mod = wrapping_mul(going_mod, going_mod);
+            // Note, that we use wrapping operations here intentionally – the original formula
+            // uses e.g., subtraction `mod n`. It is entirely fine to do them `mod
+            // usize::MAX` instead, because we take the result `mod n` at the end
+            // anyway.
+            if mod_gate >= m {
+                break;
+            }
+            inverse = wrapping_mul(inverse, wrapping_sub(2usize, wrapping_mul(x, inverse)));
+            let (new_gate, overflow) = mul_with_overflow(mod_gate, mod_gate);
+            if overflow {
+                break;
             }
+            mod_gate = new_gate;
         }
+        inverse & m_minus_one
     }
 
     let addr = p.addr();
index 97a369810056dceefcdef14ad542fe4a0e15fd93..0977980ba47bf0e243c74adb4955907f67fc50b1 100644 (file)
@@ -455,6 +455,18 @@ unsafe fn test_stride<T>(ptr: *const T, align: usize) -> bool {
     assert!(!x);
 }
 
+#[test]
+fn align_offset_issue_103361() {
+    #[cfg(target_pointer_width = "64")]
+    const SIZE: usize = 1 << 47;
+    #[cfg(target_pointer_width = "32")]
+    const SIZE: usize = 1 << 30;
+    #[cfg(target_pointer_width = "16")]
+    const SIZE: usize = 1 << 13;
+    struct HugeSize([u8; SIZE - 1]);
+    let _ = (SIZE as *const HugeSize).align_offset(SIZE);
+}
+
 #[test]
 fn offset_from() {
     let mut a = [0; 5];