rand_distr/
skew_normal.rs

1// Copyright 2021 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! The Skew Normal distribution `SN(ξ, ω, α)`.
10
11use crate::{Distribution, StandardNormal};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16/// The [skew normal distribution](https://en.wikipedia.org/wiki/Skew_normal_distribution) `SN(ξ, ω, α)`.
17///
18/// The skew normal distribution is a generalization of the
19/// [`Normal`](crate::Normal) distribution to allow for non-zero skewness.
20/// It has location parameter `ξ` (`xi`), scale parameter `ω` (`omega`),
21/// and shape parameter `α` (`alpha`).
22///
23/// The `ξ` and `ω` parameters correspond to the mean `μ` and standard
24/// deviation `σ` of the normal distribution, respectively.
25/// The `α` parameter controls the skewness.
26///
27/// # Density function
28///
29/// It has the density function, for `scale > 0`,
30/// `f(x) = 2 / scale * phi((x - location) / scale) * Phi(alpha * (x - location) / scale)`
31/// where `phi` and `Phi` are the density and distribution of a standard normal variable.
32///
33/// # Plot
34///
35/// The following plot shows the skew normal distribution with `location = 0`, `scale = 1`
36/// (corresponding to the [`standard normal distribution`](crate::StandardNormal)), and
37/// various values of `shape`.
38///
39/// ![Skew normal distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/skew_normal.svg)
40///
41/// # Example
42///
43/// ```
44/// use rand_distr::{SkewNormal, Distribution};
45///
46/// // location 2, scale 3, shape 1
47/// let skew_normal = SkewNormal::new(2.0, 3.0, 1.0).unwrap();
48/// let v = skew_normal.sample(&mut rand::rng());
49/// println!("{} is from a SN(2, 3, 1) distribution", v)
50/// ```
51///
52/// # Implementation details
53///
54/// We are using the algorithm from [A Method to Simulate the Skew Normal Distribution].
55///
56/// [skew normal distribution]: https://en.wikipedia.org/wiki/Skew_normal_distribution
57/// [`Normal`]: struct.Normal.html
58/// [A Method to Simulate the Skew Normal Distribution]: https://dx.doi.org/10.4236/am.2014.513201
59#[derive(Clone, Copy, Debug, PartialEq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct SkewNormal<F>
62where
63    F: Float,
64    StandardNormal: Distribution<F>,
65{
66    location: F,
67    scale: F,
68    shape: F,
69}
70
71/// Error type returned from [`SkewNormal::new`].
72#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum Error {
74    /// The scale parameter is not finite or it is less or equal to zero.
75    ScaleTooSmall,
76    /// The shape parameter is not finite.
77    BadShape,
78}
79
80impl fmt::Display for Error {
81    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82        f.write_str(match self {
83            Error::ScaleTooSmall => {
84                "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution"
85            }
86            Error::BadShape => "shape parameter is non-finite in skew normal distribution",
87        })
88    }
89}
90
91#[cfg(feature = "std")]
92impl std::error::Error for Error {}
93
94impl<F> SkewNormal<F>
95where
96    F: Float,
97    StandardNormal: Distribution<F>,
98{
99    /// Construct, from location, scale and shape.
100    ///
101    /// Parameters:
102    ///
103    /// -   location (unrestricted)
104    /// -   scale (must be finite and larger than zero)
105    /// -   shape (must be finite)
106    #[inline]
107    pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
108        if !scale.is_finite() || !(scale > F::zero()) {
109            return Err(Error::ScaleTooSmall);
110        }
111        if !shape.is_finite() {
112            return Err(Error::BadShape);
113        }
114        Ok(SkewNormal {
115            location,
116            scale,
117            shape,
118        })
119    }
120
121    /// Returns the location of the distribution.
122    pub fn location(&self) -> F {
123        self.location
124    }
125
126    /// Returns the scale of the distribution.
127    pub fn scale(&self) -> F {
128        self.scale
129    }
130
131    /// Returns the shape of the distribution.
132    pub fn shape(&self) -> F {
133        self.shape
134    }
135}
136
137impl<F> Distribution<F> for SkewNormal<F>
138where
139    F: Float,
140    StandardNormal: Distribution<F>,
141{
142    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
143        let linear_map = |x: F| -> F { x * self.scale + self.location };
144        let u_1: F = rng.sample(StandardNormal);
145        if self.shape == F::zero() {
146            linear_map(u_1)
147        } else {
148            let u_2 = rng.sample(StandardNormal);
149            let (u, v) = (u_1.max(u_2), u_1.min(u_2));
150            if self.shape == -F::one() {
151                linear_map(v)
152            } else if self.shape == F::one() {
153                linear_map(u)
154            } else {
155                let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
156                    / ((F::one() + self.shape * self.shape).sqrt()
157                        * F::from(core::f64::consts::SQRT_2).unwrap());
158                linear_map(normalized)
159            }
160        }
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
169        let mut rng = crate::test::rng(213);
170        let mut buf = [zero; 4];
171        for x in &mut buf {
172            *x = rng.sample(&distr);
173        }
174        assert_eq!(buf, expected);
175    }
176
177    #[test]
178    #[should_panic]
179    fn invalid_scale_nan() {
180        SkewNormal::new(0.0, f64::NAN, 0.0).unwrap();
181    }
182
183    #[test]
184    #[should_panic]
185    fn invalid_scale_zero() {
186        SkewNormal::new(0.0, 0.0, 0.0).unwrap();
187    }
188
189    #[test]
190    #[should_panic]
191    fn invalid_scale_negative() {
192        SkewNormal::new(0.0, -1.0, 0.0).unwrap();
193    }
194
195    #[test]
196    #[should_panic]
197    fn invalid_scale_infinite() {
198        SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap();
199    }
200
201    #[test]
202    #[should_panic]
203    fn invalid_shape_nan() {
204        SkewNormal::new(0.0, 1.0, f64::NAN).unwrap();
205    }
206
207    #[test]
208    #[should_panic]
209    fn invalid_shape_infinite() {
210        SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap();
211    }
212
213    #[test]
214    fn valid_location_nan() {
215        SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
216    }
217
218    #[test]
219    fn skew_normal_value_stability() {
220        test_samples(
221            SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
222            0f32,
223            &[-0.11844189, 0.781378, 0.06563994, -1.1932899],
224        );
225        test_samples(
226            SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
227            0f64,
228            &[
229                -0.11844188827977231,
230                0.7813779637772346,
231                0.06563993969580051,
232                -1.1932899004186373,
233            ],
234        );
235        test_samples(
236            SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
237            0f64,
238            &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY],
239        );
240        test_samples(
241            SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
242            0f64,
243            &[
244                f64::NEG_INFINITY,
245                f64::NEG_INFINITY,
246                f64::NEG_INFINITY,
247                f64::NEG_INFINITY,
248            ],
249        );
250    }
251
252    #[test]
253    fn skew_normal_value_location_nan() {
254        let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
255        let mut rng = crate::test::rng(213);
256        let mut buf = [0.0; 4];
257        for x in &mut buf {
258            *x = rng.sample(skew_normal);
259        }
260        for value in buf.iter() {
261            assert!(value.is_nan());
262        }
263    }
264
265    #[test]
266    fn skew_normal_distributions_can_be_compared() {
267        assert_eq!(
268            SkewNormal::new(1.0, 2.0, 3.0),
269            SkewNormal::new(1.0, 2.0, 3.0)
270        );
271    }
272}