]> git.lizzy.rs Git - rust.git/blob - src/librand/distributions/gamma.rs
fceda64cbb3f195e2d251bb74c07a5afd3164146
[rust.git] / src / librand / distributions / gamma.rs
1 // Copyright 2013 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10
11 //! The Gamma and derived distributions.
12
13 use self::GammaRepr::*;
14 use self::ChiSquaredRepr::*;
15
16 use FloatMath;
17
18 use {Rng, Open01};
19 use super::normal::StandardNormal;
20 use super::{IndependentSample, Sample, Exp};
21
22 /// The Gamma distribution `Gamma(shape, scale)` distribution.
23 ///
24 /// The density function of this distribution is
25 ///
26 /// ```text
27 /// f(x) =  x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)
28 /// ```
29 ///
30 /// where `Γ` is the Gamma function, `k` is the shape and `θ` is the
31 /// scale and both `k` and `θ` are strictly positive.
32 ///
33 /// The algorithm used is that described by Marsaglia & Tsang 2000[1],
34 /// falling back to directly sampling from an Exponential for `shape
35 /// == 1`, and using the boosting technique described in [1] for
36 /// `shape < 1`.
37 ///
38 /// [1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method
39 /// for Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
40 /// (September 2000),
41 /// 363-372. DOI:[10.1145/358407.358414](http://doi.acm.org/10.1145/358407.358414)
42 pub struct Gamma {
43     repr: GammaRepr,
44 }
45
46 enum GammaRepr {
47     Large(GammaLargeShape),
48     One(Exp),
49     Small(GammaSmallShape),
50 }
51
52 // These two helpers could be made public, but saving the
53 // match-on-Gamma-enum branch from using them directly (e.g. if one
54 // knows that the shape is always > 1) doesn't appear to be much
55 // faster.
56
57 /// Gamma distribution where the shape parameter is less than 1.
58 ///
59 /// Note, samples from this require a compulsory floating-point `pow`
60 /// call, which makes it significantly slower than sampling from a
61 /// gamma distribution where the shape parameter is greater than or
62 /// equal to 1.
63 ///
64 /// See `Gamma` for sampling from a Gamma distribution with general
65 /// shape parameters.
66 struct GammaSmallShape {
67     inv_shape: f64,
68     large_shape: GammaLargeShape,
69 }
70
71 /// Gamma distribution where the shape parameter is larger than 1.
72 ///
73 /// See `Gamma` for sampling from a Gamma distribution with general
74 /// shape parameters.
75 struct GammaLargeShape {
76     scale: f64,
77     c: f64,
78     d: f64,
79 }
80
81 impl Gamma {
82     /// Construct an object representing the `Gamma(shape, scale)`
83     /// distribution.
84     ///
85     /// Panics if `shape <= 0` or `scale <= 0`.
86     pub fn new(shape: f64, scale: f64) -> Gamma {
87         assert!(shape > 0.0, "Gamma::new called with shape <= 0");
88         assert!(scale > 0.0, "Gamma::new called with scale <= 0");
89
90         let repr = match shape {
91             1.0 => One(Exp::new(1.0 / scale)),
92             0.0 ... 1.0 => Small(GammaSmallShape::new_raw(shape, scale)),
93             _ => Large(GammaLargeShape::new_raw(shape, scale)),
94         };
95         Gamma { repr: repr }
96     }
97 }
98
99 impl GammaSmallShape {
100     fn new_raw(shape: f64, scale: f64) -> GammaSmallShape {
101         GammaSmallShape {
102             inv_shape: 1. / shape,
103             large_shape: GammaLargeShape::new_raw(shape + 1.0, scale),
104         }
105     }
106 }
107
108 impl GammaLargeShape {
109     fn new_raw(shape: f64, scale: f64) -> GammaLargeShape {
110         let d = shape - 1. / 3.;
111         GammaLargeShape {
112             scale: scale,
113             c: 1. / (9. * d).sqrt(),
114             d: d,
115         }
116     }
117 }
118
119 impl Sample<f64> for Gamma {
120     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
121         self.ind_sample(rng)
122     }
123 }
124 impl Sample<f64> for GammaSmallShape {
125     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
126         self.ind_sample(rng)
127     }
128 }
129 impl Sample<f64> for GammaLargeShape {
130     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
131         self.ind_sample(rng)
132     }
133 }
134
135 impl IndependentSample<f64> for Gamma {
136     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
137         match self.repr {
138             Small(ref g) => g.ind_sample(rng),
139             One(ref g) => g.ind_sample(rng),
140             Large(ref g) => g.ind_sample(rng),
141         }
142     }
143 }
144 impl IndependentSample<f64> for GammaSmallShape {
145     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
146         let Open01(u) = rng.gen::<Open01<f64>>();
147
148         self.large_shape.ind_sample(rng) * u.powf(self.inv_shape)
149     }
150 }
151 impl IndependentSample<f64> for GammaLargeShape {
152     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
153         loop {
154             let StandardNormal(x) = rng.gen::<StandardNormal>();
155             let v_cbrt = 1.0 + self.c * x;
156             if v_cbrt <= 0.0 { // a^3 <= 0 iff a <= 0
157                 continue;
158             }
159
160             let v = v_cbrt * v_cbrt * v_cbrt;
161             let Open01(u) = rng.gen::<Open01<f64>>();
162
163             let x_sqr = x * x;
164             if u < 1.0 - 0.0331 * x_sqr * x_sqr ||
165                u.ln() < 0.5 * x_sqr + self.d * (1.0 - v + v.ln()) {
166                 return self.d * v * self.scale;
167             }
168         }
169     }
170 }
171
172 /// The chi-squared distribution `χ²(k)`, where `k` is the degrees of
173 /// freedom.
174 ///
175 /// For `k > 0` integral, this distribution is the sum of the squares
176 /// of `k` independent standard normal random variables. For other
177 /// `k`, this uses the equivalent characterization `χ²(k) = Gamma(k/2,
178 /// 2)`.
179 pub struct ChiSquared {
180     repr: ChiSquaredRepr,
181 }
182
183 enum ChiSquaredRepr {
184     // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1,
185     // e.g. when alpha = 1/2 as it would be for this case, so special-
186     // casing and using the definition of N(0,1)^2 is faster.
187     DoFExactlyOne,
188     DoFAnythingElse(Gamma),
189 }
190
191 impl ChiSquared {
192     /// Create a new chi-squared distribution with degrees-of-freedom
193     /// `k`. Panics if `k < 0`.
194     pub fn new(k: f64) -> ChiSquared {
195         let repr = if k == 1.0 {
196             DoFExactlyOne
197         } else {
198             assert!(k > 0.0, "ChiSquared::new called with `k` < 0");
199             DoFAnythingElse(Gamma::new(0.5 * k, 2.0))
200         };
201         ChiSquared { repr: repr }
202     }
203 }
204 impl Sample<f64> for ChiSquared {
205     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
206         self.ind_sample(rng)
207     }
208 }
209 impl IndependentSample<f64> for ChiSquared {
210     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
211         match self.repr {
212             DoFExactlyOne => {
213                 // k == 1 => N(0,1)^2
214                 let StandardNormal(norm) = rng.gen::<StandardNormal>();
215                 norm * norm
216             }
217             DoFAnythingElse(ref g) => g.ind_sample(rng),
218         }
219     }
220 }
221
222 /// The Fisher F distribution `F(m, n)`.
223 ///
224 /// This distribution is equivalent to the ratio of two normalised
225 /// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) /
226 /// (χ²(n)/n)`.
227 pub struct FisherF {
228     numer: ChiSquared,
229     denom: ChiSquared,
230     // denom_dof / numer_dof so that this can just be a straight
231     // multiplication, rather than a division.
232     dof_ratio: f64,
233 }
234
235 impl FisherF {
236     /// Create a new `FisherF` distribution, with the given
237     /// parameter. Panics if either `m` or `n` are not positive.
238     pub fn new(m: f64, n: f64) -> FisherF {
239         assert!(m > 0.0, "FisherF::new called with `m < 0`");
240         assert!(n > 0.0, "FisherF::new called with `n < 0`");
241
242         FisherF {
243             numer: ChiSquared::new(m),
244             denom: ChiSquared::new(n),
245             dof_ratio: n / m,
246         }
247     }
248 }
249 impl Sample<f64> for FisherF {
250     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
251         self.ind_sample(rng)
252     }
253 }
254 impl IndependentSample<f64> for FisherF {
255     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
256         self.numer.ind_sample(rng) / self.denom.ind_sample(rng) * self.dof_ratio
257     }
258 }
259
260 /// The Student t distribution, `t(nu)`, where `nu` is the degrees of
261 /// freedom.
262 pub struct StudentT {
263     chi: ChiSquared,
264     dof: f64,
265 }
266
267 impl StudentT {
268     /// Create a new Student t distribution with `n` degrees of
269     /// freedom. Panics if `n <= 0`.
270     pub fn new(n: f64) -> StudentT {
271         assert!(n > 0.0, "StudentT::new called with `n <= 0`");
272         StudentT {
273             chi: ChiSquared::new(n),
274             dof: n,
275         }
276     }
277 }
278 impl Sample<f64> for StudentT {
279     fn sample<R: Rng>(&mut self, rng: &mut R) -> f64 {
280         self.ind_sample(rng)
281     }
282 }
283 impl IndependentSample<f64> for StudentT {
284     fn ind_sample<R: Rng>(&self, rng: &mut R) -> f64 {
285         let StandardNormal(norm) = rng.gen::<StandardNormal>();
286         norm * (self.dof / self.chi.ind_sample(rng)).sqrt()
287     }
288 }
289
290 #[cfg(test)]
291 mod tests {
292     use distributions::{Sample, IndependentSample};
293     use super::{ChiSquared, StudentT, FisherF};
294
295     #[test]
296     fn test_chi_squared_one() {
297         let mut chi = ChiSquared::new(1.0);
298         let mut rng = ::test::rng();
299         for _ in 0..1000 {
300             chi.sample(&mut rng);
301             chi.ind_sample(&mut rng);
302         }
303     }
304     #[test]
305     fn test_chi_squared_small() {
306         let mut chi = ChiSquared::new(0.5);
307         let mut rng = ::test::rng();
308         for _ in 0..1000 {
309             chi.sample(&mut rng);
310             chi.ind_sample(&mut rng);
311         }
312     }
313     #[test]
314     fn test_chi_squared_large() {
315         let mut chi = ChiSquared::new(30.0);
316         let mut rng = ::test::rng();
317         for _ in 0..1000 {
318             chi.sample(&mut rng);
319             chi.ind_sample(&mut rng);
320         }
321     }
322     #[test]
323     #[should_panic]
324     fn test_chi_squared_invalid_dof() {
325         ChiSquared::new(-1.0);
326     }
327
328     #[test]
329     fn test_f() {
330         let mut f = FisherF::new(2.0, 32.0);
331         let mut rng = ::test::rng();
332         for _ in 0..1000 {
333             f.sample(&mut rng);
334             f.ind_sample(&mut rng);
335         }
336     }
337
338     #[test]
339     fn test_t() {
340         let mut t = StudentT::new(11.0);
341         let mut rng = ::test::rng();
342         for _ in 0..1000 {
343             t.sample(&mut rng);
344             t.ind_sample(&mut rng);
345         }
346     }
347 }
348
349 #[cfg(test)]
350 mod bench {
351     extern crate test;
352     use self::test::Bencher;
353     use std::mem::size_of;
354     use distributions::IndependentSample;
355     use super::Gamma;
356
357
358     #[bench]
359     fn bench_gamma_large_shape(b: &mut Bencher) {
360         let gamma = Gamma::new(10., 1.0);
361         let mut rng = ::test::weak_rng();
362
363         b.iter(|| {
364             for _ in 0..::RAND_BENCH_N {
365                 gamma.ind_sample(&mut rng);
366             }
367         });
368         b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
369     }
370
371     #[bench]
372     fn bench_gamma_small_shape(b: &mut Bencher) {
373         let gamma = Gamma::new(0.1, 1.0);
374         let mut rng = ::test::weak_rng();
375
376         b.iter(|| {
377             for _ in 0..::RAND_BENCH_N {
378                 gamma.ind_sample(&mut rng);
379             }
380         });
381         b.bytes = size_of::<f64>() as u64 * ::RAND_BENCH_N;
382     }
383 }