rand_distr/weighted/
weighted_tree.rs1use core::ops::SubAssign;
13
14use super::{Error, Weight};
15use crate::Distribution;
16use alloc::vec::Vec;
17use rand::distr::uniform::{SampleBorrow, SampleUniform};
18use rand::Rng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[cfg_attr(
82 feature = "serde",
83 serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
84)]
85#[cfg_attr(
86 feature = "serde",
87 serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
88)]
89#[derive(Clone, Default, Debug, PartialEq)]
90pub struct WeightedTreeIndex<
91 W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
92> {
93 subtotals: Vec<W>,
94}
95
96impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
97 WeightedTreeIndex<W>
98{
99 pub fn new<I>(weights: I) -> Result<Self, Error>
105 where
106 I: IntoIterator,
107 I::Item: SampleBorrow<W>,
108 {
109 let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
110 for weight in subtotals.iter() {
111 if !(*weight >= W::ZERO) {
112 return Err(Error::InvalidWeight);
113 }
114 }
115 let n = subtotals.len();
116 for i in (1..n).rev() {
117 let w = subtotals[i].clone();
118 let parent = (i - 1) / 2;
119 subtotals[parent]
120 .checked_add_assign(&w)
121 .map_err(|()| Error::Overflow)?;
122 }
123 Ok(Self { subtotals })
124 }
125
126 pub fn is_empty(&self) -> bool {
128 self.subtotals.is_empty()
129 }
130
131 pub fn len(&self) -> usize {
133 self.subtotals.len()
134 }
135
136 pub fn is_valid(&self) -> bool {
140 if let Some(weight) = self.subtotals.first() {
141 *weight > W::ZERO
142 } else {
143 false
144 }
145 }
146
147 pub fn get(&self, index: usize) -> W {
149 let left_index = 2 * index + 1;
150 let right_index = 2 * index + 2;
151 let mut w = self.subtotals[index].clone();
152 w -= self.subtotal(left_index);
153 w -= self.subtotal(right_index);
154 w
155 }
156
157 pub fn pop(&mut self) -> Option<W> {
159 self.subtotals.pop().map(|weight| {
160 let mut index = self.len();
161 while index != 0 {
162 index = (index - 1) / 2;
163 self.subtotals[index] -= weight.clone();
164 }
165 weight
166 })
167 }
168
169 pub fn push(&mut self, weight: W) -> Result<(), Error> {
175 if !(weight >= W::ZERO) {
176 return Err(Error::InvalidWeight);
177 }
178 if let Some(total) = self.subtotals.first() {
179 let mut total = total.clone();
180 if total.checked_add_assign(&weight).is_err() {
181 return Err(Error::Overflow);
182 }
183 }
184 let mut index = self.len();
185 self.subtotals.push(weight.clone());
186 while index != 0 {
187 index = (index - 1) / 2;
188 self.subtotals[index].checked_add_assign(&weight).unwrap();
189 }
190 Ok(())
191 }
192
193 pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> {
199 if !(weight >= W::ZERO) {
200 return Err(Error::InvalidWeight);
201 }
202 let old_weight = self.get(index);
203 if weight > old_weight {
204 let mut difference = weight;
205 difference -= old_weight;
206 if let Some(total) = self.subtotals.first() {
207 let mut total = total.clone();
208 if total.checked_add_assign(&difference).is_err() {
209 return Err(Error::Overflow);
210 }
211 }
212 self.subtotals[index]
213 .checked_add_assign(&difference)
214 .unwrap();
215 while index != 0 {
216 index = (index - 1) / 2;
217 self.subtotals[index]
218 .checked_add_assign(&difference)
219 .unwrap();
220 }
221 } else if weight < old_weight {
222 let mut difference = old_weight;
223 difference -= weight;
224 self.subtotals[index] -= difference.clone();
225 while index != 0 {
226 index = (index - 1) / 2;
227 self.subtotals[index] -= difference.clone();
228 }
229 }
230 Ok(())
231 }
232
233 fn subtotal(&self, index: usize) -> W {
234 if index < self.subtotals.len() {
235 self.subtotals[index].clone()
236 } else {
237 W::ZERO
238 }
239 }
240}
241
242impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
243 WeightedTreeIndex<W>
244{
245 pub fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, Error> {
250 let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
251 if total_weight == W::ZERO {
252 return Err(Error::InsufficientNonZero);
253 }
254 let mut target_weight = rng.random_range(W::ZERO..total_weight);
255 let mut index = 0;
256 loop {
257 let left_index = 2 * index + 1;
259 let left_subtotal = self.subtotal(left_index);
260 if target_weight < left_subtotal {
261 index = left_index;
262 continue;
263 }
264 target_weight -= left_subtotal;
265
266 let right_index = 2 * index + 2;
268 let right_subtotal = self.subtotal(right_index);
269 if target_weight < right_subtotal {
270 index = right_index;
271 continue;
272 }
273 target_weight -= right_subtotal;
274
275 break;
277 }
278 assert!(target_weight >= W::ZERO);
279 assert!(target_weight < self.get(index));
280 Ok(index)
281 }
282}
283
284impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
290 for WeightedTreeIndex<W>
291{
292 #[track_caller]
293 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
294 self.try_sample(rng).unwrap()
295 }
296}
297
298#[cfg(test)]
299mod test {
300 use super::*;
301
302 #[test]
303 fn test_no_item_error() {
304 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
305 #[allow(clippy::needless_borrows_for_generic_args)]
306 let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
307 assert_eq!(
308 tree.try_sample(&mut rng).unwrap_err(),
309 Error::InsufficientNonZero
310 );
311 }
312
313 #[test]
314 fn test_overflow_error() {
315 assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow));
316 let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
317 assert_eq!(tree.push(3), Err(Error::Overflow));
318 assert_eq!(tree.update(1, 4), Err(Error::Overflow));
319 tree.update(1, 2).unwrap();
320 }
321
322 #[test]
323 fn test_all_weights_zero_error() {
324 let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
325 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
326 assert_eq!(
327 tree.try_sample(&mut rng).unwrap_err(),
328 Error::InsufficientNonZero
329 );
330 }
331
332 #[test]
333 fn test_invalid_weight_error() {
334 assert_eq!(
335 WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
336 Error::InvalidWeight
337 );
338 #[allow(clippy::needless_borrows_for_generic_args)]
339 let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
340 assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight);
341 tree.push(1).unwrap();
342 assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight);
343 }
344
345 #[test]
346 fn test_tree_modifications() {
347 let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
348 tree.push(3).unwrap();
349 tree.push(5).unwrap();
350 tree.update(0, 0).unwrap();
351 assert_eq!(tree.pop(), Some(5));
352 let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
353 assert_eq!(tree, expected);
354 }
355
356 #[test]
357 #[allow(clippy::needless_range_loop)]
358 fn test_sample_counts_match_probabilities() {
359 let start = 1;
360 let end = 3;
361 let samples = 20;
362 let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
363 let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
364 let mut tree = WeightedTreeIndex::new(weights).unwrap();
365 let mut total_weight = 0.0;
366 let mut weights = alloc::vec![0.0; end];
367 for i in 0..end {
368 tree.update(i, i as f64).unwrap();
369 weights[i] = i as f64;
370 total_weight += i as f64;
371 }
372 for i in 0..start {
373 tree.update(i, 0.0).unwrap();
374 weights[i] = 0.0;
375 total_weight -= i as f64;
376 }
377 let mut counts = alloc::vec![0_usize; end];
378 for _ in 0..samples {
379 let i = tree.sample(&mut rng);
380 counts[i] += 1;
381 }
382 for i in 0..start {
383 assert_eq!(counts[i], 0);
384 }
385 for i in start..end {
386 let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
387 assert!(diff.abs() < 0.05);
388 }
389 }
390}