]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/mod.rs
fb0c9f17cf36737e78f8cd39501b6115800b1ea5
[rust.git] / src / librand / distributions / mod.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! Sampling from random distributions.
12 //!
13 //! This is a generalization of `Rand` to allow parameters to control the
14 //! exact properties of the generated values, e.g. the mean and standard
15 //! deviation of a normal distribution. The `Sample` trait is the most
16 //! general, and allows for generating values that change some state
17 //! internally. The `IndependentSample` trait is for generating values
18 //! that do not need to record state.
19
20 use core::fmt;
21
22 #[cfg(not(test))] // only necessary for no_std
23 use core::num::Float;
24
25 use core::marker::PhantomData;
26
27 use {Rand, Rng};
28
29 pub use self::range::Range;
30 pub use self::gamma::{ChiSquared, FisherF, Gamma, StudentT};
31 pub use self::normal::{LogNormal, Normal};
32 pub use self::exponential::Exp;
33
34 pub mod range;
35 pub mod gamma;
36 pub mod normal;
37 pub mod exponential;
38
39 /// Types that can be used to create a random instance of `Support`.
40 pub trait Sample<Support> {
41     /// Generate a random value of `Support`, using `rng` as the
42     /// source of randomness.
43     fn sample<R: Rng>(&mut self, rng: &mut R) -> Support;
44 }
45
46 /// `Sample`s that do not require keeping track of state.
47 ///
48 /// Since no state is recorded, each sample is (statistically)
49 /// independent of all others, assuming the `Rng` used has this
50 /// property.
51 // FIXME maybe having this separate is overkill (the only reason is to
52 // take &self rather than &mut self)? or maybe this should be the
53 // trait called `Sample` and the other should be `DependentSample`.
54 pub trait IndependentSample<Support>: Sample<Support> {
55     /// Generate a random value.
56     fn ind_sample<R: Rng>(&self, &mut R) -> Support;
57 }
58
59 /// A wrapper for generating types that implement `Rand` via the
60 /// `Sample` & `IndependentSample` traits.
61 pub struct RandSample<Sup> {
62     _marker: PhantomData<Sup>,
63 }
64
65 impl<Sup> RandSample<Sup> {
66     pub fn new() -> RandSample<Sup> {
67         RandSample { _marker: PhantomData }
68     }
69 }
70
71 impl<Sup: Rand> Sample<Sup> for RandSample<Sup> {
72     fn sample<R: Rng>(&mut self, rng: &mut R) -> Sup {
73         self.ind_sample(rng)
74     }
75 }
76
77 impl<Sup: Rand> IndependentSample<Sup> for RandSample<Sup> {
78     fn ind_sample<R: Rng>(&self, rng: &mut R) -> Sup {
79         rng.gen()
80     }
81 }
82
83 impl<Sup> fmt::Debug for RandSample<Sup> {
84     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
85         f.pad("RandSample { .. }")
86     }
87 }
88
89 /// A value with a particular weight for use with `WeightedChoice`.
90 pub struct Weighted<T> {
91     /// The numerical weight of this item
92     pub weight: usize,
93     /// The actual item which is being weighted
94     pub item: T,
95 }
96
97 impl<T> fmt::Debug for Weighted<T> {
98     default fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
99         f.pad("Weighted")
100     }
101 }
102
103 impl<T: fmt::Debug> fmt::Debug for Weighted<T> {
104     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
105         f.debug_struct("Weighted")
106          .field("weight", &self.weight)
107          .field("item", &self.item)
108          .finish()
109     }
110 }
111
112 /// A distribution that selects from a finite collection of weighted items.
113 ///
114 /// Each item has an associated weight that influences how likely it
115 /// is to be chosen: higher weight is more likely.
116 ///
117 /// The `Clone` restriction is a limitation of the `Sample` and
118 /// `IndependentSample` traits. Note that `&T` is (cheaply) `Clone` for
119 /// all `T`, as is `usize`, so one can store references or indices into
120 /// another vector.
121 pub struct WeightedChoice<'a, T: 'a> {
122     items: &'a mut [Weighted<T>],
123     weight_range: Range<usize>,
124 }
125
126 impl<'a, T: Clone> WeightedChoice<'a, T> {
127     /// Create a new `WeightedChoice`.
128     ///
129     /// Panics if:
130     /// - `v` is empty
131     /// - the total weight is 0
132     /// - the total weight is larger than a `usize` can contain.
133     pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
134         // strictly speaking, this is subsumed by the total weight == 0 case
135         assert!(!items.is_empty(),
136                 "WeightedChoice::new called with no items");
137
138         let mut running_total = 0_usize;
139
140         // we convert the list from individual weights to cumulative
141         // weights so we can binary search. This *could* drop elements
142         // with weight == 0 as an optimisation.
143         for item in &mut *items {
144             running_total = match running_total.checked_add(item.weight) {
145                 Some(n) => n,
146                 None => {
147                     panic!("WeightedChoice::new called with a total weight larger than a usize \
148                             can contain")
149                 }
150             };
151
152             item.weight = running_total;
153         }
154         assert!(running_total != 0,
155                 "WeightedChoice::new called with a total weight of 0");
156
157         WeightedChoice {
158             items: items,
159             // we're likely to be generating numbers in this range
160             // relatively often, so might as well cache it
161             weight_range: Range::new(0, running_total),
162         }
163     }
164 }
165
166 impl<'a, T: Clone> Sample<T> for WeightedChoice<'a, T> {
167     fn sample<R: Rng>(&mut self, rng: &mut R) -> T {
168         self.ind_sample(rng)
169     }
170 }
171
172 impl<'a, T: Clone> IndependentSample<T> for WeightedChoice<'a, T> {
173     fn ind_sample<R: Rng>(&self, rng: &mut R) -> T {
174         // we want to find the first element that has cumulative
175         // weight > sample_weight, which we do by binary since the
176         // cumulative weights of self.items are sorted.
177
178         // choose a weight in [0, total_weight)
179         let sample_weight = self.weight_range.ind_sample(rng);
180
181         // short circuit when it's the first item
182         if sample_weight < self.items[0].weight {
183             return self.items[0].item.clone();
184         }
185
186         let mut idx = 0;
187         let mut modifier = self.items.len();
188
189         // now we know that every possibility has an element to the
190         // left, so we can just search for the last element that has
191         // cumulative weight <= sample_weight, then the next one will
192         // be "it". (Note that this greatest element will never be the
193         // last element of the vector, since sample_weight is chosen
194         // in [0, total_weight) and the cumulative weight of the last
195         // one is exactly the total weight.)
196         while modifier > 1 {
197             let i = idx + modifier / 2;
198             if self.items[i].weight <= sample_weight {
199                 // we're small, so look to the right, but allow this
200                 // exact element still.
201                 idx = i;
202                 // we need the `/ 2` to round up otherwise we'll drop
203                 // the trailing elements when `modifier` is odd.
204                 modifier += 1;
205             } else {
206                 // otherwise we're too big, so go left. (i.e. do
207                 // nothing)
208             }
209             modifier /= 2;
210         }
211         return self.items[idx + 1].item.clone();
212     }
213 }
214
215 impl<'a, T> fmt::Debug for WeightedChoice<'a, T> {
216     default fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
217         f.pad("WeightedChoice")
218     }
219 }
220
221 impl<'a, T: fmt::Debug> fmt::Debug for WeightedChoice<'a, T> {
222     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
223         f.debug_struct("WeightedChoice")
224          .field("items", &self.items)
225          .field("weight_range", &self.weight_range)
226          .finish()
227     }
228 }
229
230 mod ziggurat_tables;
231
232 /// Sample a random number using the Ziggurat method (specifically the
233 /// ZIGNOR variant from Doornik 2005). Most of the arguments are
234 /// directly from the paper:
235 ///
236 /// * `rng`: source of randomness
237 /// * `symmetric`: whether this is a symmetric distribution, or one-sided with P(x < 0) = 0.
238 /// * `X`: the $x_i$ abscissae.
239 /// * `F`: precomputed values of the PDF at the $x_i$, (i.e. $f(x_i)$)
240 /// * `F_DIFF`: precomputed values of $f(x_i) - f(x_{i+1})$
241 /// * `pdf`: the probability density function
242 /// * `zero_case`: manual sampling from the tail when we chose the
243 ///    bottom box (i.e. i == 0)
244 // the perf improvement (25-50%) is definitely worth the extra code
245 // size from force-inlining.
246 #[inline(always)]
247 fn ziggurat<R: Rng, P, Z>(rng: &mut R,
248                           symmetric: bool,
249                           x_tab: ziggurat_tables::ZigTable,
250                           f_tab: ziggurat_tables::ZigTable,
251                           mut pdf: P,
252                           mut zero_case: Z)
253                           -> f64
254     where P: FnMut(f64) -> f64,
255           Z: FnMut(&mut R, f64) -> f64
256 {
257     const SCALE: f64 = (1u64 << 53) as f64;
258     loop {
259         // reimplement the f64 generation as an optimisation suggested
260         // by the Doornik paper: we have a lot of precision-space
261         // (i.e. there are 11 bits of the 64 of a u64 to use after
262         // creating a f64), so we might as well reuse some to save
263         // generating a whole extra random number. (Seems to be 15%
264         // faster.)
265         //
266         // This unfortunately misses out on the benefits of direct
267         // floating point generation if an RNG like dSMFT is
268         // used. (That is, such RNGs create floats directly, highly
269         // efficiently and overload next_f32/f64, so by not calling it
270         // this may be slower than it would be otherwise.)
271         // FIXME: investigate/optimise for the above.
272         let bits: u64 = rng.gen();
273         let i = (bits & 0xff) as usize;
274         let f = (bits >> 11) as f64 / SCALE;
275
276         // u is either U(-1, 1) or U(0, 1) depending on if this is a
277         // symmetric distribution or not.
278         let u = if symmetric { 2.0 * f - 1.0 } else { f };
279         let x = u * x_tab[i];
280
281         let test_x = if symmetric { x.abs() } else { x };
282
283         // algebraically equivalent to |u| < x_tab[i+1]/x_tab[i] (or u < x_tab[i+1]/x_tab[i])
284         if test_x < x_tab[i + 1] {
285             return x;
286         }
287         if i == 0 {
288             return zero_case(rng, u);
289         }
290         // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
291         if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
292             return x;
293         }
294     }
295 }
296
297 #[cfg(test)]
298 mod tests {
299     use {Rand, Rng};
300     use super::{IndependentSample, RandSample, Sample, Weighted, WeightedChoice};
301
302     #[derive(PartialEq, Debug)]
303     struct ConstRand(usize);
304     impl Rand for ConstRand {
305         fn rand<R: Rng>(_: &mut R) -> ConstRand {
306             ConstRand(0)
307         }
308     }
309
310     // 0, 1, 2, 3, ...
311     struct CountingRng {
312         i: u32,
313     }
314     impl Rng for CountingRng {
315         fn next_u32(&mut self) -> u32 {
316             self.i += 1;
317             self.i - 1
318         }
319         fn next_u64(&mut self) -> u64 {
320             self.next_u32() as u64
321         }
322     }
323
324     #[test]
325     fn test_rand_sample() {
326         let mut rand_sample = RandSample::<ConstRand>::new();
327
328         assert_eq!(rand_sample.sample(&mut ::test::rng()), ConstRand(0));
329         assert_eq!(rand_sample.ind_sample(&mut ::test::rng()), ConstRand(0));
330     }
331     #[test]
332     #[rustfmt_skip]
333     fn test_weighted_choice() {
334         // this makes assumptions about the internal implementation of
335         // WeightedChoice, specifically: it doesn't reorder the items,
336         // it doesn't do weird things to the RNG (so 0 maps to 0, 1 to
337         // 1, internally; modulo a modulo operation).
338
339         macro_rules! t {
340             ($items:expr, $expected:expr) => {{
341                 let mut items = $items;
342                 let wc = WeightedChoice::new(&mut items);
343                 let expected = $expected;
344
345                 let mut rng = CountingRng { i: 0 };
346
347                 for &val in &expected {
348                     assert_eq!(wc.ind_sample(&mut rng), val)
349                 }
350             }}
351         }
352
353         t!(vec![Weighted { weight: 1, item: 10 }],
354            [10]);
355
356         // skip some
357         t!(vec![Weighted { weight: 0, item: 20 },
358                 Weighted { weight: 2, item: 21 },
359                 Weighted { weight: 0, item: 22 },
360                 Weighted { weight: 1, item: 23 }],
361            [21, 21, 23]);
362
363         // different weights
364         t!(vec![Weighted { weight: 4, item: 30 },
365                 Weighted { weight: 3, item: 31 }],
366            [30, 30, 30, 30, 31, 31, 31]);
367
368         // check that we're binary searching
369         // correctly with some vectors of odd
370         // length.
371         t!(vec![Weighted { weight: 1, item: 40 },
372                 Weighted { weight: 1, item: 41 },
373                 Weighted { weight: 1, item: 42 },
374                 Weighted { weight: 1, item: 43 },
375                 Weighted { weight: 1, item: 44 }],
376            [40, 41, 42, 43, 44]);
377         t!(vec![Weighted { weight: 1, item: 50 },
378                 Weighted { weight: 1, item: 51 },
379                 Weighted { weight: 1, item: 52 },
380                 Weighted { weight: 1, item: 53 },
381                 Weighted { weight: 1, item: 54 },
382                 Weighted { weight: 1, item: 55 },
383                 Weighted { weight: 1, item: 56 }],
384            [50, 51, 52, 53, 54, 55, 56]);
385     }
386
387     #[test]
388     #[should_panic]
389     fn test_weighted_choice_no_items() {
390         WeightedChoice::<isize>::new(&mut []);
391     }
392     #[test]
393     #[should_panic]
394     #[rustfmt_skip]
395     fn test_weighted_choice_zero_weight() {
396         WeightedChoice::new(&mut [Weighted { weight: 0, item: 0 },
397                                   Weighted { weight: 0, item: 1 }]);
398     }
399     #[test]
400     #[should_panic]
401     #[rustfmt_skip]
402     fn test_weighted_choice_weight_overflows() {
403         let x = (!0) as usize / 2; // x + x + 2 is the overflow
404         WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
405                                   Weighted { weight: 1, item: 1 },
406                                   Weighted { weight: x, item: 2 },
407                                   Weighted { weight: 1, item: 3 }]);
408     }
409 }