]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/mod.rs
Auto merge of #38561 - nagisa:rdrandseed, r=alexcrichton
[rust.git] / src / librand / distributions / mod.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Sampling from random distributions.
12 //!
13 //! This is a generalization of `Rand` to allow parameters to control the
14 //! exact properties of the generated values, e.g. the mean and standard
15 //! deviation of a normal distribution. The `Sample` trait is the most
16 //! general, and allows for generating values that change some state
17 //! internally. The `IndependentSample` trait is for generating values
18 //! that do not need to record state.
19
20 use core::fmt;
21
22 #[cfg(not(test))] // only necessary for no_std
23 use core::num::Float;
24
25 use core::marker::PhantomData;
26
27 use {Rand, Rng};
28
29 pub use self::range::Range;
30 pub use self::gamma::{ChiSquared, FisherF, Gamma, StudentT};
31 pub use self::normal::{LogNormal, Normal};
32 pub use self::exponential::Exp;
33
34 pub mod range;
35 pub mod gamma;
36 pub mod normal;
37 pub mod exponential;
38
39 /// Types that can be used to create a random instance of `Support`.
40 pub trait Sample<Support> {
41     /// Generate a random value of `Support`, using `rng` as the
42     /// source of randomness.
43     fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
44 }
45
46 /// `Sample`s that do not require keeping track of state.
47 ///
48 /// Since no state is recorded, each sample is (statistically)
49 /// independent of all others, assuming the `Rng` used has this
50 /// property.
51 // FIXME maybe having this separate is overkill (the only reason is to
52 // take &self rather than &mut self)? or maybe this should be the
53 // trait called `Sample` and the other should be `DependentSample`.
54 pub trait IndependentSample<Support>: Sample<Support> {
55     /// Generate a random value.
56     fn ind_sample<R: Rng>(&self, &mut R) -> Support;
57 }
58
59 /// A wrapper for generating types that implement `Rand` via the
60 /// `Sample` & `IndependentSample` traits.
61 pub struct RandSample<Sup> {
62     _marker: PhantomData<Sup>,
63 }
64
65 impl<Sup> RandSample<Sup> {
66     pub fn new() -> RandSample<Sup> {
67         RandSample { _marker: PhantomData }
68     }
69 }
70
71 impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
72     fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup {
73         self.ind_sample(rng)
74     }
75 }
76
77 impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
78     fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
79         rng.gen()
80     }
81 }
82
83 impl<Sup> fmt::Debug for RandSample<Sup> {
84     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85         f.pad("RandSample { .. }")
86     }
87 }
88
89 /// A value with a particular weight for use with `WeightedChoice`.
90 pub struct Weighted<T> {
91     /// The numerical weight of this item
92     pub weight: usize,
93     /// The actual item which is being weighted
94     pub item: T,
95 }
96
97 impl<T: fmt::Debug> fmt::Debug for Weighted<T> {
98     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99         f.debug_struct("Weighted")
100          .field("weight", &self.weight)
101          .field("item", &self.item)
102          .finish()
103     }
104 }
105
106 /// A distribution that selects from a finite collection of weighted items.
107 ///
108 /// Each item has an associated weight that influences how likely it
109 /// is to be chosen: higher weight is more likely.
110 ///
111 /// The `Clone` restriction is a limitation of the `Sample` and
112 /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
113 /// all `T`, as is `usize`, so one can store references or indices into
114 /// another vector.
115 pub struct WeightedChoice<'a, T: 'a> {
116     items: &'a mut [Weighted<T>],
117     weight_range: Range<usize>,
118 }
119
120 impl<'a, T: Clone> WeightedChoice<'a, T> {
121     /// Create a new `WeightedChoice`.
122     ///
123     /// Panics if:
124     /// - `v` is empty
125     /// - the total weight is 0
126     /// - the total weight is larger than a `usize` can contain.
127     pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
128         // strictly speaking, this is subsumed by the total weight == 0 case
129         assert!(!items.is_empty(),
130                 "WeightedChoice::new called with no items");
131
132         let mut running_total = 0_usize;
133
134         // we convert the list from individual weights to cumulative
135         // weights so we can binary search. This *could* drop elements
136         // with weight == 0 as an optimisation.
137         for item in &mut *items {
138             running_total = match running_total.checked_add(item.weight) {
139                 Some(n) => n,
140                 None => {
141                     panic!("WeightedChoice::new called with a total weight larger than a usize \
142                             can contain")
143                 }
144             };
145
146             item.weight = running_total;
147         }
148         assert!(running_total != 0,
149                 "WeightedChoice::new called with a total weight of 0");
150
151         WeightedChoice {
152             items: items,
153             // we're likely to be generating numbers in this range
154             // relatively often, so might as well cache it
155             weight_range: Range::new(0, running_total),
156         }
157     }
158 }
159
160 impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
161     fn sample<R: Rng>(&mut self, rng: &mut R) -> T {
162         self.ind_sample(rng)
163     }
164 }
165
166 impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
167     fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
168         // we want to find the first element that has cumulative
169         // weight > sample_weight, which we do by binary since the
170         // cumulative weights of self.items are sorted.
171
172         // choose a weight in [0, total_weight)
173         let sample_weight = self.weight_range.ind_sample(rng);
174
175         // short circuit when it's the first item
176         if sample_weight < self.items[0].weight {
177             return self.items[0].item.clone();
178         }
179
180         let mut idx = 0;
181         let mut modifier = self.items.len();
182
183         // now we know that every possibility has an element to the
184         // left, so we can just search for the last element that has
185         // cumulative weight <= sample_weight, then the next one will
186         // be "it". (Note that this greatest element will never be the
187         // last element of the vector, since sample_weight is chosen
188         // in [0, total_weight) and the cumulative weight of the last
189         // one is exactly the total weight.)
190         while modifier > 1 {
191             let i = idx + modifier / 2;
192             if self.items[i].weight <= sample_weight {
193                 // we're small, so look to the right, but allow this
194                 // exact element still.
195                 idx = i;
196                 // we need the `/ 2` to round up otherwise we'll drop
197                 // the trailing elements when `modifier` is odd.
198                 modifier += 1;
199             } else {
200                 // otherwise we're too big, so go left. (i.e. do
201                 // nothing)
202             }
203             modifier /= 2;
204         }
205         return self.items[idx + 1].item.clone();
206     }
207 }
208
209 impl<'a, T: fmt::Debug> fmt::Debug for WeightedChoice<'a, T> {
210     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
211         f.debug_struct("WeightedChoice")
212          .field("items", &self.items)
213          .field("weight_range", &self.weight_range)
214          .finish()
215     }
216 }
217
218 mod ziggurat_tables;
219
220 /// Sample a random number using the Ziggurat method (specifically the
221 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
222 /// directly from the paper:
223 ///
224 /// * `rng`: source of randomness
225 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
226 /// * `X`: the $x_i$ abscissae.
227 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
228 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
229 /// * `pdf`: the probability density function
230 /// * `zero_case`: manual sampling from the tail when we chose the
231 ///    bottom box (i.e. i == 0)
232 // the perf improvement (25-50%) is definitely worth the extra code
233 // size from force-inlining.
234 #[inline(always)]
235 fn ziggurat<R: Rng, P, Z>(rng: &mut R,
236                           symmetric: bool,
237                           x_tab: ziggurat_tables::ZigTable,
238                           f_tab: ziggurat_tables::ZigTable,
239                           mut pdf: P,
240                           mut zero_case: Z)
241                           -> f64
242     where P: FnMut(f64) -> f64,
243           Z: FnMut(&mut R, f64) -> f64
244 {
245     const SCALE: f64 = (1u64 << 53) as f64;
246     loop {
247         // reimplement the f64 generation as an optimisation suggested
248         // by the Doornik paper: we have a lot of precision-space
249         // (i.e. there are 11 bits of the 64 of a u64 to use after
250         // creating a f64), so we might as well reuse some to save
251         // generating a whole extra random number. (Seems to be 15%
252         // faster.)
253         //
254         // This unfortunately misses out on the benefits of direct
255         // floating point generation if an RNG like dSMFT is
256         // used. (That is, such RNGs create floats directly, highly
257         // efficiently and overload next_f32/f64, so by not calling it
258         // this may be slower than it would be otherwise.)
259         // FIXME: investigate/optimise for the above.
260         let bits: u64 = rng.gen();
261         let i = (bits & 0xff) as usize;
262         let f = (bits >> 11) as f64 / SCALE;
263
264         // u is either U(-1, 1) or U(0, 1) depending on if this is a
265         // symmetric distribution or not.
266         let u = if symmetric { 2.0 * f - 1.0 } else { f };
267         let x = u * x_tab[i];
268
269         let test_x = if symmetric { x.abs() } else { x };
270
271         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
272         if test_x < x_tab[i + 1] {
273             return x;
274         }
275         if i == 0 {
276             return zero_case(rng, u);
277         }
278         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
279         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
280             return x;
281         }
282     }
283 }
284
285 #[cfg(test)]
286 mod tests {
287     use {Rand, Rng};
288     use super::{IndependentSample, RandSample, Sample, Weighted, WeightedChoice};
289
290     #[derive(PartialEq, Debug)]
291     struct ConstRand(usize);
292     impl Rand for ConstRand {
293         fn rand<R: Rng>(_: &mut R) -> ConstRand {
294             ConstRand(0)
295         }
296     }
297
298     // 0, 1, 2, 3, ...
299     struct CountingRng {
300         i: u32,
301     }
302     impl Rng for CountingRng {
303         fn next_u32(&mut self) -> u32 {
304             self.i += 1;
305             self.i - 1
306         }
307         fn next_u64(&mut self) -> u64 {
308             self.next_u32() as u64
309         }
310     }
311
312     #[test]
313     fn test_rand_sample() {
314         let mut rand_sample = RandSample::<ConstRand>::new();
315
316         assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
317         assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
318     }
319     #[test]
320     #[rustfmt_skip]
321     fn test_weighted_choice() {
322         // this makes assumptions about the internal implementation of
323         // WeightedChoice, specifically: it doesn't reorder the items,
324         // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
325         // 1, internally; modulo a modulo operation).
326
327         macro_rules! t {
328             ($items:expr, $expected:expr) => {{
329                 let mut items = $items;
330                 let wc = WeightedChoice::new(&mut items);
331                 let expected = $expected;
332
333                 let mut rng = CountingRng { i: 0 };
334
335                 for &val in &expected {
336                     assert_eq!(wc.ind_sample(&mut rng), val)
337                 }
338             }}
339         }
340
341         t!(vec![Weighted { weight: 1, item: 10 }],
342            [10]);
343
344         // skip some
345         t!(vec![Weighted { weight: 0, item: 20 },
346                 Weighted { weight: 2, item: 21 },
347                 Weighted { weight: 0, item: 22 },
348                 Weighted { weight: 1, item: 23 }],
349            [21, 21, 23]);
350
351         // different weights
352         t!(vec![Weighted { weight: 4, item: 30 },
353                 Weighted { weight: 3, item: 31 }],
354            [30, 30, 30, 30, 31, 31, 31]);
355
356         // check that we're binary searching
357         // correctly with some vectors of odd
358         // length.
359         t!(vec![Weighted { weight: 1, item: 40 },
360                 Weighted { weight: 1, item: 41 },
361                 Weighted { weight: 1, item: 42 },
362                 Weighted { weight: 1, item: 43 },
363                 Weighted { weight: 1, item: 44 }],
364            [40, 41, 42, 43, 44]);
365         t!(vec![Weighted { weight: 1, item: 50 },
366                 Weighted { weight: 1, item: 51 },
367                 Weighted { weight: 1, item: 52 },
368                 Weighted { weight: 1, item: 53 },
369                 Weighted { weight: 1, item: 54 },
370                 Weighted { weight: 1, item: 55 },
371                 Weighted { weight: 1, item: 56 }],
372            [50, 51, 52, 53, 54, 55, 56]);
373     }
374
375     #[test]
376     #[should_panic]
377     fn test_weighted_choice_no_items() {
378         WeightedChoice::<isize>::new(&mut []);
379     }
380     #[test]
381     #[should_panic]
382     #[rustfmt_skip]
383     fn test_weighted_choice_zero_weight() {
384         WeightedChoice::new(&mut [Weighted { weight: 0, item: 0 },
385                                   Weighted { weight: 0, item: 1 }]);
386     }
387     #[test]
388     #[should_panic]
389     #[rustfmt_skip]
390     fn test_weighted_choice_weight_overflows() {
391         let x = (!0) as usize / 2; // x + x + 2 is the overflow
392         WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
393                                   Weighted { weight: 1, item: 1 },
394                                   Weighted { weight: x, item: 2 },
395                                   Weighted { weight: 1, item: 3 }]);
396     }
397 }