1use self::GammaRepr::*;
13
14use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
15use core::fmt;
16use num_traits::Float;
17use rand::Rng;
18#[cfg(feature = "serde")]
19use serde::{Deserialize, Serialize};
20
21#[derive(Clone, Copy, Debug, PartialEq)]
66#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
67pub struct Gamma<F>
68where
69 F: Float,
70 StandardNormal: Distribution<F>,
71 Exp1: Distribution<F>,
72 Open01: Distribution<F>,
73{
74 repr: GammaRepr<F>,
75}
76
77#[derive(Clone, Copy, Debug, PartialEq, Eq)]
79pub enum Error {
80 ShapeTooSmall,
82 ScaleTooSmall,
84 ScaleTooLarge,
86}
87
88impl fmt::Display for Error {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.write_str(match self {
91 Error::ShapeTooSmall => "shape is not positive in gamma distribution",
92 Error::ScaleTooSmall => "scale is not positive in gamma distribution",
93 Error::ScaleTooLarge => "scale is infinity in gamma distribution",
94 })
95 }
96}
97
98#[cfg(feature = "std")]
99impl std::error::Error for Error {}
100
101#[derive(Clone, Copy, Debug, PartialEq)]
102#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
103enum GammaRepr<F>
104where
105 F: Float,
106 StandardNormal: Distribution<F>,
107 Exp1: Distribution<F>,
108 Open01: Distribution<F>,
109{
110 Large(GammaLargeShape<F>),
111 One(Exp<F>),
112 Small(GammaSmallShape<F>),
113}
114
115#[derive(Clone, Copy, Debug, PartialEq)]
130#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
131struct GammaSmallShape<F>
132where
133 F: Float,
134 StandardNormal: Distribution<F>,
135 Open01: Distribution<F>,
136{
137 inv_shape: F,
138 large_shape: GammaLargeShape<F>,
139}
140
141#[derive(Clone, Copy, Debug, PartialEq)]
146#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
147struct GammaLargeShape<F>
148where
149 F: Float,
150 StandardNormal: Distribution<F>,
151 Open01: Distribution<F>,
152{
153 scale: F,
154 c: F,
155 d: F,
156}
157
158impl<F> Gamma<F>
159where
160 F: Float,
161 StandardNormal: Distribution<F>,
162 Exp1: Distribution<F>,
163 Open01: Distribution<F>,
164{
165 #[inline]
168 pub fn new(shape: F, scale: F) -> Result<Gamma<F>, Error> {
169 if !(shape > F::zero()) {
170 return Err(Error::ShapeTooSmall);
171 }
172 if !(scale > F::zero()) {
173 return Err(Error::ScaleTooSmall);
174 }
175
176 let repr = if shape == F::one() {
177 One(Exp::new(F::one() / scale).map_err(|_| Error::ScaleTooLarge)?)
178 } else if shape < F::one() {
179 Small(GammaSmallShape::new_raw(shape, scale))
180 } else {
181 Large(GammaLargeShape::new_raw(shape, scale))
182 };
183 Ok(Gamma { repr })
184 }
185}
186
187impl<F> GammaSmallShape<F>
188where
189 F: Float,
190 StandardNormal: Distribution<F>,
191 Open01: Distribution<F>,
192{
193 fn new_raw(shape: F, scale: F) -> GammaSmallShape<F> {
194 GammaSmallShape {
195 inv_shape: F::one() / shape,
196 large_shape: GammaLargeShape::new_raw(shape + F::one(), scale),
197 }
198 }
199}
200
201impl<F> GammaLargeShape<F>
202where
203 F: Float,
204 StandardNormal: Distribution<F>,
205 Open01: Distribution<F>,
206{
207 fn new_raw(shape: F, scale: F) -> GammaLargeShape<F> {
208 let d = shape - F::from(1. / 3.).unwrap();
209 GammaLargeShape {
210 scale,
211 c: F::one() / (F::from(9.).unwrap() * d).sqrt(),
212 d,
213 }
214 }
215}
216
217impl<F> Distribution<F> for Gamma<F>
218where
219 F: Float,
220 StandardNormal: Distribution<F>,
221 Exp1: Distribution<F>,
222 Open01: Distribution<F>,
223{
224 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
225 match self.repr {
226 Small(ref g) => g.sample(rng),
227 One(ref g) => g.sample(rng),
228 Large(ref g) => g.sample(rng),
229 }
230 }
231}
232impl<F> Distribution<F> for GammaSmallShape<F>
233where
234 F: Float,
235 StandardNormal: Distribution<F>,
236 Open01: Distribution<F>,
237{
238 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
239 let u: F = rng.sample(Open01);
240
241 self.large_shape.sample(rng) * u.powf(self.inv_shape)
242 }
243}
244impl<F> Distribution<F> for GammaLargeShape<F>
245where
246 F: Float,
247 StandardNormal: Distribution<F>,
248 Open01: Distribution<F>,
249{
250 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
251 loop {
253 let x: F = rng.sample(StandardNormal);
254 let v_cbrt = F::one() + self.c * x;
255 if v_cbrt <= F::zero() {
256 continue;
258 }
259
260 let v = v_cbrt * v_cbrt * v_cbrt;
261 let u: F = rng.sample(Open01);
262
263 let x_sqr = x * x;
264 if u < F::one() - F::from(0.0331).unwrap() * x_sqr * x_sqr
265 || u.ln() < F::from(0.5).unwrap() * x_sqr + self.d * (F::one() - v + v.ln())
266 {
267 return self.d * v * self.scale;
268 }
269 }
270 }
271}
272
273#[cfg(test)]
274mod test {
275 use super::*;
276
277 #[test]
278 fn gamma_distributions_can_be_compared() {
279 assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
280 }
281}