1use crate::{Distribution, StandardNormal};
12use core::fmt;
13use num_traits::Float;
14use rand::Rng;
15
16#[derive(Clone, Copy, Debug, PartialEq)]
60#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
61pub struct SkewNormal<F>
62where
63 F: Float,
64 StandardNormal: Distribution<F>,
65{
66 location: F,
67 scale: F,
68 shape: F,
69}
70
71#[derive(Clone, Copy, Debug, PartialEq, Eq)]
73pub enum Error {
74 ScaleTooSmall,
76 BadShape,
78}
79
80impl fmt::Display for Error {
81 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
82 f.write_str(match self {
83 Error::ScaleTooSmall => {
84 "scale parameter is either non-finite or it is less or equal to zero in skew normal distribution"
85 }
86 Error::BadShape => "shape parameter is non-finite in skew normal distribution",
87 })
88 }
89}
90
91#[cfg(feature = "std")]
92impl std::error::Error for Error {}
93
94impl<F> SkewNormal<F>
95where
96 F: Float,
97 StandardNormal: Distribution<F>,
98{
99 #[inline]
107 pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
108 if !scale.is_finite() || !(scale > F::zero()) {
109 return Err(Error::ScaleTooSmall);
110 }
111 if !shape.is_finite() {
112 return Err(Error::BadShape);
113 }
114 Ok(SkewNormal {
115 location,
116 scale,
117 shape,
118 })
119 }
120
121 pub fn location(&self) -> F {
123 self.location
124 }
125
126 pub fn scale(&self) -> F {
128 self.scale
129 }
130
131 pub fn shape(&self) -> F {
133 self.shape
134 }
135}
136
137impl<F> Distribution<F> for SkewNormal<F>
138where
139 F: Float,
140 StandardNormal: Distribution<F>,
141{
142 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
143 let linear_map = |x: F| -> F { x * self.scale + self.location };
144 let u_1: F = rng.sample(StandardNormal);
145 if self.shape == F::zero() {
146 linear_map(u_1)
147 } else {
148 let u_2 = rng.sample(StandardNormal);
149 let (u, v) = (u_1.max(u_2), u_1.min(u_2));
150 if self.shape == -F::one() {
151 linear_map(v)
152 } else if self.shape == F::one() {
153 linear_map(u)
154 } else {
155 let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
156 / ((F::one() + self.shape * self.shape).sqrt()
157 * F::from(core::f64::consts::SQRT_2).unwrap());
158 linear_map(normalized)
159 }
160 }
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::*;
167
168 fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
169 let mut rng = crate::test::rng(213);
170 let mut buf = [zero; 4];
171 for x in &mut buf {
172 *x = rng.sample(&distr);
173 }
174 assert_eq!(buf, expected);
175 }
176
177 #[test]
178 #[should_panic]
179 fn invalid_scale_nan() {
180 SkewNormal::new(0.0, f64::NAN, 0.0).unwrap();
181 }
182
183 #[test]
184 #[should_panic]
185 fn invalid_scale_zero() {
186 SkewNormal::new(0.0, 0.0, 0.0).unwrap();
187 }
188
189 #[test]
190 #[should_panic]
191 fn invalid_scale_negative() {
192 SkewNormal::new(0.0, -1.0, 0.0).unwrap();
193 }
194
195 #[test]
196 #[should_panic]
197 fn invalid_scale_infinite() {
198 SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap();
199 }
200
201 #[test]
202 #[should_panic]
203 fn invalid_shape_nan() {
204 SkewNormal::new(0.0, 1.0, f64::NAN).unwrap();
205 }
206
207 #[test]
208 #[should_panic]
209 fn invalid_shape_infinite() {
210 SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap();
211 }
212
213 #[test]
214 fn valid_location_nan() {
215 SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
216 }
217
218 #[test]
219 fn skew_normal_value_stability() {
220 test_samples(
221 SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
222 0f32,
223 &[-0.11844189, 0.781378, 0.06563994, -1.1932899],
224 );
225 test_samples(
226 SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
227 0f64,
228 &[
229 -0.11844188827977231,
230 0.7813779637772346,
231 0.06563993969580051,
232 -1.1932899004186373,
233 ],
234 );
235 test_samples(
236 SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
237 0f64,
238 &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY],
239 );
240 test_samples(
241 SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
242 0f64,
243 &[
244 f64::NEG_INFINITY,
245 f64::NEG_INFINITY,
246 f64::NEG_INFINITY,
247 f64::NEG_INFINITY,
248 ],
249 );
250 }
251
252 #[test]
253 fn skew_normal_value_location_nan() {
254 let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
255 let mut rng = crate::test::rng(213);
256 let mut buf = [0.0; 4];
257 for x in &mut buf {
258 *x = rng.sample(skew_normal);
259 }
260 for value in buf.iter() {
261 assert!(value.is_nan());
262 }
263 }
264
265 #[test]
266 fn skew_normal_distributions_can_be_compared() {
267 assert_eq!(
268 SkewNormal::new(1.0, 2.0, 3.0),
269 SkewNormal::new(1.0, 2.0, 3.0)
270 );
271 }
272}