1use std::{
2 mem,
3 pin::Pin,
4 sync::{Arc, Mutex},
5 task::{Context, Poll, Waker},
6 time::Duration,
7};
8
9use tokio::time::{self, Instant, Interval};
10
11type Id = usize;
12
13#[derive(Debug)]
14enum Descriptor {
15 Vaccant,
16 NotWaiting,
17 Waiting(Waker),
18}
19
20#[derive(Debug)]
21struct State {
22 interval: Interval,
23 descriptors: Vec<Descriptor>,
24 sessions: usize,
25 waiting: usize,
26 last_tick: Instant,
27}
28
29impl State {
30 pub fn new(interval: Interval, start: Instant) -> Self {
31 Self {
32 interval,
33 descriptors: Vec::new(),
34 sessions: 0,
35 waiting: 0,
36 last_tick: start,
37 }
38 }
39
40 pub fn new_session(&mut self) -> Id {
41 self.sessions += 1;
42 for (id, descriptor) in self.descriptors.iter_mut().enumerate() {
43 if matches!(descriptor, Descriptor::Vaccant) {
44 *descriptor = Descriptor::NotWaiting;
45 return id;
46 }
47 }
48 let id = self.descriptors.len();
49 self.descriptors.push(Descriptor::NotWaiting);
50 id
51 }
52
53 pub fn drop_session(&mut self, id: Id) {
54 let descriptor = &mut self.descriptors[id];
55 let waiting = match descriptor {
56 Descriptor::Vaccant => {
57 debug_assert!(false, "cannot drop vaccant session");
58 false
59 },
60 Descriptor::NotWaiting => false,
61 Descriptor::Waiting(_) => true,
62 };
63 *descriptor = Descriptor::Vaccant;
64 self.sessions -= 1;
65 if waiting {
66 self.cancel_one_waiting();
67 } else {
68 self.wake_one_if_complete();
69 }
70 while let Some(Descriptor::Vaccant) = self.descriptors.last() {
71 self.descriptors.pop();
72 }
73 }
74
75 pub fn poll_tick(
76 &mut self,
77 cx: &mut Context<'_>,
78 id: Id,
79 last_known_tick: Instant,
80 ) -> Poll<Instant> {
81 let descriptor =
82 mem::replace(&mut self.descriptors[id], Descriptor::NotWaiting);
83
84 match descriptor {
85 Descriptor::Vaccant => {
86 debug_assert!(false, "cannot poll on vaccant descriptor");
87 Poll::Pending
88 },
89 Descriptor::NotWaiting => {
90 if last_known_tick >= self.last_tick {
91 self.waiting += 1;
92 self.descriptors[id] =
93 Descriptor::Waiting(cx.waker().clone());
94 self.poll_interval(cx)
95 } else {
96 Poll::Ready(self.last_tick)
97 }
98 },
99 Descriptor::Waiting(_) => {
100 if last_known_tick >= self.last_tick {
101 self.descriptors[id] =
102 Descriptor::Waiting(cx.waker().clone());
103 self.poll_interval(cx)
104 } else {
105 debug_assert!(
106 false,
107 "last known tick should be up to date"
108 );
109 Poll::Ready(self.last_tick)
110 }
111 },
112 }
113 }
114
115 pub fn cancel_tick(&mut self, id: Id) -> bool {
116 let descriptor = &mut self.descriptors[id];
117 debug_assert_eq!(
118 false,
119 matches!(descriptor, Descriptor::Vaccant),
120 "cannot cancel on vaccant descriptor",
121 );
122 if matches!(descriptor, Descriptor::Waiting(_)) {
123 *descriptor = Descriptor::NotWaiting;
124 self.cancel_one_waiting();
125 true
126 } else {
127 false
128 }
129 }
130
131 fn poll_interval(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
132 if self.sessions <= self.waiting {
133 debug_assert_eq!(self.sessions, self.waiting);
134 let poll = self.interval.poll_tick(cx);
135 if let Poll::Ready(instant) = poll {
136 self.last_tick = instant;
137 self.waiting = 0;
138 for descriptor in &mut self.descriptors {
139 if let Descriptor::Waiting(waker) =
140 mem::replace(descriptor, Descriptor::NotWaiting)
141 {
142 waker.wake();
143 }
144 }
145 }
146 poll
147 } else {
148 Poll::Pending
149 }
150 }
151
152 fn cancel_one_waiting(&mut self) {
153 self.waiting -= 1;
154 self.wake_one_if_complete();
155 }
156
157 fn wake_one_if_complete(&mut self) {
158 if self.sessions <= self.waiting {
159 debug_assert_eq!(self.sessions, self.waiting);
160 for descriptor in &mut self.descriptors {
161 match mem::replace(descriptor, Descriptor::NotWaiting) {
162 Descriptor::Waiting(waker) => {
163 waker.wake();
164 self.waiting -= 1;
165 break;
166 },
167 value => *descriptor = value,
168 }
169 }
170 }
171 }
172}
173
174#[derive(Debug)]
175struct Shared {
176 period: Duration,
177 state: Mutex<State>,
178}
179
180impl Shared {
181 pub fn with_state<F, T>(&self, scope: F) -> T
182 where
183 F: FnOnce(&mut State) -> T,
184 {
185 let mut state = self.state.lock().expect("poisoned lock");
186 scope(&mut state)
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct Timer {
192 shared: Arc<Shared>,
193}
194
195impl Timer {
196 pub fn new(period: Duration) -> Self {
197 let start = Instant::now();
198 let interval = time::interval(period);
199 let state = Mutex::new(State::new(interval, start));
200 let shared = Shared { period, state };
201 Self { shared: Arc::new(shared) }
202 }
203
204 fn from_session(session: &TickSession) -> Self {
205 Self { shared: session.shared.clone() }
206 }
207
208 pub fn new_session(&self) -> TickSession {
209 let last_known_tick = self.shared.with_state(|state| state.last_tick);
210 TickSession::new(&self.shared, last_known_tick)
211 }
212}
213
214#[derive(Debug)]
215pub struct TickSession {
216 id: Id,
217 last_known_tick: Instant,
218 shared: Arc<Shared>,
219}
220
221impl TickSession {
222 fn new(shared: &Arc<Shared>, last_known_tick: Instant) -> Self {
223 let id = shared.with_state(|state| state.new_session());
224 Self { id, last_known_tick, shared: shared.clone() }
225 }
226
227 pub fn tick(&mut self) -> Tick<'_> {
228 Tick { timer: self }
229 }
230
231 pub fn poll_tick(&mut self, cx: &mut Context<'_>) -> Poll<Instant> {
232 let poll = self.shared.with_state(|state| {
233 state.poll_tick(cx, self.id, self.last_known_tick)
234 });
235 if let Poll::Ready(instant) = poll {
236 self.last_known_tick = instant;
237 }
238 poll
239 }
240
241 pub fn cancel_tick(&mut self) -> bool {
242 self.shared.with_state(|state| state.cancel_tick(self.id))
243 }
244
245 pub fn period(&self) -> Duration {
246 self.shared.period
247 }
248
249 pub fn last_tick(&self) -> Instant {
250 self.shared.with_state(|state| state.last_tick)
251 }
252
253 pub fn elapsed(&self) -> Duration {
254 self.last_known_tick.elapsed()
255 }
256
257 pub fn time_left(&self) -> Duration {
258 self.period().saturating_sub(self.elapsed())
259 }
260
261 pub fn timer(&self) -> Timer {
262 Timer::from_session(self)
263 }
264}
265
266impl Clone for TickSession {
267 fn clone(&self) -> Self {
268 Self::new(&self.shared, self.last_known_tick)
269 }
270}
271
272impl Drop for TickSession {
273 fn drop(&mut self) {
274 self.shared.with_state(|state| state.drop_session(self.id))
275 }
276}
277
278#[derive(Debug)]
279pub struct Tick<'a> {
280 timer: &'a mut TickSession,
281}
282
283impl Future for Tick<'_> {
284 type Output = Instant;
285
286 fn poll(
287 mut self: Pin<&mut Self>,
288 cx: &mut Context<'_>,
289 ) -> Poll<Self::Output> {
290 self.timer.poll_tick(cx)
291 }
292}
293
294impl Drop for Tick<'_> {
295 fn drop(&mut self) {
296 self.timer.cancel_tick();
297 }
298}
299
300#[cfg(test)]
301mod test {
302 use std::time::Duration;
303
304 use tokio::{task::JoinSet, time};
305
306 use crate::timer::Timer;
307
308 #[tokio::test]
309 async fn sync_once() {
310 let mut join_set = JoinSet::new();
311 let timer = Timer::new(Duration::from_micros(100));
312 let mut participant = timer.new_session();
313 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
314 for mut participant in timers {
315 join_set.spawn(async move { participant.tick().await });
316 }
317 let answer = participant.tick().await;
318 while let Some(alternative) = join_set.join_next().await {
319 assert_eq!(answer, alternative.unwrap());
320 }
321 }
322
323 #[tokio::test]
324 async fn sync_twice() {
325 let mut join_set = JoinSet::new();
326 let timer = Timer::new(Duration::from_micros(100));
327 let mut participant = timer.new_session();
328 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
329 for mut participant in timers {
330 join_set.spawn(async move {
331 let first = participant.tick().await;
332 let second = participant.tick().await;
333 (first, second)
334 });
335 }
336 let first_answer = participant.tick().await;
337 let second_answer = participant.tick().await;
338 while let Some(alternative) = join_set.join_next().await {
339 let (first_alternative, second_alternative) = alternative.unwrap();
340 assert_eq!(first_answer, first_alternative);
341 assert_eq!(second_answer, second_alternative);
342 }
343 }
344
345 #[tokio::test]
346 async fn sync_twice_with_novices() {
347 let mut join_set = JoinSet::new();
348 let timer = Timer::new(Duration::from_micros(100));
349 let mut participant = timer.new_session();
350 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
351 for mut participant in timers {
352 join_set.spawn(async move {
353 let first = participant.tick().await;
354 let second = participant.tick().await;
355 (Some(first), second)
356 });
357 }
358 let first_answer = participant.tick().await;
359 let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
360 for mut participant in timers {
361 join_set.spawn(async move {
362 let second = participant.tick().await;
363 (None, second)
364 });
365 }
366 let second_answer = participant.tick().await;
367
368 let mut has_first = 0;
369 while let Some(alternative) = join_set.join_next().await {
370 let (maybe_first_alternative, second_alternative) =
371 alternative.unwrap();
372 if let Some(first_alternative) = maybe_first_alternative {
373 assert_eq!(first_answer, first_alternative);
374 has_first += 1;
375 }
376 assert_eq!(second_answer, second_alternative);
377 }
378 assert_ne!(has_first, 0);
379 }
380
381 #[tokio::test]
382 async fn sync_thrice_with_novices_and_leavos() {
383 let mut join_set = JoinSet::new();
384 let timer = Timer::new(Duration::from_micros(100));
385 let mut participant = timer.new_session();
386 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
387 for (i, mut participant) in timers.into_iter().enumerate() {
388 join_set.spawn(async move {
389 let first = participant.tick().await;
390 let second = participant.tick().await;
391 let third =
392 if i < 4 { None } else { Some(participant.tick().await) };
393 (Some(first), second, third)
394 });
395 }
396 let first_answer = participant.tick().await;
397 let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
398 for mut participant in timers {
399 join_set.spawn(async move {
400 let second = participant.tick().await;
401 let third = participant.tick().await;
402 (None, second, Some(third))
403 });
404 }
405 let second_answer = participant.tick().await;
406 let third_answer = participant.tick().await;
407
408 let mut has_first = 0;
409 let mut has_third = 0;
410 while let Some(alternative) = join_set.join_next().await {
411 let (
412 maybe_first_alternative,
413 second_alternative,
414 maybe_third_alternative,
415 ) = alternative.unwrap();
416 if let Some(first_alternative) = maybe_first_alternative {
417 assert_eq!(first_answer, first_alternative);
418 has_first += 1;
419 }
420 assert_eq!(second_answer, second_alternative);
421 if let Some(third_alternative) = maybe_third_alternative {
422 assert_eq!(third_answer, third_alternative);
423 has_third += 1;
424 }
425 }
426 assert_ne!(has_first, 0);
427 assert_ne!(has_third, 0);
428 }
429
430 #[tokio::test]
431 async fn sync_thrice_with_novices_and_leavos_mixed() {
432 let mut join_set = JoinSet::new();
433 let timer = Timer::new(Duration::from_micros(100));
434 let mut participant = timer.new_session();
435 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
436 for (i, mut participant) in timers.into_iter().enumerate() {
437 join_set.spawn(async move {
438 let first = participant.tick().await;
439 let (second, third) = if i < 13 {
440 let second = participant.tick().await;
441 let third = if i < 4 {
442 None
443 } else {
444 Some(participant.tick().await)
445 };
446 (Some(second), third)
447 } else {
448 (None, None)
449 };
450 (Some(first), second, third)
451 });
452 }
453 let first_answer = participant.tick().await;
454 let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
455 for mut participant in timers {
456 join_set.spawn(async move {
457 let second = participant.tick().await;
458 let third = participant.tick().await;
459 (None, Some(second), Some(third))
460 });
461 }
462 let second_answer = participant.tick().await;
463 let third_answer = participant.tick().await;
464
465 let mut has_first = 0;
466 let mut has_second = 0;
467 let mut has_third = 0;
468 while let Some(alternative) = join_set.join_next().await {
469 let (
470 maybe_first_alternative,
471 maybe_second_alternative,
472 maybe_third_alternative,
473 ) = alternative.unwrap();
474 if let Some(first_alternative) = maybe_first_alternative {
475 assert_eq!(first_answer, first_alternative);
476 has_first += 1;
477 }
478 if let Some(second_alternative) = maybe_second_alternative {
479 assert_eq!(second_answer, second_alternative);
480 has_second += 1;
481 }
482 if let Some(third_alternative) = maybe_third_alternative {
483 assert_eq!(third_answer, third_alternative);
484 has_third += 1;
485 }
486 }
487 assert_ne!(has_first, 0);
488 assert_ne!(has_second, 0);
489 assert_ne!(has_third, 0);
490 }
491
492 #[tokio::test]
493 async fn sync_four_times_with_novices_and_leavos_and_novices_again() {
494 let mut join_set = JoinSet::new();
495 let timer = Timer::new(Duration::from_micros(100));
496 let mut participant = timer.new_session();
497 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
498 for (i, mut participant) in timers.into_iter().enumerate() {
499 join_set.spawn(async move {
500 let first = participant.tick().await;
501 let second = participant.tick().await;
502 let third =
503 if i < 9 { None } else { Some(participant.tick().await) };
504 (Some(first), Some(second), third, None)
505 });
506 }
507 let first_answer = participant.tick().await;
508 let timers = (0 .. 7).map(|_| participant.clone()).collect::<Vec<_>>();
509 for mut participant in timers {
510 join_set.spawn(async move {
511 let second = participant.tick().await;
512 let third = participant.tick().await;
513 let fourth = participant.tick().await;
514 (None, Some(second), Some(third), Some(fourth))
515 });
516 }
517 let second_answer = participant.tick().await;
518 let third_answer = participant.tick().await;
519 let timers = (0 .. 11).map(|_| participant.clone()).collect::<Vec<_>>();
520 for mut participant in timers {
521 join_set.spawn(async move {
522 let fourth = participant.tick().await;
523 (None, None, None, Some(fourth))
524 });
525 }
526 let fourth_answer = participant.tick().await;
527
528 let mut has_first = 0;
529 let mut has_second = 0;
530 let mut has_third = 0;
531 let mut has_fourth = 0;
532 while let Some(alternative) = join_set.join_next().await {
533 let (
534 maybe_first_alternative,
535 maybe_second_alternative,
536 maybe_third_alternative,
537 maybe_fourth_alternative,
538 ) = alternative.unwrap();
539 if let Some(first_alternative) = maybe_first_alternative {
540 assert_eq!(first_answer, first_alternative);
541 has_first += 1;
542 }
543 if let Some(second_alternative) = maybe_second_alternative {
544 assert_eq!(second_answer, second_alternative);
545 has_second += 1;
546 }
547 if let Some(third_alternative) = maybe_third_alternative {
548 assert_eq!(third_answer, third_alternative);
549 has_third += 1;
550 }
551 if let Some(fourth_alternative) = maybe_fourth_alternative {
552 assert_eq!(fourth_answer, fourth_alternative);
553 has_fourth += 1;
554 }
555 }
556 assert_ne!(has_first, 0);
557 assert_ne!(has_second, 0);
558 assert_ne!(has_third, 0);
559 assert_ne!(has_fourth, 0);
560 }
561
562 #[tokio::test]
563 async fn sync_twice_only_one_left() {
564 let mut join_set = JoinSet::new();
565 let timer = Timer::new(Duration::from_micros(100));
566 let mut participant = timer.new_session();
567 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
568 for mut participant in timers {
569 join_set.spawn(async move { participant.tick().await });
570 }
571 let first_answer = participant.tick().await;
572 let _second_answer = participant.tick().await;
573 while let Some(alternative) = join_set.join_next().await {
574 assert_eq!(first_answer, alternative.unwrap());
575 }
576 }
577
578 #[tokio::test]
579 async fn sync_with_cancel() {
580 let mut join_set = JoinSet::new();
581 let timer = Timer::new(Duration::from_millis(10));
582 let mut participant = timer.new_session();
583 participant.tick().await;
584 let timers = (0 .. 16).map(|_| participant.clone()).collect::<Vec<_>>();
585 for (i, mut participant) in timers.into_iter().enumerate() {
586 join_set.spawn(async move {
587 if i == 7 {
588 let sleep_first = tokio::select! {
589 _ = participant.tick() => false,
590 _ = time::sleep(Duration::from_micros(1)) => true,
591 };
592 assert!(sleep_first);
593 None
594 } else {
595 Some(participant.tick().await)
596 }
597 });
598 }
599 let answer = participant.tick().await;
600 let mut other_completion = 0;
601 while let Some(maybe_alternative) = join_set.join_next().await {
602 if let Some(alternative) = maybe_alternative.unwrap() {
603 assert_eq!(answer, alternative);
604 other_completion += 1;
605 }
606 }
607 assert_eq!(other_completion, 15);
608 }
609}