]> git.lizzy.rs Git - rust.git/blob - library/core/src/slice/rotate.rs
Auto merge of #91962 - matthiaskrgr:rollup-2g082jw, r=matthiaskrgr
[rust.git] / library / core / src / slice / rotate.rs
1 use crate::cmp;
2 use crate::mem::{self, MaybeUninit};
3 use crate::ptr;
4
5 /// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
6 /// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
7 /// right.
8 ///
9 /// # Safety
10 ///
11 /// The specified range must be valid for reading and writing.
12 ///
13 /// # Algorithm
14 ///
15 /// Algorithm 1 is used for small values of `left + right` or for large `T`. The elements are moved
16 /// into their final positions one at a time starting at `mid - left` and advancing by `right` steps
17 /// modulo `left + right`, such that only one temporary is needed. Eventually, we arrive back at
18 /// `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps skipped over
19 /// elements. For example:
20 /// ```text
21 /// left = 10, right = 6
22 /// the `^` indicates an element in its final place
23 /// 6 7 8 9 10 11 12 13 14 15 . 0 1 2 3 4 5
24 /// after using one step of the above algorithm (The X will be overwritten at the end of the round,
25 /// and 12 is stored in a temporary):
26 /// X 7 8 9 10 11 6 13 14 15 . 0 1 2 3 4 5
27 ///               ^
28 /// after using another step (now 2 is in the temporary):
29 /// X 7 8 9 10 11 6 13 14 15 . 0 1 12 3 4 5
30 ///               ^                 ^
31 /// after the third step (the steps wrap around, and 8 is in the temporary):
32 /// X 7 2 9 10 11 6 13 14 15 . 0 1 12 3 4 5
33 ///     ^         ^                 ^
34 /// after 7 more steps, the round ends with the temporary 0 getting put in the X:
35 /// 0 7 2 9 4 11 6 13 8 15 . 10 1 12 3 14 5
36 /// ^   ^   ^    ^    ^       ^    ^    ^
37 /// ```
38 /// Fortunately, the number of skipped over elements between finalized elements is always equal, so
39 /// we can just offset our starting position and do more rounds (the total number of rounds is the
40 /// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
41 /// only once.
42 ///
43 /// Algorithm 2 is used if `left + right` is large but `min(left, right)` is small enough to
44 /// fit onto a stack buffer. The `min(left, right)` elements are copied onto the buffer, `memmove`
45 /// is applied to the others, and the ones on the buffer are moved back into the hole on the
46 /// opposite side of where they originated.
47 ///
48 /// Algorithms that can be vectorized outperform the above once `left + right` becomes large enough.
49 /// Algorithm 1 can be vectorized by chunking and performing many rounds at once, but there are too
50 /// few rounds on average until `left + right` is enormous, and the worst case of a single
51 /// round is always there. Instead, algorithm 3 utilizes repeated swapping of
52 /// `min(left, right)` elements until a smaller rotate problem is left.
53 ///
54 /// ```text
55 /// left = 11, right = 4
56 /// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
57 ///                  ^  ^  ^  ^   ^ ^ ^ ^ swapping the right most elements with elements to the left
58 /// [4 5 6 7 8 9 10 . 0 1 2 3] 11 12 13 14
59 ///        ^ ^ ^  ^   ^ ^ ^ ^ swapping these
60 /// [4 5 6 . 0 1 2 3] 7 8 9 10 11 12 13 14
61 /// we cannot swap any more, but a smaller rotation problem is left to solve
62 /// ```
63 /// when `left < right` the swapping happens from the left instead.
64 pub unsafe fn ptr_rotate<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
65     type BufType = [usize; 32];
66     if mem::size_of::<T>() == 0 {
67         return;
68     }
69     loop {
70         // N.B. the below algorithms can fail if these cases are not checked
71         if (right == 0) || (left == 0) {
72             return;
73         }
74         if (left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()) {
75             // Algorithm 1
76             // Microbenchmarks indicate that the average performance for random shifts is better all
77             // the way until about `left + right == 32`, but the worst case performance breaks even
78             // around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
79             // `usize`s, this algorithm also outperforms other algorithms.
80             // SAFETY: callers must ensure `mid - left` is valid for reading and writing.
81             let x = unsafe { mid.sub(left) };
82             // beginning of first round
83             // SAFETY: see previous comment.
84             let mut tmp: T = unsafe { x.read() };
85             let mut i = right;
86             // `gcd` can be found before hand by calculating `gcd(left + right, right)`,
87             // but it is faster to do one loop which calculates the gcd as a side effect, then
88             // doing the rest of the chunk
89             let mut gcd = right;
90             // benchmarks reveal that it is faster to swap temporaries all the way through instead
91             // of reading one temporary once, copying backwards, and then writing that temporary at
92             // the very end. This is possibly due to the fact that swapping or replacing temporaries
93             // uses only one memory address in the loop instead of needing to manage two.
94             loop {
95                 // [long-safety-expl]
96                 // SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
97                 // writing.
98                 //
99                 // - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
100                 // - `i <= left+right-1` is always true
101                 //   - if `i < left`, `right` is added so `i < left+right` and on the next
102                 //     iteration `left` is removed from `i` so it doesn't go further
103                 //   - if `i >= left`, `left` is removed immediately and so it doesn't go further.
104                 // - overflows cannot happen for `i` since the function's safety contract ask for
105                 //   `mid+right-1 = x+left+right` to be valid for writing
106                 // - underflows cannot happen because `i` must be bigger or equal to `left` for
107                 //   a subtraction of `left` to happen.
108                 //
109                 // So `x+i` is valid for reading and writing if the caller respected the contract
110                 tmp = unsafe { x.add(i).replace(tmp) };
111                 // instead of incrementing `i` and then checking if it is outside the bounds, we
112                 // check if `i` will go outside the bounds on the next increment. This prevents
113                 // any wrapping of pointers or `usize`.
114                 if i >= left {
115                     i -= left;
116                     if i == 0 {
117                         // end of first round
118                         // SAFETY: tmp has been read from a valid source and x is valid for writing
119                         // according to the caller.
120                         unsafe { x.write(tmp) };
121                         break;
122                     }
123                     // this conditional must be here if `left + right >= 15`
124                     if i < gcd {
125                         gcd = i;
126                     }
127                 } else {
128                     i += right;
129                 }
130             }
131             // finish the chunk with more rounds
132             for start in 1..gcd {
133                 // SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
134                 // reading and writing as per the function's safety contract, see [long-safety-expl]
135                 // above
136                 tmp = unsafe { x.add(start).read() };
137                 // [safety-expl-addition]
138                 //
139                 // Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
140                 // greatest common divisor of `(left+right, right)` means that `left = right` so
141                 // `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
142                 // according to the function's safety contract.
143                 i = start + right;
144                 loop {
145                     // SAFETY: see [long-safety-expl] and [safety-expl-addition]
146                     tmp = unsafe { x.add(i).replace(tmp) };
147                     if i >= left {
148                         i -= left;
149                         if i == start {
150                             // SAFETY: see [long-safety-expl] and [safety-expl-addition]
151                             unsafe { x.add(start).write(tmp) };
152                             break;
153                         }
154                     } else {
155                         i += right;
156                     }
157                 }
158             }
159             return;
160         // `T` is not a zero-sized type, so it's okay to divide by its size.
161         } else if cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>() {
162             // Algorithm 2
163             // The `[T; 0]` here is to ensure this is appropriately aligned for T
164             let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
165             let buf = rawarray.as_mut_ptr() as *mut T;
166             // SAFETY: `mid-left <= mid-left+right < mid+right`
167             let dim = unsafe { mid.sub(left).add(right) };
168             if left <= right {
169                 // SAFETY:
170                 //
171                 // 1) The `else if` condition about the sizes ensures `[mid-left; left]` will fit in
172                 //    `buf` without overflow and `buf` was created just above and so cannot be
173                 //    overlapped with any value of `[mid-left; left]`
174                 // 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
175                 //    about overlaps here.
176                 // 3) The `if` condition about `left <= right` ensures writing `left` elements to
177                 //    `dim = mid-left+right` is valid because:
178                 //    - `buf` is valid and `left` elements were written in it in 1)
179                 //    - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
180                 unsafe {
181                     // 1)
182                     ptr::copy_nonoverlapping(mid.sub(left), buf, left);
183                     // 2)
184                     ptr::copy(mid, mid.sub(left), right);
185                     // 3)
186                     ptr::copy_nonoverlapping(buf, dim, left);
187                 }
188             } else {
189                 // SAFETY: same reasoning as above but with `left` and `right` reversed
190                 unsafe {
191                     ptr::copy_nonoverlapping(mid, buf, right);
192                     ptr::copy(mid.sub(left), dim, left);
193                     ptr::copy_nonoverlapping(buf, mid.sub(left), right);
194                 }
195             }
196             return;
197         } else if left >= right {
198             // Algorithm 3
199             // There is an alternate way of swapping that involves finding where the last swap
200             // of this algorithm would be, and swapping using that last chunk instead of swapping
201             // adjacent chunks like this algorithm is doing, but this way is still faster.
202             loop {
203                 // SAFETY:
204                 // `left >= right` so `[mid-right, mid+right)` is valid for reading and writing
205                 // Subtracting `right` from `mid` each turn is counterbalanced by the addition and
206                 // check after it.
207                 unsafe {
208                     ptr::swap_nonoverlapping(mid.sub(right), mid, right);
209                     mid = mid.sub(right);
210                 }
211                 left -= right;
212                 if left < right {
213                     break;
214                 }
215             }
216         } else {
217             // Algorithm 3, `left < right`
218             loop {
219                 // SAFETY: `[mid-left, mid+left)` is valid for reading and writing because
220                 // `left < right` so `mid+left < mid+right`.
221                 // Adding `left` to `mid` each turn is counterbalanced by the subtraction and check
222                 // after it.
223                 unsafe {
224                     ptr::swap_nonoverlapping(mid.sub(left), mid, left);
225                     mid = mid.add(left);
226                 }
227                 right -= left;
228                 if right < left {
229                     break;
230                 }
231             }
232         }
233     }
234 }