rand_distr/
pert.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//! The PERT distribution.
9
10use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
11use core::fmt;
12use num_traits::Float;
13use rand::Rng;
14
15/// The [PERT distribution](https://en.wikipedia.org/wiki/PERT_distribution) `PERT(min, max, mode, shape)`.
16///
17/// Similar to the [`Triangular`] distribution, the PERT distribution is
18/// parameterised by a range and a mode within that range. Unlike the
19/// [`Triangular`] distribution, the probability density function of the PERT
20/// distribution is smooth, with a configurable weighting around the mode.
21///
22/// # Plot
23///
24/// The following plot shows the PERT distribution with `min = -1`, `max = 1`,
25/// and various values of `mode` and `shape`.
26///
27/// ![PERT distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/pert.svg)
28///
29/// # Example
30///
31/// ```rust
32/// use rand_distr::{Pert, Distribution};
33///
34/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
35/// let v = d.sample(&mut rand::rng());
36/// println!("{} is from a PERT distribution", v);
37/// ```
38///
39/// [`Triangular`]: crate::Triangular
40#[derive(Clone, Copy, Debug, PartialEq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct Pert<F>
43where
44    F: Float,
45    StandardNormal: Distribution<F>,
46    Exp1: Distribution<F>,
47    Open01: Distribution<F>,
48{
49    min: F,
50    range: F,
51    beta: Beta<F>,
52}
53
54/// Error type returned from [`Pert`] constructors.
55#[derive(Clone, Copy, Debug, PartialEq, Eq)]
56pub enum PertError {
57    /// `max < min` or `min` or `max` is NaN.
58    RangeTooSmall,
59    /// `mode < min` or `mode > max` or `mode` is NaN.
60    ModeRange,
61    /// `shape < 0` or `shape` is NaN
62    ShapeTooSmall,
63}
64
65impl fmt::Display for PertError {
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        f.write_str(match self {
68            PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution",
69            PertError::ModeRange => "mode is outside [min, max] in PERT distribution",
70            PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution",
71        })
72    }
73}
74
75#[cfg(feature = "std")]
76impl std::error::Error for PertError {}
77
78impl<F> Pert<F>
79where
80    F: Float,
81    StandardNormal: Distribution<F>,
82    Exp1: Distribution<F>,
83    Open01: Distribution<F>,
84{
85    /// Construct a PERT distribution with defined `min`, `max`
86    ///
87    /// # Example
88    ///
89    /// ```
90    /// use rand_distr::Pert;
91    /// let pert_dist = Pert::new(0.0, 10.0)
92    ///     .with_shape(3.5)
93    ///     .with_mean(3.0)
94    ///     .unwrap();
95    /// # let _unused: Pert<f64> = pert_dist;
96    /// ```
97    #[allow(clippy::new_ret_no_self)]
98    #[inline]
99    pub fn new(min: F, max: F) -> PertBuilder<F> {
100        let shape = F::from(4.0).unwrap();
101        PertBuilder { min, max, shape }
102    }
103}
104
105/// Struct used to build a [`Pert`]
106#[derive(Debug)]
107pub struct PertBuilder<F> {
108    min: F,
109    max: F,
110    shape: F,
111}
112
113impl<F> PertBuilder<F>
114where
115    F: Float,
116    StandardNormal: Distribution<F>,
117    Exp1: Distribution<F>,
118    Open01: Distribution<F>,
119{
120    /// Set the shape parameter
121    ///
122    /// If not specified, this defaults to 4.
123    #[inline]
124    pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
125        self.shape = shape;
126        self
127    }
128
129    /// Specify the mean
130    #[inline]
131    pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
132        let two = F::from(2.0).unwrap();
133        let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
134        self.with_mode(mode)
135    }
136
137    /// Specify the mode
138    #[inline]
139    pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
140        if !(self.max > self.min) {
141            return Err(PertError::RangeTooSmall);
142        }
143        if !(mode >= self.min && self.max >= mode) {
144            return Err(PertError::ModeRange);
145        }
146        if !(self.shape >= F::from(0.).unwrap()) {
147            return Err(PertError::ShapeTooSmall);
148        }
149
150        let (min, max, shape) = (self.min, self.max, self.shape);
151        let range = max - min;
152        let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
153        let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
154        let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
155        Ok(Pert { min, range, beta })
156    }
157}
158
159impl<F> Distribution<F> for Pert<F>
160where
161    F: Float,
162    StandardNormal: Distribution<F>,
163    Exp1: Distribution<F>,
164    Open01: Distribution<F>,
165{
166    #[inline]
167    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
168        self.beta.sample(rng) * self.range + self.min
169    }
170}
171
172#[cfg(test)]
173mod test {
174    use super::*;
175
176    #[test]
177    fn test_pert() {
178        for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
179            let _distr = Pert::new(min, max).with_mode(mode).unwrap();
180            // TODO: test correctness
181        }
182
183        for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
184            assert!(Pert::new(min, max).with_mode(mode).is_err());
185        }
186    }
187
188    #[test]
189    fn distributions_can_be_compared() {
190        let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
191        let p1 = Pert::new(min, max).with_mode(mode).unwrap();
192        let mean = (min + shape * mode + max) / (shape + 2.0);
193        let p2 = Pert::new(min, max).with_mean(mean).unwrap();
194        assert_eq!(p1, p2);
195    }
196
197    #[test]
198    fn mode_almost_half_range() {
199        assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
200    }
201
202    #[test]
203    fn almost_symmetric_about_zero() {
204        let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
205        assert!(distr.is_ok());
206    }
207
208    #[test]
209    fn almost_symmetric() {
210        let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
211        assert!(distr.is_ok());
212    }
213}