thedes_async_util/
timer.rs

1use std::{
2    mem,
3    pin::Pin,
4    sync::{Arc, Mutex},
5    task::{Context, Poll, Waker},
6    time::Duration,
7};
8
9use tokio::time::{self, Instant, Interval};
10
11type Id = usize;
12
13#[derive(Debug)]
14enum Descriptor {
15    Vaccant,
16    NotWaiting,
17    Waiting(Waker),
18}
19
20#[derive(Debug)]
21struct State {
22    interval: Interval,
23    descriptors: Vec<Descriptor>,
24    sessions: usize,
25    waiting: usize,
26    last_tick: Instant,
27}
28
29impl State {
30    pub fn new(interval: Interval, start: Instant) -> Self {
31        Self {
32            interval,
33            descriptors: Vec::new(),
34            sessions: 0,
35            waiting: 0,
36            last_tick: start,
37        }
38    }
39
40    pub fn new_session(&mut self) -> Id {
41        self.sessions += 1;
42        for (id, descriptor) in self.descriptors.iter_mut().enumerate() {
43            if matches!(descriptor, Descriptor::Vaccant) {
44                *descriptor = Descriptor::NotWaiting;
45                return id;
46            }
47        }
48        let id = self.descriptors.len();
49        self.descriptors.push(Descriptor::NotWaiting);
50        id
51    }
52
53    pub fn drop_session(&mut self, id: Id) {
54        let descriptor = &mut self.descriptors[id];
55        let waiting = match descriptor {
56            Descriptor::Vaccant => {
57                debug_assert!(false, "cannot drop vaccant session");
58                false
59            },
60            Descriptor::NotWaiting => false,
61            Descriptor::Waiting(_) => true,
62        };
63        *descriptor = Descriptor::Vaccant;
64        self.sessions -= 1;
65        if waiting {
66            self.cancel_one_waiting();
67        } else {
68            self.wake_one_if_complete();
69        }
70        while let Some(Descriptor::Vaccant) = self.descriptors.last() {
71            self.descriptors.pop();
72        }
73    }
74
75    pub fn poll_tick(
76        &mut self,
77        cx: &mut Context<'_>,
78        id: Id,
79        last_known_tick: Instant,
80    ) -> Poll<Instant> {
81        let descriptor =
82            mem::replace(&mut self.descriptors[id], Descriptor::NotWaiting);
83
84        match descriptor {
85            Descriptor::Vaccant => {
86                debug_assert!(false, "cannot poll on vaccant descriptor");
87                Poll::Pending
88            },
89            Descriptor::NotWaiting => {
90                if last_known_tick >= self.last_tick {
91                    self.waiting += 1;
92                    self.descriptors[id] =
93                        Descriptor::Waiting(cx.waker().clone());
94                    self.poll_interval(cx)
95                } else {
96                    Poll::Ready(self.last_tick)
97                }
98            },
99            Descriptor::Waiting(_) => {
100                if last_known_tick >= self.last_tick {
101                    self.descriptors[id] =
102                        Descriptor::Waiting(cx.waker().clone());
103                    self.poll_interval(cx)
104                } else {
105                    debug_assert!(
106                        false,
107                        "last known tick should be up to date"
108                    );
109                    Poll::Ready(self.last_tick)
110                }
111            },
112        }
113    }
114
115    pub fn cancel_tick(&mut self, id: Id) -> bool {
116        let descriptor = &mut self.descriptors[id];
117        debug_assert_eq!(
118            false,
119            matches!(descriptor, Descriptor::Vaccant),
120            "cannot cancel on vaccant descriptor",
121        );
122        if matches!(descriptor, Descriptor::Waiting(_)) {
123            *descriptor = Descriptor::NotWaiting;
124            self.cancel_one_waiting();
125            true
126        } else {
127            false
128        }
129    }
130
131    fn poll_interval(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
132        if self.sessions <= self.waiting {
133            debug_assert_eq!(self.sessions, self.waiting);
134            let poll = self.interval.poll_tick(cx);
135            if let Poll::Ready(instant) = poll {
136                self.last_tick = instant;
137                self.waiting = 0;
138                for descriptor in &mut self.descriptors {
139                    if let Descriptor::Waiting(waker) =
140                        mem::replace(descriptor, Descriptor::NotWaiting)
141                    {
142                        waker.wake();
143                    }
144                }
145            }
146            poll
147        } else {
148            Poll::Pending
149        }
150    }
151
152    fn cancel_one_waiting(&mut self) {
153        self.waiting -= 1;
154        self.wake_one_if_complete();
155    }
156
157    fn wake_one_if_complete(&mut self) {
158        if self.sessions <= self.waiting {
159            debug_assert_eq!(self.sessions, self.waiting);
160            for descriptor in &mut self.descriptors {
161                match mem::replace(descriptor, Descriptor::NotWaiting) {
162                    Descriptor::Waiting(waker) => {
163                        waker.wake();
164                        self.waiting -= 1;
165                        break;
166                    },
167                    value => *descriptor = value,
168                }
169            }
170        }
171    }
172}
173
174#[derive(Debug)]
175struct Shared {
176    period: Duration,
177    state: Mutex<State>,
178}
179
180impl Shared {
181    pub fn with_state<F, T>(&self, scope: F) -> T
182    where
183        F: FnOnce(&mut State) -> T,
184    {
185        let mut state = self.state.lock().expect("poisoned lock");
186        scope(&mut state)
187    }
188}
189
190#[derive(Debug, Clone)]
191pub struct Timer {
192    shared: Arc<Shared>,
193}
194
195impl Timer {
196    pub fn new(period: Duration) -> Self {
197        let start = Instant::now();
198        let interval = time::interval(period);
199        let state = Mutex::new(State::new(interval, start));
200        let shared = Shared { period, state };
201        Self { shared: Arc::new(shared) }
202    }
203
204    fn from_session(session: &TickSession) -> Self {
205        Self { shared: session.shared.clone() }
206    }
207
208    pub fn new_session(&self) -> TickSession {
209        let last_known_tick = self.shared.with_state(|state| state.last_tick);
210        TickSession::new(&self.shared, last_known_tick)
211    }
212}
213
214#[derive(Debug)]
215pub struct TickSession {
216    id: Id,
217    last_known_tick: Instant,
218    shared: Arc<Shared>,
219}
220
221impl TickSession {
222    fn new(shared: &Arc<Shared>, last_known_tick: Instant) -> Self {
223        let id = shared.with_state(|state| state.new_session());
224        Self { id, last_known_tick, shared: shared.clone() }
225    }
226
227    pub fn tick(&mut self) -> Tick<'_> {
228        Tick { timer: self }
229    }
230
231    pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
232        let poll = self.shared.with_state(|state| {
233            state.poll_tick(cx, self.id, self.last_known_tick)
234        });
235        if let Poll::Ready(instant) = poll {
236            self.last_known_tick = instant;
237        }
238        poll
239    }
240
241    pub fn cancel_tick(&mut self) -> bool {
242        self.shared.with_state(|state| state.cancel_tick(self.id))
243    }
244
245    pub fn period(&self) -> Duration {
246        self.shared.period
247    }
248
249    pub fn last_tick(&self) -> Instant {
250        self.shared.with_state(|state| state.last_tick)
251    }
252
253    pub fn elapsed(&self) -> Duration {
254        self.last_known_tick.elapsed()
255    }
256
257    pub fn time_left(&self) -> Duration {
258        self.period().saturating_sub(self.elapsed())
259    }
260
261    pub fn timer(&self) -> Timer {
262        Timer::from_session(self)
263    }
264}
265
266impl Clone for TickSession {
267    fn clone(&self) -> Self {
268        Self::new(&self.shared, self.last_known_tick)
269    }
270}
271
272impl Drop for TickSession {
273    fn drop(&mut self) {
274        self.shared.with_state(|state| state.drop_session(self.id))
275    }
276}
277
278#[derive(Debug)]
279pub struct Tick<'a> {
280    timer: &'a mut TickSession,
281}
282
283impl Future for Tick<'_> {
284    type Output = Instant;
285
286    fn poll(
287        mut self: Pin<&mut Self>,
288        cx: &mut Context<'_>,
289    ) -> Poll<Self::Output> {
290        self.timer.poll_tick(cx)
291    }
292}
293
294impl Drop for Tick<'_> {
295    fn drop(&mut self) {
296        self.timer.cancel_tick();
297    }
298}
299
300#[cfg(test)]
301mod test {
302    use std::time::Duration;
303
304    use tokio::{task::JoinSet, time};
305
306    use crate::timer::Timer;
307
308    #[tokio::test]
309    async fn sync_once() {
310        let mut join_set = JoinSet::new();
311        let timer = Timer::new(Duration::from_micros(100));
312        let mut participant = timer.new_session();
313        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
314        for mut participant in timers {
315            join_set.spawn(async move { participant.tick().await });
316        }
317        let answer = participant.tick().await;
318        while let Some(alternative) = join_set.join_next().await {
319            assert_eq!(answer, alternative.unwrap());
320        }
321    }
322
323    #[tokio::test]
324    async fn sync_twice() {
325        let mut join_set = JoinSet::new();
326        let timer = Timer::new(Duration::from_micros(100));
327        let mut participant = timer.new_session();
328        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
329        for mut participant in timers {
330            join_set.spawn(async move {
331                let first = participant.tick().await;
332                let second = participant.tick().await;
333                (first, second)
334            });
335        }
336        let first_answer = participant.tick().await;
337        let second_answer = participant.tick().await;
338        while let Some(alternative) = join_set.join_next().await {
339            let (first_alternative, second_alternative) = alternative.unwrap();
340            assert_eq!(first_answer, first_alternative);
341            assert_eq!(second_answer, second_alternative);
342        }
343    }
344
345    #[tokio::test]
346    async fn sync_twice_with_novices() {
347        let mut join_set = JoinSet::new();
348        let timer = Timer::new(Duration::from_micros(100));
349        let mut participant = timer.new_session();
350        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
351        for mut participant in timers {
352            join_set.spawn(async move {
353                let first = participant.tick().await;
354                let second = participant.tick().await;
355                (Some(first), second)
356            });
357        }
358        let first_answer = participant.tick().await;
359        let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
360        for mut participant in timers {
361            join_set.spawn(async move {
362                let second = participant.tick().await;
363                (None, second)
364            });
365        }
366        let second_answer = participant.tick().await;
367
368        let mut has_first = 0;
369        while let Some(alternative) = join_set.join_next().await {
370            let (maybe_first_alternative, second_alternative) =
371                alternative.unwrap();
372            if let Some(first_alternative) = maybe_first_alternative {
373                assert_eq!(first_answer, first_alternative);
374                has_first += 1;
375            }
376            assert_eq!(second_answer, second_alternative);
377        }
378        assert_ne!(has_first, 0);
379    }
380
381    #[tokio::test]
382    async fn sync_thrice_with_novices_and_leavos() {
383        let mut join_set = JoinSet::new();
384        let timer = Timer::new(Duration::from_micros(100));
385        let mut participant = timer.new_session();
386        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
387        for (i, mut participant) in timers.into_iter().enumerate() {
388            join_set.spawn(async move {
389                let first = participant.tick().await;
390                let second = participant.tick().await;
391                let third =
392                    if i < 4 { None } else { Some(participant.tick().await) };
393                (Some(first), second, third)
394            });
395        }
396        let first_answer = participant.tick().await;
397        let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
398        for mut participant in timers {
399            join_set.spawn(async move {
400                let second = participant.tick().await;
401                let third = participant.tick().await;
402                (None, second, Some(third))
403            });
404        }
405        let second_answer = participant.tick().await;
406        let third_answer = participant.tick().await;
407
408        let mut has_first = 0;
409        let mut has_third = 0;
410        while let Some(alternative) = join_set.join_next().await {
411            let (
412                maybe_first_alternative,
413                second_alternative,
414                maybe_third_alternative,
415            ) = alternative.unwrap();
416            if let Some(first_alternative) = maybe_first_alternative {
417                assert_eq!(first_answer, first_alternative);
418                has_first += 1;
419            }
420            assert_eq!(second_answer, second_alternative);
421            if let Some(third_alternative) = maybe_third_alternative {
422                assert_eq!(third_answer, third_alternative);
423                has_third += 1;
424            }
425        }
426        assert_ne!(has_first, 0);
427        assert_ne!(has_third, 0);
428    }
429
430    #[tokio::test]
431    async fn sync_thrice_with_novices_and_leavos_mixed() {
432        let mut join_set = JoinSet::new();
433        let timer = Timer::new(Duration::from_micros(100));
434        let mut participant = timer.new_session();
435        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
436        for (i, mut participant) in timers.into_iter().enumerate() {
437            join_set.spawn(async move {
438                let first = participant.tick().await;
439                let (second, third) = if i < 13 {
440                    let second = participant.tick().await;
441                    let third = if i < 4 {
442                        None
443                    } else {
444                        Some(participant.tick().await)
445                    };
446                    (Some(second), third)
447                } else {
448                    (None, None)
449                };
450                (Some(first), second, third)
451            });
452        }
453        let first_answer = participant.tick().await;
454        let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
455        for mut participant in timers {
456            join_set.spawn(async move {
457                let second = participant.tick().await;
458                let third = participant.tick().await;
459                (None, Some(second), Some(third))
460            });
461        }
462        let second_answer = participant.tick().await;
463        let third_answer = participant.tick().await;
464
465        let mut has_first = 0;
466        let mut has_second = 0;
467        let mut has_third = 0;
468        while let Some(alternative) = join_set.join_next().await {
469            let (
470                maybe_first_alternative,
471                maybe_second_alternative,
472                maybe_third_alternative,
473            ) = alternative.unwrap();
474            if let Some(first_alternative) = maybe_first_alternative {
475                assert_eq!(first_answer, first_alternative);
476                has_first += 1;
477            }
478            if let Some(second_alternative) = maybe_second_alternative {
479                assert_eq!(second_answer, second_alternative);
480                has_second += 1;
481            }
482            if let Some(third_alternative) = maybe_third_alternative {
483                assert_eq!(third_answer, third_alternative);
484                has_third += 1;
485            }
486        }
487        assert_ne!(has_first, 0);
488        assert_ne!(has_second, 0);
489        assert_ne!(has_third, 0);
490    }
491
492    #[tokio::test]
493    async fn sync_four_times_with_novices_and_leavos_and_novices_again() {
494        let mut join_set = JoinSet::new();
495        let timer = Timer::new(Duration::from_micros(100));
496        let mut participant = timer.new_session();
497        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
498        for (i, mut participant) in timers.into_iter().enumerate() {
499            join_set.spawn(async move {
500                let first = participant.tick().await;
501                let second = participant.tick().await;
502                let third =
503                    if i < 9 { None } else { Some(participant.tick().await) };
504                (Some(first), Some(second), third, None)
505            });
506        }
507        let first_answer = participant.tick().await;
508        let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
509        for mut participant in timers {
510            join_set.spawn(async move {
511                let second = participant.tick().await;
512                let third = participant.tick().await;
513                let fourth = participant.tick().await;
514                (None, Some(second), Some(third), Some(fourth))
515            });
516        }
517        let second_answer = participant.tick().await;
518        let third_answer = participant.tick().await;
519        let timers = (0 .. 11).map(|_| participant.clone()).collect::<Vec<_>>();
520        for mut participant in timers {
521            join_set.spawn(async move {
522                let fourth = participant.tick().await;
523                (None, None, None, Some(fourth))
524            });
525        }
526        let fourth_answer = participant.tick().await;
527
528        let mut has_first = 0;
529        let mut has_second = 0;
530        let mut has_third = 0;
531        let mut has_fourth = 0;
532        while let Some(alternative) = join_set.join_next().await {
533            let (
534                maybe_first_alternative,
535                maybe_second_alternative,
536                maybe_third_alternative,
537                maybe_fourth_alternative,
538            ) = alternative.unwrap();
539            if let Some(first_alternative) = maybe_first_alternative {
540                assert_eq!(first_answer, first_alternative);
541                has_first += 1;
542            }
543            if let Some(second_alternative) = maybe_second_alternative {
544                assert_eq!(second_answer, second_alternative);
545                has_second += 1;
546            }
547            if let Some(third_alternative) = maybe_third_alternative {
548                assert_eq!(third_answer, third_alternative);
549                has_third += 1;
550            }
551            if let Some(fourth_alternative) = maybe_fourth_alternative {
552                assert_eq!(fourth_answer, fourth_alternative);
553                has_fourth += 1;
554            }
555        }
556        assert_ne!(has_first, 0);
557        assert_ne!(has_second, 0);
558        assert_ne!(has_third, 0);
559        assert_ne!(has_fourth, 0);
560    }
561
562    #[tokio::test]
563    async fn sync_twice_only_one_left() {
564        let mut join_set = JoinSet::new();
565        let timer = Timer::new(Duration::from_micros(100));
566        let mut participant = timer.new_session();
567        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
568        for mut participant in timers {
569            join_set.spawn(async move { participant.tick().await });
570        }
571        let first_answer = participant.tick().await;
572        let _second_answer = participant.tick().await;
573        while let Some(alternative) = join_set.join_next().await {
574            assert_eq!(first_answer, alternative.unwrap());
575        }
576    }
577
578    #[tokio::test]
579    async fn sync_with_cancel() {
580        let mut join_set = JoinSet::new();
581        let timer = Timer::new(Duration::from_millis(10));
582        let mut participant = timer.new_session();
583        participant.tick().await;
584        let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
585        for (i, mut participant) in timers.into_iter().enumerate() {
586            join_set.spawn(async move {
587                if i == 7 {
588                    let sleep_first = tokio::select! {
589                        _ = participant.tick() => false,
590                        _ = time::sleep(Duration::from_micros(1)) => true,
591                    };
592                    assert!(sleep_first);
593                    None
594                } else {
595                    Some(participant.tick().await)
596                }
597            });
598        }
599        let answer = participant.tick().await;
600        let mut other_completion = 0;
601        while let Some(maybe_alternative) = join_set.join_next().await {
602            if let Some(alternative) = maybe_alternative.unwrap() {
603                assert_eq!(answer, alternative);
604                other_completion += 1;
605            }
606        }
607        assert_eq!(other_completion, 15);
608    }
609}