]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/mod.rs
Only retain external static symbols across LTO
[rust.git] / src / librand / distributions / mod.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Sampling from random distributions.
12 //!
13 //! This is a generalization of `Rand` to allow parameters to control the
14 //! exact properties of the generated values, e.g. the mean and standard
15 //! deviation of a normal distribution. The `Sample` trait is the most
16 //! general, and allows for generating values that change some state
17 //! internally. The `IndependentSample` trait is for generating values
18 //! that do not need to record state.
19
20 use core::num::Float;
21 use core::marker::PhantomData;
22
23 use {Rng, Rand};
24
25 pub use self::range::Range;
26 pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
27 pub use self::normal::{Normal, LogNormal};
28 pub use self::exponential::Exp;
29
30 pub mod range;
31 pub mod gamma;
32 pub mod normal;
33 pub mod exponential;
34
35 /// Types that can be used to create a random instance of `Support`.
36 pub trait Sample<Support> {
37     /// Generate a random value of `Support`, using `rng` as the
38     /// source of randomness.
39     fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
40 }
41
42 /// `Sample`s that do not require keeping track of state.
43 ///
44 /// Since no state is recorded, each sample is (statistically)
45 /// independent of all others, assuming the `Rng` used has this
46 /// property.
47 // FIXME maybe having this separate is overkill (the only reason is to
48 // take &self rather than &mut self)? or maybe this should be the
49 // trait called `Sample` and the other should be `DependentSample`.
50 pub trait IndependentSample<Support>: Sample<Support> {
51     /// Generate a random value.
52     fn ind_sample<R: Rng>(&self, &mut R) -> Support;
53 }
54
55 /// A wrapper for generating types that implement `Rand` via the
56 /// `Sample` & `IndependentSample` traits.
57 pub struct RandSample<Sup> {
58     _marker: PhantomData<Sup>,
59 }
60
61 impl<Sup> RandSample<Sup> {
62     pub fn new() -> RandSample<Sup> {
63         RandSample { _marker: PhantomData }
64     }
65 }
66
67 impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
68     fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup {
69         self.ind_sample(rng)
70     }
71 }
72
73 impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
74     fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
75         rng.gen()
76     }
77 }
78
79 /// A value with a particular weight for use with `WeightedChoice`.
80 pub struct Weighted<T> {
81     /// The numerical weight of this item
82     pub weight: usize,
83     /// The actual item which is being weighted
84     pub item: T,
85 }
86
87 /// A distribution that selects from a finite collection of weighted items.
88 ///
89 /// Each item has an associated weight that influences how likely it
90 /// is to be chosen: higher weight is more likely.
91 ///
92 /// The `Clone` restriction is a limitation of the `Sample` and
93 /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
94 /// all `T`, as is `usize`, so one can store references or indices into
95 /// another vector.
96 pub struct WeightedChoice<'a, T: 'a> {
97     items: &'a mut [Weighted<T>],
98     weight_range: Range<usize>,
99 }
100
101 impl<'a, T: Clone> WeightedChoice<'a, T> {
102     /// Create a new `WeightedChoice`.
103     ///
104     /// Panics if:
105     /// - `v` is empty
106     /// - the total weight is 0
107     /// - the total weight is larger than a `usize` can contain.
108     pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
109         // strictly speaking, this is subsumed by the total weight == 0 case
110         assert!(!items.is_empty(),
111                 "WeightedChoice::new called with no items");
112
113         let mut running_total = 0_usize;
114
115         // we convert the list from individual weights to cumulative
116         // weights so we can binary search. This *could* drop elements
117         // with weight == 0 as an optimisation.
118         for item in &mut *items {
119             running_total = match running_total.checked_add(item.weight) {
120                 Some(n) => n,
121                 None => {
122                     panic!("WeightedChoice::new called with a total weight larger than a usize \
123                             can contain")
124                 }
125             };
126
127             item.weight = running_total;
128         }
129         assert!(running_total != 0,
130                 "WeightedChoice::new called with a total weight of 0");
131
132         WeightedChoice {
133             items: items,
134             // we're likely to be generating numbers in this range
135             // relatively often, so might as well cache it
136             weight_range: Range::new(0, running_total),
137         }
138     }
139 }
140
141 impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
142     fn sample<R: Rng>(&mut self, rng: &mut R) -> T {
143         self.ind_sample(rng)
144     }
145 }
146
147 impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
148     fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
149         // we want to find the first element that has cumulative
150         // weight > sample_weight, which we do by binary since the
151         // cumulative weights of self.items are sorted.
152
153         // choose a weight in [0, total_weight)
154         let sample_weight = self.weight_range.ind_sample(rng);
155
156         // short circuit when it's the first item
157         if sample_weight < self.items[0].weight {
158             return self.items[0].item.clone();
159         }
160
161         let mut idx = 0;
162         let mut modifier = self.items.len();
163
164         // now we know that every possibility has an element to the
165         // left, so we can just search for the last element that has
166         // cumulative weight <= sample_weight, then the next one will
167         // be "it". (Note that this greatest element will never be the
168         // last element of the vector, since sample_weight is chosen
169         // in [0, total_weight) and the cumulative weight of the last
170         // one is exactly the total weight.)
171         while modifier > 1 {
172             let i = idx + modifier / 2;
173             if self.items[i].weight <= sample_weight {
174                 // we're small, so look to the right, but allow this
175                 // exact element still.
176                 idx = i;
177                 // we need the `/ 2` to round up otherwise we'll drop
178                 // the trailing elements when `modifier` is odd.
179                 modifier += 1;
180             } else {
181                 // otherwise we're too big, so go left. (i.e. do
182                 // nothing)
183             }
184             modifier /= 2;
185         }
186         return self.items[idx + 1].item.clone();
187     }
188 }
189
190 mod ziggurat_tables;
191
192 /// Sample a random number using the Ziggurat method (specifically the
193 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
194 /// directly from the paper:
195 ///
196 /// * `rng`: source of randomness
197 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
198 /// * `X`: the $x_i$ abscissae.
199 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
200 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
201 /// * `pdf`: the probability density function
202 /// * `zero_case`: manual sampling from the tail when we chose the
203 ///    bottom box (i.e. i == 0)
204 // the perf improvement (25-50%) is definitely worth the extra code
205 // size from force-inlining.
206 #[inline(always)]
207 fn ziggurat<R: Rng, P, Z>(rng: &mut R,
208                           symmetric: bool,
209                           x_tab: ziggurat_tables::ZigTable,
210                           f_tab: ziggurat_tables::ZigTable,
211                           mut pdf: P,
212                           mut zero_case: Z)
213                           -> f64
214     where P: FnMut(f64) -> f64,
215           Z: FnMut(&mut R, f64) -> f64
216 {
217     const SCALE: f64 = (1u64 << 53) as f64;
218     loop {
219         // reimplement the f64 generation as an optimisation suggested
220         // by the Doornik paper: we have a lot of precision-space
221         // (i.e. there are 11 bits of the 64 of a u64 to use after
222         // creating a f64), so we might as well reuse some to save
223         // generating a whole extra random number. (Seems to be 15%
224         // faster.)
225         //
226         // This unfortunately misses out on the benefits of direct
227         // floating point generation if an RNG like dSMFT is
228         // used. (That is, such RNGs create floats directly, highly
229         // efficiently and overload next_f32/f64, so by not calling it
230         // this may be slower than it would be otherwise.)
231         // FIXME: investigate/optimise for the above.
232         let bits: u64 = rng.gen();
233         let i = (bits & 0xff) as usize;
234         let f = (bits >> 11) as f64 / SCALE;
235
236         // u is either U(-1, 1) or U(0, 1) depending on if this is a
237         // symmetric distribution or not.
238         let u = if symmetric {
239             2.0 * f - 1.0
240         } else {
241             f
242         };
243         let x = u * x_tab[i];
244
245         let test_x = if symmetric {
246             x.abs()
247         } else {
248             x
249         };
250
251         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
252         if test_x < x_tab[i + 1] {
253             return x;
254         }
255         if i == 0 {
256             return zero_case(rng, u);
257         }
258         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
259         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
260             return x;
261         }
262     }
263 }
264
265 #[cfg(test)]
266 mod tests {
267     use {Rng, Rand};
268     use super::{RandSample, WeightedChoice, Weighted, Sample, IndependentSample};
269
270     #[derive(PartialEq, Debug)]
271     struct ConstRand(usize);
272     impl Rand for ConstRand {
273         fn rand<R: Rng>(_: &mut R) -> ConstRand {
274             ConstRand(0)
275         }
276     }
277
278     // 0, 1, 2, 3, ...
279     struct CountingRng {
280         i: u32,
281     }
282     impl Rng for CountingRng {
283         fn next_u32(&mut self) -> u32 {
284             self.i += 1;
285             self.i - 1
286         }
287         fn next_u64(&mut self) -> u64 {
288             self.next_u32() as u64
289         }
290     }
291
292     #[test]
293     fn test_rand_sample() {
294         let mut rand_sample = RandSample::<ConstRand>::new();
295
296         assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
297         assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
298     }
299     #[test]
300     #[rustfmt_skip]
301     fn test_weighted_choice() {
302         // this makes assumptions about the internal implementation of
303         // WeightedChoice, specifically: it doesn't reorder the items,
304         // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
305         // 1, internally; modulo a modulo operation).
306
307         macro_rules! t {
308             ($items:expr, $expected:expr) => {{
309                 let mut items = $items;
310                 let wc = WeightedChoice::new(&mut items);
311                 let expected = $expected;
312
313                 let mut rng = CountingRng { i: 0 };
314
315                 for &val in &expected {
316                     assert_eq!(wc.ind_sample(&mut rng), val)
317                 }
318             }}
319         }
320
321         t!(vec!(Weighted { weight: 1, item: 10 }),
322            [10]);
323
324         // skip some
325         t!(vec!(Weighted { weight: 0, item: 20 },
326                 Weighted { weight: 2, item: 21 },
327                 Weighted { weight: 0, item: 22 },
328                 Weighted { weight: 1, item: 23 }),
329            [21, 21, 23]);
330
331         // different weights
332         t!(vec!(Weighted { weight: 4, item: 30 },
333                 Weighted { weight: 3, item: 31 }),
334            [30, 30, 30, 30, 31, 31, 31]);
335
336         // check that we're binary searching
337         // correctly with some vectors of odd
338         // length.
339         t!(vec!(Weighted { weight: 1, item: 40 },
340                 Weighted { weight: 1, item: 41 },
341                 Weighted { weight: 1, item: 42 },
342                 Weighted { weight: 1, item: 43 },
343                 Weighted { weight: 1, item: 44 }),
344            [40, 41, 42, 43, 44]);
345         t!(vec!(Weighted { weight: 1, item: 50 },
346                 Weighted { weight: 1, item: 51 },
347                 Weighted { weight: 1, item: 52 },
348                 Weighted { weight: 1, item: 53 },
349                 Weighted { weight: 1, item: 54 },
350                 Weighted { weight: 1, item: 55 },
351                 Weighted { weight: 1, item: 56 }),
352            [50, 51, 52, 53, 54, 55, 56]);
353     }
354
355     #[test]
356     #[should_panic]
357     fn test_weighted_choice_no_items() {
358         WeightedChoice::<isize>::new(&mut []);
359     }
360     #[test]
361     #[should_panic]
362     #[rustfmt_skip]
363     fn test_weighted_choice_zero_weight() {
364         WeightedChoice::new(&mut [Weighted { weight: 0, item: 0 },
365                                   Weighted { weight: 0, item: 1 }]);
366     }
367     #[test]
368     #[should_panic]
369     #[rustfmt_skip]
370     fn test_weighted_choice_weight_overflows() {
371         let x = (!0) as usize / 2; // x + x + 2 is the overflow
372         WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
373                                   Weighted { weight: 1, item: 1 },
374                                   Weighted { weight: x, item: 2 },
375                                   Weighted { weight: 1, item: 3 }]);
376     }
377 }