Skip to main content

thedes_async_util/non_blocking/spsc/
unbounded.rs

1use std::{
2    cell::{Cell, UnsafeCell},
3    fmt,
4    mem::MaybeUninit,
5    ops::Deref,
6    ptr::{NonNull, null_mut},
7    sync::atomic::{AtomicBool, AtomicPtr, Ordering::*},
8};
9
10use thiserror::Error;
11
12pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
13    let (sender_shared, receiver_shared) = SharedPtr::new();
14    (Sender::new(sender_shared), Receiver::new(receiver_shared))
15}
16
17#[derive(Debug, Clone, Error)]
18#[error("Receiver disconnected")]
19pub struct SendError<T> {
20    message: T,
21}
22
23impl<T> SendError<T> {
24    fn new(message: T) -> Self {
25        Self { message }
26    }
27
28    pub fn message(&self) -> &T {
29        &self.message
30    }
31
32    pub fn into_message(self) -> T {
33        self.message
34    }
35}
36
37#[derive(Debug, Clone, Error)]
38#[error("Senders disconnected")]
39pub struct RecvError {
40    _private: (),
41}
42
43impl RecvError {
44    fn new() -> Self {
45        Self { _private: () }
46    }
47}
48
49#[derive(Debug)]
50#[repr(C)]
51struct Node<T> {
52    next: AtomicPtr<Self>,
53    data: UnsafeCell<MaybeUninit<T>>,
54}
55
56struct Shared<T> {
57    connected: AtomicBool,
58    front: Cell<NonNull<Node<T>>>,
59    back: AtomicPtr<Node<T>>,
60}
61
62impl<T> Shared<T> {
63    pub fn new_connected() -> Self {
64        let dummy = Box::new(Node {
65            data: UnsafeCell::new(MaybeUninit::uninit()),
66            next: AtomicPtr::new(null_mut()),
67        });
68
69        let dummy_non_null =
70            unsafe { NonNull::new_unchecked(Box::into_raw(dummy)) };
71
72        let this = Self {
73            front: Cell::new(dummy_non_null),
74            back: AtomicPtr::new(dummy_non_null.as_ptr()),
75            connected: AtomicBool::new(true),
76        };
77        this
78    }
79
80    pub fn is_connected_weak(&self) -> bool {
81        self.connected.load(Relaxed)
82    }
83
84    pub unsafe fn send(&self, message: T) -> Result<(), SendError<T>> {
85        if self.connected.load(Acquire) {
86            let new_node = Box::new(Node {
87                data: UnsafeCell::new(MaybeUninit::new(message)),
88                next: AtomicPtr::new(null_mut()),
89            });
90            unsafe {
91                let node_non_null =
92                    NonNull::new_unchecked(Box::into_raw(new_node));
93                let back = self.back.load(Relaxed);
94                (&mut *back).next.store(node_non_null.as_ptr(), Release);
95                self.back.store(node_non_null.as_ptr(), Release);
96            }
97            Ok(())
98        } else {
99            Err(SendError::new(message))
100        }
101    }
102
103    pub unsafe fn recv_one(&self) -> Result<Option<T>, RecvError> {
104        unsafe {
105            let next_ptr = self.front.get().as_ref().next.load(Acquire);
106            match NonNull::new(next_ptr) {
107                Some(next_non_null) => {
108                    let data =
109                        (&*next_non_null.as_ref().data.get()).as_ptr().read();
110                    let _ = Box::from_raw(self.front.get().as_ptr());
111                    self.front.set(next_non_null);
112                    Ok(Some(data))
113                },
114                None => {
115                    if self.connected.load(Acquire) {
116                        Ok(None)
117                    } else {
118                        Err(RecvError::new())
119                    }
120                },
121            }
122        }
123    }
124
125    pub unsafe fn recv_many<'a>(
126        &'a self,
127    ) -> Result<RecvMany<'a, T>, RecvError> {
128        let back = unsafe { NonNull::new_unchecked(self.back.load(Acquire)) };
129        if back == self.front.get() && !self.connected.load(Acquire) {
130            Err(RecvError::new())?
131        }
132        Ok(RecvMany { shared: self, back_limit: back })
133    }
134}
135
136impl<T> fmt::Debug for Shared<T> {
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        f.debug_struct("Shared")
139            .field("front", &self.front)
140            .field("back", &self.back)
141            .field("connected", &self.connected)
142            .finish()
143    }
144}
145
146impl<T> Drop for Shared<T> {
147    fn drop(&mut self) {
148        unsafe { while let Ok(Some(_)) = self.recv_one() {} }
149        unsafe {
150            let _ = Box::from_raw(self.front.get().as_ptr());
151        }
152    }
153}
154
155struct SharedPtr<T> {
156    inner: NonNull<Shared<T>>,
157}
158
159impl<T> SharedPtr<T> {
160    pub fn new() -> (Self, Self) {
161        let shared = Shared::new_connected();
162        let shared_boxed = Box::new(shared);
163
164        let inner =
165            unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
166
167        (Self { inner }, Self { inner })
168    }
169}
170
171impl<T> fmt::Debug for SharedPtr<T> {
172    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173        f.debug_struct("SharedPtr").field("inner", &**self).finish()
174    }
175}
176
177impl<T> Deref for SharedPtr<T> {
178    type Target = Shared<T>;
179
180    fn deref(&self) -> &Self::Target {
181        unsafe { self.inner.as_ref() }
182    }
183}
184
185impl<T> Drop for SharedPtr<T> {
186    fn drop(&mut self) {
187        unsafe {
188            let last = self
189                .connected
190                .compare_exchange(true, false, AcqRel, Acquire)
191                .is_err();
192            if last {
193                let _ = Box::from_raw(self.inner.as_ptr());
194            }
195        }
196    }
197}
198
199#[derive(Debug)]
200pub struct RecvMany<'a, T> {
201    shared: &'a Shared<T>,
202    back_limit: NonNull<Node<T>>,
203}
204
205unsafe impl<'a, T> Send for RecvMany<'a, T> where T: Send {}
206unsafe impl<'a, T> Sync for RecvMany<'a, T> where T: Send {}
207
208impl<'a, T> Iterator for RecvMany<'a, T> {
209    type Item = T;
210
211    fn next(&mut self) -> Option<Self::Item> {
212        unsafe {
213            if self.shared.front.get() == self.back_limit {
214                None?
215            }
216            self.shared.recv_one().ok().flatten()
217        }
218    }
219}
220
221#[derive(Debug)]
222pub struct Sender<T> {
223    shared: SharedPtr<T>,
224}
225
226impl<T> Sender<T> {
227    fn new(shared: SharedPtr<T>) -> Self {
228        Self { shared }
229    }
230
231    pub fn is_connected(&self) -> bool {
232        self.shared.is_connected_weak()
233    }
234
235    pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
236        unsafe { self.shared.send(message) }
237    }
238}
239
240unsafe impl<T> Send for Sender<T> where T: Send {}
241unsafe impl<T> Sync for Sender<T> where T: Send {}
242
243#[derive(Debug)]
244pub struct Receiver<T> {
245    shared: SharedPtr<T>,
246}
247
248impl<T> Receiver<T> {
249    fn new(shared: SharedPtr<T>) -> Self {
250        Self { shared }
251    }
252
253    pub fn is_connected(&self) -> bool {
254        self.shared.is_connected_weak()
255    }
256
257    pub fn recv_one(&mut self) -> Result<Option<T>, RecvError> {
258        unsafe { self.shared.recv_one() }
259    }
260
261    pub fn recv_many<'a>(&'a mut self) -> Result<RecvMany<'a, T>, RecvError> {
262        unsafe { self.shared.recv_many() }
263    }
264}
265
266unsafe impl<T> Send for Receiver<T> where T: Send {}
267unsafe impl<T> Sync for Receiver<T> where T: Send {}
268
269#[cfg(test)]
270mod test {
271    use super::channel;
272
273    #[test]
274    fn recv_empty() {
275        let (_sender, mut receiver) = channel::<u64>();
276        assert_eq!(receiver.recv_one().unwrap(), None);
277        assert_eq!(receiver.recv_one().unwrap(), None);
278    }
279
280    #[test]
281    fn recv_empty_disconnected() {
282        let (sender, mut receiver) = channel::<u64>();
283        assert_eq!(receiver.recv_one().unwrap(), None);
284        drop(sender);
285        assert!(receiver.recv_one().is_err());
286    }
287
288    #[test]
289    fn send_recv_once() {
290        let (mut sender, mut receiver) = channel::<u64>();
291        sender.send(93).unwrap();
292        assert_eq!(receiver.recv_one().unwrap(), Some(93));
293        assert_eq!(receiver.recv_one().unwrap(), None);
294    }
295
296    #[test]
297    fn send_recv_twice_interleaved() {
298        let (mut sender, mut receiver) = channel::<u64>();
299        sender.send(4).unwrap();
300        assert_eq!(receiver.recv_one().unwrap(), Some(4));
301        sender.send(234).unwrap();
302        assert_eq!(receiver.recv_one().unwrap(), Some(234));
303        assert_eq!(receiver.recv_one().unwrap(), None);
304    }
305
306    #[test]
307    fn send_recv_twice_consecutive() {
308        let (mut sender, mut receiver) = channel::<u64>();
309        sender.send(4).unwrap();
310        sender.send(234).unwrap();
311        assert_eq!(receiver.recv_one().unwrap(), Some(4));
312        assert_eq!(receiver.recv_one().unwrap(), Some(234));
313        assert_eq!(receiver.recv_one().unwrap(), None);
314    }
315
316    #[test]
317    fn send_recv_dropped() {
318        let (mut sender, mut receiver) = channel::<u64>();
319        sender.send(9452).unwrap();
320        drop(sender);
321        assert_eq!(receiver.recv_one().unwrap(), Some(9452));
322        assert!(receiver.recv_one().is_err());
323    }
324
325    #[test]
326    fn send_recv_twice_dropped() {
327        let (mut sender, mut receiver) = channel::<u64>();
328        sender.send(9452).unwrap();
329        sender.send(12).unwrap();
330        drop(sender);
331        assert_eq!(receiver.recv_one().unwrap(), Some(9452));
332        assert_eq!(receiver.recv_one().unwrap(), Some(12));
333        assert!(receiver.recv_one().is_err());
334    }
335
336    #[test]
337    fn send_recv_twice_dropped_interleaved() {
338        let (mut sender, mut receiver) = channel::<u64>();
339        sender.send(9452).unwrap();
340        assert_eq!(receiver.recv_one().unwrap(), Some(9452));
341        sender.send(12).unwrap();
342        drop(sender);
343        assert_eq!(receiver.recv_one().unwrap(), Some(12));
344        assert!(receiver.recv_one().is_err());
345    }
346
347    #[test]
348    fn recv_many() {
349        let (mut sender, mut receiver) = channel::<u64>();
350        sender.send(1).unwrap();
351        sender.send(2).unwrap();
352        sender.send(3).unwrap();
353        sender.send(4).unwrap();
354        let recv_many = receiver.recv_many().unwrap();
355        sender.send(5).unwrap();
356        sender.send(6).unwrap();
357        let numbers: Vec<_> = recv_many.collect();
358        assert_eq!(numbers, vec![1, 2, 3, 4]);
359    }
360
361    #[test]
362    fn recv_many_dropped() {
363        let (mut sender, mut receiver) = channel::<u64>();
364        sender.send(1).unwrap();
365        assert_eq!(receiver.recv_one().unwrap(), Some(1));
366        drop(sender);
367        receiver.recv_many().unwrap_err();
368    }
369
370    #[test]
371    fn recv_many_dropped_with_messages() {
372        let (mut sender, mut receiver) = channel::<u64>();
373        sender.send(1).unwrap();
374        assert_eq!(receiver.recv_one().unwrap(), Some(1));
375        sender.send(2).unwrap();
376        sender.send(3).unwrap();
377        sender.send(4).unwrap();
378        drop(sender);
379        let recv_many = receiver.recv_many().unwrap();
380        let numbers: Vec<_> = recv_many.collect();
381        assert_eq!(numbers, vec![2, 3, 4]);
382    }
383}