rand_distr/
gamma.rs

1// Copyright 2018 Developers of the Rand project.
2// Copyright 2013 The Rust Project Developers.
3//
4// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
5// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
6// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
7// option. This file may not be copied, modified, or distributed
8// except according to those terms.
9
10//! The Gamma distribution.
11
12use self::GammaRepr::*;
13
14use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
15use core::fmt;
16use num_traits::Float;
17use rand::Rng;
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21/// The [Gamma distribution](https://en.wikipedia.org/wiki/Gamma_distribution) `Gamma(k, θ)`.
22///
23/// The Gamma distribution is a continuous probability distribution
24/// with shape parameter `k > 0` (number of events) and
25/// scale parameter `θ > 0` (mean waiting time between events).
26/// It describes the time until `k` events occur in a Poisson
27/// process with rate `1/θ`. It is the generalization of the
28/// [`Exponential`](crate::Exp) distribution.
29///
30/// # Density function
31///
32/// `f(x) =  x^(k - 1) * exp(-x / θ) / (Γ(k) * θ^k)` for `x > 0`,
33/// where `Γ` is the [gamma function](https://en.wikipedia.org/wiki/Gamma_function).
34///
35/// # Plot
36///
37/// The following plot illustrates the Gamma distribution with
38/// various values of `k` and `θ`.
39/// Curves with `θ = 1` are more saturated, while corresponding
40/// curves with `θ = 2` have a lighter color.
41///
42/// ![Gamma distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/gamma.svg)
43///
44/// # Example
45///
46/// ```
47/// use rand_distr::{Distribution, Gamma};
48///
49/// let gamma = Gamma::new(2.0, 5.0).unwrap();
50/// let v = gamma.sample(&mut rand::rng());
51/// println!("{} is from a Gamma(2, 5) distribution", v);
52/// ```
53///
54/// # Notes
55///
56/// The algorithm used is that described by Marsaglia & Tsang 2000[^1],
57/// falling back to directly sampling from an Exponential for `shape
58/// == 1`, and using the boosting technique described in that paper for
59/// `shape < 1`.
60///
61/// [^1]: George Marsaglia and Wai Wan Tsang. 2000. "A Simple Method for
62///       Generating Gamma Variables" *ACM Trans. Math. Softw.* 26, 3
63///       (September 2000), 363-372.
64///       DOI:[10.1145/358407.358414](https://doi.acm.org/10.1145/358407.358414)
65#[derive(Clone, Copy, Debug, PartialEq)]
66#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67pub struct Gamma<F>
68where
69    F: Float,
70    StandardNormal: Distribution<F>,
71    Exp1: Distribution<F>,
72    Open01: Distribution<F>,
73{
74    repr: GammaRepr<F>,
75}
76
77/// Error type returned from [`Gamma::new`].
78#[derive(Clone, Copy, Debug, PartialEq, Eq)]
79pub enum Error {
80    /// `shape <= 0` or `nan`.
81    ShapeTooSmall,
82    /// `scale <= 0` or `nan`.
83    ScaleTooSmall,
84    /// `1 / scale == 0`.
85    ScaleTooLarge,
86}
87
88impl fmt::Display for Error {
89    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90        f.write_str(match self {
91            Error::ShapeTooSmall => "shape is not positive in gamma distribution",
92            Error::ScaleTooSmall => "scale is not positive in gamma distribution",
93            Error::ScaleTooLarge => "scale is infinity in gamma distribution",
94        })
95    }
96}
97
98#[cfg(feature = "std")]
99impl std::error::Error for Error {}
100
101#[derive(Clone, Copy, Debug, PartialEq)]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103enum GammaRepr<F>
104where
105    F: Float,
106    StandardNormal: Distribution<F>,
107    Exp1: Distribution<F>,
108    Open01: Distribution<F>,
109{
110    Large(GammaLargeShape<F>),
111    One(Exp<F>),
112    Small(GammaSmallShape<F>),
113}
114
115// These two helpers could be made public, but saving the
116// match-on-Gamma-enum branch from using them directly (e.g. if one
117// knows that the shape is always > 1) doesn't appear to be much
118// faster.
119
120/// Gamma distribution where the shape parameter is less than 1.
121///
122/// Note, samples from this require a compulsory floating-point `pow`
123/// call, which makes it significantly slower than sampling from a
124/// gamma distribution where the shape parameter is greater than or
125/// equal to 1.
126///
127/// See `Gamma` for sampling from a Gamma distribution with general
128/// shape parameters.
129#[derive(Clone, Copy, Debug, PartialEq)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131struct GammaSmallShape<F>
132where
133    F: Float,
134    StandardNormal: Distribution<F>,
135    Open01: Distribution<F>,
136{
137    inv_shape: F,
138    large_shape: GammaLargeShape<F>,
139}
140
141/// Gamma distribution where the shape parameter is larger than 1.
142///
143/// See `Gamma` for sampling from a Gamma distribution with general
144/// shape parameters.
145#[derive(Clone, Copy, Debug, PartialEq)]
146#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147struct GammaLargeShape<F>
148where
149    F: Float,
150    StandardNormal: Distribution<F>,
151    Open01: Distribution<F>,
152{
153    scale: F,
154    c: F,
155    d: F,
156}
157
158impl<F> Gamma<F>
159where
160    F: Float,
161    StandardNormal: Distribution<F>,
162    Exp1: Distribution<F>,
163    Open01: Distribution<F>,
164{
165    /// Construct an object representing the `Gamma(shape, scale)`
166    /// distribution.
167    #[inline]
168    pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
169        if !(shape > F::zero()) {
170            return Err(Error::ShapeTooSmall);
171        }
172        if !(scale > F::zero()) {
173            return Err(Error::ScaleTooSmall);
174        }
175
176        let repr = if shape == F::one() {
177            One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
178        } else if shape < F::one() {
179            Small(GammaSmallShape::new_raw(shape, scale))
180        } else {
181            Large(GammaLargeShape::new_raw(shape, scale))
182        };
183        Ok(Gamma { repr })
184    }
185}
186
187impl<F> GammaSmallShape<F>
188where
189    F: Float,
190    StandardNormal: Distribution<F>,
191    Open01: Distribution<F>,
192{
193    fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
194        GammaSmallShape {
195            inv_shape: F::one() / shape,
196            large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
197        }
198    }
199}
200
201impl<F> GammaLargeShape<F>
202where
203    F: Float,
204    StandardNormal: Distribution<F>,
205    Open01: Distribution<F>,
206{
207    fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
208        let d = shape - F::from(1. / 3.).unwrap();
209        GammaLargeShape {
210            scale,
211            c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
212            d,
213        }
214    }
215}
216
217impl<F> Distribution<F> for Gamma<F>
218where
219    F: Float,
220    StandardNormal: Distribution<F>,
221    Exp1: Distribution<F>,
222    Open01: Distribution<F>,
223{
224    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
225        match self.repr {
226            Small(ref g) => g.sample(rng),
227            One(ref g) => g.sample(rng),
228            Large(ref g) => g.sample(rng),
229        }
230    }
231}
232impl<F> Distribution<F> for GammaSmallShape<F>
233where
234    F: Float,
235    StandardNormal: Distribution<F>,
236    Open01: Distribution<F>,
237{
238    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
239        let u: F = rng.sample(Open01);
240
241        self.large_shape.sample(rng) * u.powf(self.inv_shape)
242    }
243}
244impl<F> Distribution<F> for GammaLargeShape<F>
245where
246    F: Float,
247    StandardNormal: Distribution<F>,
248    Open01: Distribution<F>,
249{
250    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
251        // Marsaglia & Tsang method, 2000
252        loop {
253            let x: F = rng.sample(StandardNormal);
254            let v_cbrt = F::one() + self.c * x;
255            if v_cbrt <= F::zero() {
256                // a^3 <= 0 iff a <= 0
257                continue;
258            }
259
260            let v = v_cbrt * v_cbrt * v_cbrt;
261            let u: F = rng.sample(Open01);
262
263            let x_sqr = x * x;
264            if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
265                || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
266            {
267                return self.d * v * self.scale;
268            }
269        }
270    }
271}
272
273#[cfg(test)]
274mod test {
275    use super::*;
276
277    #[test]
278    fn gamma_distributions_can_be_compared() {
279        assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
280    }
281}