rand_distr/weighted/
weighted_tree.rs

1// Copyright 2024 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! This module contains an implementation of a tree structure for sampling random
10//! indices with probabilities proportional to a collection of weights.
11
12use core::ops::SubAssign;
13
14use super::{Error, Weight};
15use crate::Distribution;
16use alloc::vec::Vec;
17use rand::distr::uniform::{SampleBorrow, SampleUniform};
18use rand::Rng;
19#[cfg(feature = "serde")]
20use serde::{Deserialize, Serialize};
21
22/// A distribution using weighted sampling to pick a discretely selected item.
23///
24/// Sampling a [`WeightedTreeIndex<W>`] distribution returns the index of a randomly
25/// selected element from the vector used to create the [`WeightedTreeIndex<W>`].
26/// The chance of a given element being picked is proportional to the value of
27/// the element. The weights can have any type `W` for which an implementation of
28/// [`Weight`] exists.
29///
30/// # Key differences
31///
32/// The main distinction between [`WeightedTreeIndex<W>`] and [`WeightedIndex<W>`]
33/// lies in the internal representation of weights. In [`WeightedTreeIndex<W>`],
34/// weights are structured as a tree, which is optimized for frequent updates of the weights.
35///
36/// # Caution: Floating point types
37///
38/// When utilizing [`WeightedTreeIndex<W>`] with floating point types (such as f32 or f64),
39/// exercise caution due to the inherent nature of floating point arithmetic. Floating point types
40/// are susceptible to numerical rounding errors. Since operations on floating point weights are
41/// repeated numerous times, rounding errors can accumulate, potentially leading to noticeable
42/// deviations from the expected behavior.
43///
44/// Ideally, use fixed point or integer types whenever possible.
45///
46/// # Performance
47///
48/// A [`WeightedTreeIndex<W>`] with `n` elements requires `O(n)` memory.
49///
50/// Time complexity for the operations of a [`WeightedTreeIndex<W>`] are:
51/// * Constructing: Building the initial tree from an iterator of weights takes `O(n)` time.
52/// * Sampling: Choosing an index (traversing down the tree) requires `O(log n)` time.
53/// * Weight Update: Modifying a weight (traversing up the tree), requires `O(log n)` time.
54/// * Weight Addition (Pushing): Adding a new weight (traversing up the tree), requires `O(log n)` time.
55/// * Weight Removal (Popping): Removing a weight (traversing up the tree), requires `O(log n)` time.
56///
57/// # Example
58///
59/// ```
60/// use rand_distr::weighted::WeightedTreeIndex;
61/// use rand::prelude::*;
62///
63/// let choices = vec!['a', 'b', 'c'];
64/// let weights = vec![2, 0];
65/// let mut dist = WeightedTreeIndex::new(&weights).unwrap();
66/// dist.push(1).unwrap();
67/// dist.update(1, 1).unwrap();
68/// let mut rng = rand::rng();
69/// let mut samples = [0; 3];
70/// for _ in 0..100 {
71///     // 50% chance to print 'a', 25% chance to print 'b', 25% chance to print 'c'
72///     let i = dist.sample(&mut rng);
73///     samples[i] += 1;
74/// }
75/// println!("Results: {:?}", choices.iter().zip(samples.iter()).collect::<Vec<_>>());
76/// ```
77///
78/// [`WeightedTreeIndex<W>`]: WeightedTreeIndex
79/// [`WeightedIndex<W>`]: super::WeightedIndex
80#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
81#[cfg_attr(
82    feature = "serde",
83    serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
84)]
85#[cfg_attr(
86    feature = "serde",
87    serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
88)]
89#[derive(Clone, Default, Debug, PartialEq)]
90pub struct WeightedTreeIndex<
91    W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight,
92> {
93    subtotals: Vec<W>,
94}
95
96impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
97    WeightedTreeIndex<W>
98{
99    /// Creates a new [`WeightedTreeIndex`] from a slice of weights.
100    ///
101    /// Error cases:
102    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
103    /// -   [`Error::Overflow`] when the sum of all weights overflows.
104    pub fn new<I>(weights: I) -> Result<Self, Error>
105    where
106        I: IntoIterator,
107        I::Item: SampleBorrow<W>,
108    {
109        let mut subtotals: Vec<W> = weights.into_iter().map(|x| x.borrow().clone()).collect();
110        for weight in subtotals.iter() {
111            if !(*weight >= W::ZERO) {
112                return Err(Error::InvalidWeight);
113            }
114        }
115        let n = subtotals.len();
116        for i in (1..n).rev() {
117            let w = subtotals[i].clone();
118            let parent = (i - 1) / 2;
119            subtotals[parent]
120                .checked_add_assign(&w)
121                .map_err(|()| Error::Overflow)?;
122        }
123        Ok(Self { subtotals })
124    }
125
126    /// Returns `true` if the tree contains no weights.
127    pub fn is_empty(&self) -> bool {
128        self.subtotals.is_empty()
129    }
130
131    /// Returns the number of weights.
132    pub fn len(&self) -> usize {
133        self.subtotals.len()
134    }
135
136    /// Returns `true` if we can sample.
137    ///
138    /// This is the case if the total weight of the tree is greater than zero.
139    pub fn is_valid(&self) -> bool {
140        if let Some(weight) = self.subtotals.first() {
141            *weight > W::ZERO
142        } else {
143            false
144        }
145    }
146
147    /// Gets the weight at an index.
148    pub fn get(&self, index: usize) -> W {
149        let left_index = 2 * index + 1;
150        let right_index = 2 * index + 2;
151        let mut w = self.subtotals[index].clone();
152        w -= self.subtotal(left_index);
153        w -= self.subtotal(right_index);
154        w
155    }
156
157    /// Removes the last weight and returns it, or [`None`] if it is empty.
158    pub fn pop(&mut self) -> Option<W> {
159        self.subtotals.pop().map(|weight| {
160            let mut index = self.len();
161            while index != 0 {
162                index = (index - 1) / 2;
163                self.subtotals[index] -= weight.clone();
164            }
165            weight
166        })
167    }
168
169    /// Appends a new weight at the end.
170    ///
171    /// Error cases:
172    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
173    /// -   [`Error::Overflow`] when the sum of all weights overflows.
174    pub fn push(&mut self, weight: W) -> Result<(), Error> {
175        if !(weight >= W::ZERO) {
176            return Err(Error::InvalidWeight);
177        }
178        if let Some(total) = self.subtotals.first() {
179            let mut total = total.clone();
180            if total.checked_add_assign(&weight).is_err() {
181                return Err(Error::Overflow);
182            }
183        }
184        let mut index = self.len();
185        self.subtotals.push(weight.clone());
186        while index != 0 {
187            index = (index - 1) / 2;
188            self.subtotals[index].checked_add_assign(&weight).unwrap();
189        }
190        Ok(())
191    }
192
193    /// Updates the weight at an index.
194    ///
195    /// Error cases:
196    /// -   [`Error::InvalidWeight`] when a weight is not-a-number or negative.
197    /// -   [`Error::Overflow`] when the sum of all weights overflows.
198    pub fn update(&mut self, mut index: usize, weight: W) -> Result<(), Error> {
199        if !(weight >= W::ZERO) {
200            return Err(Error::InvalidWeight);
201        }
202        let old_weight = self.get(index);
203        if weight > old_weight {
204            let mut difference = weight;
205            difference -= old_weight;
206            if let Some(total) = self.subtotals.first() {
207                let mut total = total.clone();
208                if total.checked_add_assign(&difference).is_err() {
209                    return Err(Error::Overflow);
210                }
211            }
212            self.subtotals[index]
213                .checked_add_assign(&difference)
214                .unwrap();
215            while index != 0 {
216                index = (index - 1) / 2;
217                self.subtotals[index]
218                    .checked_add_assign(&difference)
219                    .unwrap();
220            }
221        } else if weight < old_weight {
222            let mut difference = old_weight;
223            difference -= weight;
224            self.subtotals[index] -= difference.clone();
225            while index != 0 {
226                index = (index - 1) / 2;
227                self.subtotals[index] -= difference.clone();
228            }
229        }
230        Ok(())
231    }
232
233    fn subtotal(&self, index: usize) -> W {
234        if index < self.subtotals.len() {
235            self.subtotals[index].clone()
236        } else {
237            W::ZERO
238        }
239    }
240}
241
242impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight>
243    WeightedTreeIndex<W>
244{
245    /// Samples a randomly selected index from the weighted distribution.
246    ///
247    /// Returns an error if there are no elements or all weights are zero. This
248    /// is unlike [`Distribution::sample`], which panics in those cases.
249    pub fn try_sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<usize, Error> {
250        let total_weight = self.subtotals.first().cloned().unwrap_or(W::ZERO);
251        if total_weight == W::ZERO {
252            return Err(Error::InsufficientNonZero);
253        }
254        let mut target_weight = rng.random_range(W::ZERO..total_weight);
255        let mut index = 0;
256        loop {
257            // Maybe descend into the left sub tree.
258            let left_index = 2 * index + 1;
259            let left_subtotal = self.subtotal(left_index);
260            if target_weight < left_subtotal {
261                index = left_index;
262                continue;
263            }
264            target_weight -= left_subtotal;
265
266            // Maybe descend into the right sub tree.
267            let right_index = 2 * index + 2;
268            let right_subtotal = self.subtotal(right_index);
269            if target_weight < right_subtotal {
270                index = right_index;
271                continue;
272            }
273            target_weight -= right_subtotal;
274
275            // Otherwise we found the index with the target weight.
276            break;
277        }
278        assert!(target_weight >= W::ZERO);
279        assert!(target_weight < self.get(index));
280        Ok(index)
281    }
282}
283
284/// Samples a randomly selected index from the weighted distribution.
285///
286/// Caution: This method panics if there are no elements or all weights are zero. However,
287/// it is guaranteed that this method will not panic if a call to [`WeightedTreeIndex::is_valid`]
288/// returns `true`.
289impl<W: Clone + PartialEq + PartialOrd + SampleUniform + SubAssign<W> + Weight> Distribution<usize>
290    for WeightedTreeIndex<W>
291{
292    #[track_caller]
293    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
294        self.try_sample(rng).unwrap()
295    }
296}
297
298#[cfg(test)]
299mod test {
300    use super::*;
301
302    #[test]
303    fn test_no_item_error() {
304        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
305        #[allow(clippy::needless_borrows_for_generic_args)]
306        let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
307        assert_eq!(
308            tree.try_sample(&mut rng).unwrap_err(),
309            Error::InsufficientNonZero
310        );
311    }
312
313    #[test]
314    fn test_overflow_error() {
315        assert_eq!(WeightedTreeIndex::new([i32::MAX, 2]), Err(Error::Overflow));
316        let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
317        assert_eq!(tree.push(3), Err(Error::Overflow));
318        assert_eq!(tree.update(1, 4), Err(Error::Overflow));
319        tree.update(1, 2).unwrap();
320    }
321
322    #[test]
323    fn test_all_weights_zero_error() {
324        let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
325        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
326        assert_eq!(
327            tree.try_sample(&mut rng).unwrap_err(),
328            Error::InsufficientNonZero
329        );
330    }
331
332    #[test]
333    fn test_invalid_weight_error() {
334        assert_eq!(
335            WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
336            Error::InvalidWeight
337        );
338        #[allow(clippy::needless_borrows_for_generic_args)]
339        let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
340        assert_eq!(tree.push(-1).unwrap_err(), Error::InvalidWeight);
341        tree.push(1).unwrap();
342        assert_eq!(tree.update(0, -1).unwrap_err(), Error::InvalidWeight);
343    }
344
345    #[test]
346    fn test_tree_modifications() {
347        let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
348        tree.push(3).unwrap();
349        tree.push(5).unwrap();
350        tree.update(0, 0).unwrap();
351        assert_eq!(tree.pop(), Some(5));
352        let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
353        assert_eq!(tree, expected);
354    }
355
356    #[test]
357    #[allow(clippy::needless_range_loop)]
358    fn test_sample_counts_match_probabilities() {
359        let start = 1;
360        let end = 3;
361        let samples = 20;
362        let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
363        let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
364        let mut tree = WeightedTreeIndex::new(weights).unwrap();
365        let mut total_weight = 0.0;
366        let mut weights = alloc::vec![0.0; end];
367        for i in 0..end {
368            tree.update(i, i as f64).unwrap();
369            weights[i] = i as f64;
370            total_weight += i as f64;
371        }
372        for i in 0..start {
373            tree.update(i, 0.0).unwrap();
374            weights[i] = 0.0;
375            total_weight -= i as f64;
376        }
377        let mut counts = alloc::vec![0_usize; end];
378        for _ in 0..samples {
379            let i = tree.sample(&mut rng);
380            counts[i] += 1;
381        }
382        for i in 0..start {
383            assert_eq!(counts[i], 0);
384        }
385        for i in start..end {
386            let diff = counts[i] as f64 / samples as f64 - weights[i] / total_weight;
387            assert!(diff.abs() < 0.05);
388        }
389    }
390}