]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/mod.rs
Auto merge of #22541 - Manishearth:rollup, r=Gankro
[rust.git] / src / librand / distributions / mod.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Sampling from random distributions.
12 //!
13 //! This is a generalization of `Rand` to allow parameters to control the
14 //! exact properties of the generated values, e.g. the mean and standard
15 //! deviation of a normal distribution. The `Sample` trait is the most
16 //! general, and allows for generating values that change some state
17 //! internally. The `IndependentSample` trait is for generating values
18 //! that do not need to record state.
19
20 #![unstable(feature = "rand")]
21
22 use core::prelude::*;
23 use core::num::{Float, Int};
24 use core::marker::PhantomData;
25
26 use {Rng, Rand};
27
28 pub use self::range::Range;
29 pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
30 pub use self::normal::{Normal, LogNormal};
31 pub use self::exponential::Exp;
32
33 pub mod range;
34 pub mod gamma;
35 pub mod normal;
36 pub mod exponential;
37
38 /// Types that can be used to create a random instance of `Support`.
39 pub trait Sample<Support> {
40     /// Generate a random value of `Support`, using `rng` as the
41     /// source of randomness.
42     fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
43 }
44
45 /// `Sample`s that do not require keeping track of state.
46 ///
47 /// Since no state is recorded, each sample is (statistically)
48 /// independent of all others, assuming the `Rng` used has this
49 /// property.
50 // FIXME maybe having this separate is overkill (the only reason is to
51 // take &self rather than &mut self)? or maybe this should be the
52 // trait called `Sample` and the other should be `DependentSample`.
53 pub trait IndependentSample<Support>: Sample<Support> {
54     /// Generate a random value.
55     fn ind_sample<R: Rng>(&self, &mut R) -> Support;
56 }
57
58 /// A wrapper for generating types that implement `Rand` via the
59 /// `Sample` & `IndependentSample` traits.
60 pub struct RandSample<Sup> { _marker: PhantomData<Sup> }
61
62 impl<Sup> RandSample<Sup> {
63     pub fn new() -> RandSample<Sup> {
64         RandSample { _marker: PhantomData }
65     }
66 }
67
68 impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
69     fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup { self.ind_sample(rng) }
70 }
71
72 impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
73     fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
74         rng.gen()
75     }
76 }
77
78 /// A value with a particular weight for use with `WeightedChoice`.
79 pub struct Weighted<T> {
80     /// The numerical weight of this item
81     pub weight: uint,
82     /// The actual item which is being weighted
83     pub item: T,
84 }
85
86 /// A distribution that selects from a finite collection of weighted items.
87 ///
88 /// Each item has an associated weight that influences how likely it
89 /// is to be chosen: higher weight is more likely.
90 ///
91 /// The `Clone` restriction is a limitation of the `Sample` and
92 /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
93 /// all `T`, as is `uint`, so one can store references or indices into
94 /// another vector.
95 ///
96 /// # Example
97 ///
98 /// ```rust
99 /// use std::rand;
100 /// use std::rand::distributions::{Weighted, WeightedChoice, IndependentSample};
101 ///
102 /// let mut items = vec!(Weighted { weight: 2, item: 'a' },
103 ///                      Weighted { weight: 4, item: 'b' },
104 ///                      Weighted { weight: 1, item: 'c' });
105 /// let wc = WeightedChoice::new(items.as_mut_slice());
106 /// let mut rng = rand::thread_rng();
107 /// for _ in 0..16 {
108 ///      // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
109 ///      println!("{}", wc.ind_sample(&mut rng));
110 /// }
111 /// ```
112 pub struct WeightedChoice<'a, T:'a> {
113     items: &'a mut [Weighted<T>],
114     weight_range: Range<uint>
115 }
116
117 impl<'a, T: Clone> WeightedChoice<'a, T> {
118     /// Create a new `WeightedChoice`.
119     ///
120     /// Panics if:
121     /// - `v` is empty
122     /// - the total weight is 0
123     /// - the total weight is larger than a `uint` can contain.
124     pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
125         // strictly speaking, this is subsumed by the total weight == 0 case
126         assert!(!items.is_empty(), "WeightedChoice::new called with no items");
127
128         let mut running_total = 0;
129
130         // we convert the list from individual weights to cumulative
131         // weights so we can binary search. This *could* drop elements
132         // with weight == 0 as an optimisation.
133         for item in &mut *items {
134             running_total = match running_total.checked_add(item.weight) {
135                 Some(n) => n,
136                 None => panic!("WeightedChoice::new called with a total weight \
137                                larger than a uint can contain")
138             };
139
140             item.weight = running_total;
141         }
142         assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
143
144         WeightedChoice {
145             items: items,
146             // we're likely to be generating numbers in this range
147             // relatively often, so might as well cache it
148             weight_range: Range::new(0, running_total)
149         }
150     }
151 }
152
153 impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
154     fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
155 }
156
157 impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
158     fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
159         // we want to find the first element that has cumulative
160         // weight > sample_weight, which we do by binary since the
161         // cumulative weights of self.items are sorted.
162
163         // choose a weight in [0, total_weight)
164         let sample_weight = self.weight_range.ind_sample(rng);
165
166         // short circuit when it's the first item
167         if sample_weight < self.items[0].weight {
168             return self.items[0].item.clone();
169         }
170
171         let mut idx = 0;
172         let mut modifier = self.items.len();
173
174         // now we know that every possibility has an element to the
175         // left, so we can just search for the last element that has
176         // cumulative weight <= sample_weight, then the next one will
177         // be "it". (Note that this greatest element will never be the
178         // last element of the vector, since sample_weight is chosen
179         // in [0, total_weight) and the cumulative weight of the last
180         // one is exactly the total weight.)
181         while modifier > 1 {
182             let i = idx + modifier / 2;
183             if self.items[i].weight <= sample_weight {
184                 // we're small, so look to the right, but allow this
185                 // exact element still.
186                 idx = i;
187                 // we need the `/ 2` to round up otherwise we'll drop
188                 // the trailing elements when `modifier` is odd.
189                 modifier += 1;
190             } else {
191                 // otherwise we're too big, so go left. (i.e. do
192                 // nothing)
193             }
194             modifier /= 2;
195         }
196         return self.items[idx + 1].item.clone();
197     }
198 }
199
200 mod ziggurat_tables;
201
202 /// Sample a random number using the Ziggurat method (specifically the
203 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
204 /// directly from the paper:
205 ///
206 /// * `rng`: source of randomness
207 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
208 /// * `X`: the $x_i$ abscissae.
209 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
210 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
211 /// * `pdf`: the probability density function
212 /// * `zero_case`: manual sampling from the tail when we chose the
213 ///    bottom box (i.e. i == 0)
214
215 // the perf improvement (25-50%) is definitely worth the extra code
216 // size from force-inlining.
217 #[inline(always)]
218 fn ziggurat<R: Rng, P, Z>(
219             rng: &mut R,
220             symmetric: bool,
221             x_tab: ziggurat_tables::ZigTable,
222             f_tab: ziggurat_tables::ZigTable,
223             mut pdf: P,
224             mut zero_case: Z)
225             -> f64 where P: FnMut(f64) -> f64, Z: FnMut(&mut R, f64) -> f64 {
226     static SCALE: f64 = (1u64 << 53) as f64;
227     loop {
228         // reimplement the f64 generation as an optimisation suggested
229         // by the Doornik paper: we have a lot of precision-space
230         // (i.e. there are 11 bits of the 64 of a u64 to use after
231         // creating a f64), so we might as well reuse some to save
232         // generating a whole extra random number. (Seems to be 15%
233         // faster.)
234         //
235         // This unfortunately misses out on the benefits of direct
236         // floating point generation if an RNG like dSMFT is
237         // used. (That is, such RNGs create floats directly, highly
238         // efficiently and overload next_f32/f64, so by not calling it
239         // this may be slower than it would be otherwise.)
240         // FIXME: investigate/optimise for the above.
241         let bits: u64 = rng.gen();
242         let i = (bits & 0xff) as uint;
243         let f = (bits >> 11) as f64 / SCALE;
244
245         // u is either U(-1, 1) or U(0, 1) depending on if this is a
246         // symmetric distribution or not.
247         let u = if symmetric {2.0 * f - 1.0} else {f};
248         let x = u * x_tab[i];
249
250         let test_x = if symmetric { x.abs() } else {x};
251
252         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
253         if test_x < x_tab[i + 1] {
254             return x;
255         }
256         if i == 0 {
257             return zero_case(rng, u);
258         }
259         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
260         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen() < pdf(x) {
261             return x;
262         }
263     }
264 }
265
266 #[cfg(test)]
267 mod tests {
268     use std::prelude::v1::*;
269
270     use {Rng, Rand};
271     use super::{RandSample, WeightedChoice, Weighted, Sample, IndependentSample};
272
273     #[derive(PartialEq, Debug)]
274     struct ConstRand(uint);
275     impl Rand for ConstRand {
276         fn rand<R: Rng>(_: &mut R) -> ConstRand {
277             ConstRand(0)
278         }
279     }
280
281     // 0, 1, 2, 3, ...
282     struct CountingRng { i: u32 }
283     impl Rng for CountingRng {
284         fn next_u32(&mut self) -> u32 {
285             self.i += 1;
286             self.i - 1
287         }
288         fn next_u64(&mut self) -> u64 {
289             self.next_u32() as u64
290         }
291     }
292
293     #[test]
294     fn test_rand_sample() {
295         let mut rand_sample = RandSample::<ConstRand>::new();
296
297         assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
298         assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
299     }
300     #[test]
301     fn test_weighted_choice() {
302         // this makes assumptions about the internal implementation of
303         // WeightedChoice, specifically: it doesn't reorder the items,
304         // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
305         // 1, internally; modulo a modulo operation).
306
307         macro_rules! t {
308             ($items:expr, $expected:expr) => {{
309                 let mut items = $items;
310                 let wc = WeightedChoice::new(&mut items);
311                 let expected = $expected;
312
313                 let mut rng = CountingRng { i: 0 };
314
315                 for &val in &expected {
316                     assert_eq!(wc.ind_sample(&mut rng), val)
317                 }
318             }}
319         }
320
321         t!(vec!(Weighted { weight: 1, item: 10}), [10]);
322
323         // skip some
324         t!(vec!(Weighted { weight: 0, item: 20},
325                 Weighted { weight: 2, item: 21},
326                 Weighted { weight: 0, item: 22},
327                 Weighted { weight: 1, item: 23}),
328            [21,21, 23]);
329
330         // different weights
331         t!(vec!(Weighted { weight: 4, item: 30},
332                 Weighted { weight: 3, item: 31}),
333            [30,30,30,30, 31,31,31]);
334
335         // check that we're binary searching
336         // correctly with some vectors of odd
337         // length.
338         t!(vec!(Weighted { weight: 1, item: 40},
339                 Weighted { weight: 1, item: 41},
340                 Weighted { weight: 1, item: 42},
341                 Weighted { weight: 1, item: 43},
342                 Weighted { weight: 1, item: 44}),
343            [40, 41, 42, 43, 44]);
344         t!(vec!(Weighted { weight: 1, item: 50},
345                 Weighted { weight: 1, item: 51},
346                 Weighted { weight: 1, item: 52},
347                 Weighted { weight: 1, item: 53},
348                 Weighted { weight: 1, item: 54},
349                 Weighted { weight: 1, item: 55},
350                 Weighted { weight: 1, item: 56}),
351            [50, 51, 52, 53, 54, 55, 56]);
352     }
353
354     #[test] #[should_fail]
355     fn test_weighted_choice_no_items() {
356         WeightedChoice::<int>::new(&mut []);
357     }
358     #[test] #[should_fail]
359     fn test_weighted_choice_zero_weight() {
360         WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
361                                   Weighted { weight: 0, item: 1}]);
362     }
363     #[test] #[should_fail]
364     fn test_weighted_choice_weight_overflows() {
365         let x = (-1) as uint / 2; // x + x + 2 is the overflow
366         WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
367                                   Weighted { weight: 1, item: 1 },
368                                   Weighted { weight: x, item: 2 },
369                                   Weighted { weight: 1, item: 3 }]);
370     }
371 }