]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/mod.rs
2e3c0394747535a1c7fd66562e90cef69cb9d230
[rust.git] / src / librand / distributions / mod.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 /*!
12 Sampling from random distributions.
13
14 This is a generalization of `Rand` to allow parameters to control the
15 exact properties of the generated values, e.g. the mean and standard
16 deviation of a normal distribution. The `Sample` trait is the most
17 general, and allows for generating values that change some state
18 internally. The `IndependentSample` trait is for generating values
19 that do not need to record state.
20
21 */
22
23 #![experimental]
24
25 use core::prelude::*;
26 use core::num;
27 use core::num::CheckedAdd;
28
29 use {Rng, Rand};
30
31 pub use self::range::Range;
32 pub use self::gamma::{Gamma, ChiSquared, FisherF, StudentT};
33 pub use self::normal::{Normal, LogNormal};
34 pub use self::exponential::Exp;
35
36 pub mod range;
37 pub mod gamma;
38 pub mod normal;
39 pub mod exponential;
40
41 /// Types that can be used to create a random instance of `Support`.
42 pub trait Sample<Support> {
43     /// Generate a random value of `Support`, using `rng` as the
44     /// source of randomness.
45     fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
46 }
47
48 /// `Sample`s that do not require keeping track of state.
49 ///
50 /// Since no state is recorded, each sample is (statistically)
51 /// independent of all others, assuming the `Rng` used has this
52 /// property.
53 // FIXME maybe having this separate is overkill (the only reason is to
54 // take &self rather than &mut self)? or maybe this should be the
55 // trait called `Sample` and the other should be `DependentSample`.
56 pub trait IndependentSample<Support>: Sample<Support> {
57     /// Generate a random value.
58     fn ind_sample<R: Rng>(&self, &mut R) -> Support;
59 }
60
61 /// A wrapper for generating types that implement `Rand` via the
62 /// `Sample` & `IndependentSample` traits.
63 pub struct RandSample<Sup>;
64
65 impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
66     fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup { self.ind_sample(rng) }
67 }
68
69 impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
70     fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
71         rng.gen()
72     }
73 }
74
75 /// A value with a particular weight for use with `WeightedChoice`.
76 pub struct Weighted<T> {
77     /// The numerical weight of this item
78     pub weight: uint,
79     /// The actual item which is being weighted
80     pub item: T,
81 }
82
83 /// A distribution that selects from a finite collection of weighted items.
84 ///
85 /// Each item has an associated weight that influences how likely it
86 /// is to be chosen: higher weight is more likely.
87 ///
88 /// The `Clone` restriction is a limitation of the `Sample` and
89 /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
90 /// all `T`, as is `uint`, so one can store references or indices into
91 /// another vector.
92 ///
93 /// # Example
94 ///
95 /// ```rust
96 /// use std::rand;
97 /// use std::rand::distributions::{Weighted, WeightedChoice, IndependentSample};
98 ///
99 /// let mut items = vec!(Weighted { weight: 2, item: 'a' },
100 ///                      Weighted { weight: 4, item: 'b' },
101 ///                      Weighted { weight: 1, item: 'c' });
102 /// let wc = WeightedChoice::new(items.as_mut_slice());
103 /// let mut rng = rand::task_rng();
104 /// for _ in range(0, 16) {
105 ///      // on average prints 'a' 4 times, 'b' 8 and 'c' twice.
106 ///      println!("{}", wc.ind_sample(&mut rng));
107 /// }
108 /// ```
109 pub struct WeightedChoice<'a, T> {
110     items: &'a mut [Weighted<T>],
111     weight_range: Range<uint>
112 }
113
114 impl<'a, T: Clone> WeightedChoice<'a, T> {
115     /// Create a new `WeightedChoice`.
116     ///
117     /// Fails if:
118     /// - `v` is empty
119     /// - the total weight is 0
120     /// - the total weight is larger than a `uint` can contain.
121     pub fn new<'a>(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
122         // strictly speaking, this is subsumed by the total weight == 0 case
123         assert!(!items.is_empty(), "WeightedChoice::new called with no items");
124
125         let mut running_total = 0u;
126
127         // we convert the list from individual weights to cumulative
128         // weights so we can binary search. This *could* drop elements
129         // with weight == 0 as an optimisation.
130         for item in items.mut_iter() {
131             running_total = match running_total.checked_add(&item.weight) {
132                 Some(n) => n,
133                 None => fail!("WeightedChoice::new called with a total weight \
134                                larger than a uint can contain")
135             };
136
137             item.weight = running_total;
138         }
139         assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
140
141         WeightedChoice {
142             items: items,
143             // we're likely to be generating numbers in this range
144             // relatively often, so might as well cache it
145             weight_range: Range::new(0, running_total)
146         }
147     }
148 }
149
150 impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
151     fn sample<R: Rng>(&mut self, rng: &mut R) -> T { self.ind_sample(rng) }
152 }
153
154 impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
155     fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
156         // we want to find the first element that has cumulative
157         // weight > sample_weight, which we do by binary since the
158         // cumulative weights of self.items are sorted.
159
160         // choose a weight in [0, total_weight)
161         let sample_weight = self.weight_range.ind_sample(rng);
162
163         // short circuit when it's the first item
164         if sample_weight < self.items[0].weight {
165             return self.items[0].item.clone();
166         }
167
168         let mut idx = 0;
169         let mut modifier = self.items.len();
170
171         // now we know that every possibility has an element to the
172         // left, so we can just search for the last element that has
173         // cumulative weight <= sample_weight, then the next one will
174         // be "it". (Note that this greatest element will never be the
175         // last element of the vector, since sample_weight is chosen
176         // in [0, total_weight) and the cumulative weight of the last
177         // one is exactly the total weight.)
178         while modifier > 1 {
179             let i = idx + modifier / 2;
180             if self.items[i].weight <= sample_weight {
181                 // we're small, so look to the right, but allow this
182                 // exact element still.
183                 idx = i;
184                 // we need the `/ 2` to round up otherwise we'll drop
185                 // the trailing elements when `modifier` is odd.
186                 modifier += 1;
187             } else {
188                 // otherwise we're too big, so go left. (i.e. do
189                 // nothing)
190             }
191             modifier /= 2;
192         }
193         return self.items[idx + 1].item.clone();
194     }
195 }
196
197 mod ziggurat_tables;
198
199 /// Sample a random number using the Ziggurat method (specifically the
200 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
201 /// directly from the paper:
202 ///
203 /// * `rng`: source of randomness
204 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
205 /// * `X`: the $x_i$ abscissae.
206 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
207 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
208 /// * `pdf`: the probability density function
209 /// * `zero_case`: manual sampling from the tail when we chose the
210 ///    bottom box (i.e. i == 0)
211
212 // the perf improvement (25-50%) is definitely worth the extra code
213 // size from force-inlining.
214 #[inline(always)]
215 fn ziggurat<R:Rng>(
216             rng: &mut R,
217             symmetric: bool,
218             x_tab: ziggurat_tables::ZigTable,
219             f_tab: ziggurat_tables::ZigTable,
220             pdf: |f64|: 'static -> f64,
221             zero_case: |&mut R, f64|: 'static -> f64)
222             -> f64 {
223     static SCALE: f64 = (1u64 << 53) as f64;
224     loop {
225         // reimplement the f64 generation as an optimisation suggested
226         // by the Doornik paper: we have a lot of precision-space
227         // (i.e. there are 11 bits of the 64 of a u64 to use after
228         // creating a f64), so we might as well reuse some to save
229         // generating a whole extra random number. (Seems to be 15%
230         // faster.)
231         let bits: u64 = rng.gen();
232         let i = (bits & 0xff) as uint;
233         let f = (bits >> 11) as f64 / SCALE;
234
235         // u is either U(-1, 1) or U(0, 1) depending on if this is a
236         // symmetric distribution or not.
237         let u = if symmetric {2.0 * f - 1.0} else {f};
238         let x = u * x_tab[i];
239
240         let test_x = if symmetric {num::abs(x)} else {x};
241
242         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
243         if test_x < x_tab[i + 1] {
244             return x;
245         }
246         if i == 0 {
247             return zero_case(rng, u);
248         }
249         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
250         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen() < pdf(x) {
251             return x;
252         }
253     }
254 }
255
256 #[cfg(test)]
257 mod tests {
258     use std::prelude::*;
259
260     use {Rng, Rand};
261     use super::{RandSample, WeightedChoice, Weighted, Sample, IndependentSample};
262
263     #[deriving(PartialEq, Show)]
264     struct ConstRand(uint);
265     impl Rand for ConstRand {
266         fn rand<R: Rng>(_: &mut R) -> ConstRand {
267             ConstRand(0)
268         }
269     }
270
271     // 0, 1, 2, 3, ...
272     struct CountingRng { i: u32 }
273     impl Rng for CountingRng {
274         fn next_u32(&mut self) -> u32 {
275             self.i += 1;
276             self.i - 1
277         }
278         fn next_u64(&mut self) -> u64 {
279             self.next_u32() as u64
280         }
281     }
282
283     #[test]
284     fn test_rand_sample() {
285         let mut rand_sample = RandSample::<ConstRand>;
286
287         assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
288         assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
289     }
290     #[test]
291     fn test_weighted_choice() {
292         // this makes assumptions about the internal implementation of
293         // WeightedChoice, specifically: it doesn't reorder the items,
294         // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
295         // 1, internally; modulo a modulo operation).
296
297         macro_rules! t (
298             ($items:expr, $expected:expr) => {{
299                 let mut items = $items;
300                 let wc = WeightedChoice::new(items.as_mut_slice());
301                 let expected = $expected;
302
303                 let mut rng = CountingRng { i: 0 };
304
305                 for &val in expected.iter() {
306                     assert_eq!(wc.ind_sample(&mut rng), val)
307                 }
308             }}
309         );
310
311         t!(vec!(Weighted { weight: 1, item: 10}), [10]);
312
313         // skip some
314         t!(vec!(Weighted { weight: 0, item: 20},
315                 Weighted { weight: 2, item: 21},
316                 Weighted { weight: 0, item: 22},
317                 Weighted { weight: 1, item: 23}),
318            [21,21, 23]);
319
320         // different weights
321         t!(vec!(Weighted { weight: 4, item: 30},
322                 Weighted { weight: 3, item: 31}),
323            [30,30,30,30, 31,31,31]);
324
325         // check that we're binary searching
326         // correctly with some vectors of odd
327         // length.
328         t!(vec!(Weighted { weight: 1, item: 40},
329                 Weighted { weight: 1, item: 41},
330                 Weighted { weight: 1, item: 42},
331                 Weighted { weight: 1, item: 43},
332                 Weighted { weight: 1, item: 44}),
333            [40, 41, 42, 43, 44]);
334         t!(vec!(Weighted { weight: 1, item: 50},
335                 Weighted { weight: 1, item: 51},
336                 Weighted { weight: 1, item: 52},
337                 Weighted { weight: 1, item: 53},
338                 Weighted { weight: 1, item: 54},
339                 Weighted { weight: 1, item: 55},
340                 Weighted { weight: 1, item: 56}),
341            [50, 51, 52, 53, 54, 55, 56]);
342     }
343
344     #[test] #[should_fail]
345     fn test_weighted_choice_no_items() {
346         WeightedChoice::<int>::new([]);
347     }
348     #[test] #[should_fail]
349     fn test_weighted_choice_zero_weight() {
350         WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
351                                   Weighted { weight: 0, item: 1}]);
352     }
353     #[test] #[should_fail]
354     fn test_weighted_choice_weight_overflows() {
355         let x = (-1) as uint / 2; // x + x + 2 is the overflow
356         WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
357                                   Weighted { weight: 1, item: 1 },
358                                   Weighted { weight: x, item: 2 },
359                                   Weighted { weight: 1, item: 3 }]);
360     }
361 }