1use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
11use core::fmt;
12use num_traits::Float;
13use rand::Rng;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
41#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
42pub struct Pert<F>
43where
44 F: Float,
45 StandardNormal: Distribution<F>,
46 Exp1: Distribution<F>,
47 Open01: Distribution<F>,
48{
49 min: F,
50 range: F,
51 beta: Beta<F>,
52}
53
54#[derive(Clone, Copy, Debug, PartialEq, Eq)]
56pub enum PertError {
57 RangeTooSmall,
59 ModeRange,
61 ShapeTooSmall,
63}
64
65impl fmt::Display for PertError {
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 f.write_str(match self {
68 PertError::RangeTooSmall => "requirement min < max is not met in PERT distribution",
69 PertError::ModeRange => "mode is outside [min, max] in PERT distribution",
70 PertError::ShapeTooSmall => "shape < 0 or is NaN in PERT distribution",
71 })
72 }
73}
74
75#[cfg(feature = "std")]
76impl std::error::Error for PertError {}
77
78impl<F> Pert<F>
79where
80 F: Float,
81 StandardNormal: Distribution<F>,
82 Exp1: Distribution<F>,
83 Open01: Distribution<F>,
84{
85 #[allow(clippy::new_ret_no_self)]
98 #[inline]
99 pub fn new(min: F, max: F) -> PertBuilder<F> {
100 let shape = F::from(4.0).unwrap();
101 PertBuilder { min, max, shape }
102 }
103}
104
105#[derive(Debug)]
107pub struct PertBuilder<F> {
108 min: F,
109 max: F,
110 shape: F,
111}
112
113impl<F> PertBuilder<F>
114where
115 F: Float,
116 StandardNormal: Distribution<F>,
117 Exp1: Distribution<F>,
118 Open01: Distribution<F>,
119{
120 #[inline]
124 pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
125 self.shape = shape;
126 self
127 }
128
129 #[inline]
131 pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
132 let two = F::from(2.0).unwrap();
133 let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
134 self.with_mode(mode)
135 }
136
137 #[inline]
139 pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
140 if !(self.max > self.min) {
141 return Err(PertError::RangeTooSmall);
142 }
143 if !(mode >= self.min && self.max >= mode) {
144 return Err(PertError::ModeRange);
145 }
146 if !(self.shape >= F::from(0.).unwrap()) {
147 return Err(PertError::ShapeTooSmall);
148 }
149
150 let (min, max, shape) = (self.min, self.max, self.shape);
151 let range = max - min;
152 let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
153 let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
154 let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
155 Ok(Pert { min, range, beta })
156 }
157}
158
159impl<F> Distribution<F> for Pert<F>
160where
161 F: Float,
162 StandardNormal: Distribution<F>,
163 Exp1: Distribution<F>,
164 Open01: Distribution<F>,
165{
166 #[inline]
167 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
168 self.beta.sample(rng) * self.range + self.min
169 }
170}
171
172#[cfg(test)]
173mod test {
174 use super::*;
175
176 #[test]
177 fn test_pert() {
178 for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
179 let _distr = Pert::new(min, max).with_mode(mode).unwrap();
180 }
182
183 for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
184 assert!(Pert::new(min, max).with_mode(mode).is_err());
185 }
186 }
187
188 #[test]
189 fn distributions_can_be_compared() {
190 let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
191 let p1 = Pert::new(min, max).with_mode(mode).unwrap();
192 let mean = (min + shape * mode + max) / (shape + 2.0);
193 let p2 = Pert::new(min, max).with_mean(mean).unwrap();
194 assert_eq!(p1, p2);
195 }
196
197 #[test]
198 fn mode_almost_half_range() {
199 assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
200 }
201
202 #[test]
203 fn almost_symmetric_about_zero() {
204 let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
205 assert!(distr.is_ok());
206 }
207
208 #[test]
209 fn almost_symmetric() {
210 let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
211 assert!(distr.is_ok());
212 }
213}