Skip to main content

thedes_async_util/non_blocking/spsc/
bounded.rs

1use std::{
2    cell::{Cell, UnsafeCell},
3    fmt,
4    mem::MaybeUninit,
5    ops::Deref,
6    ptr::NonNull,
7    sync::atomic::{AtomicBool, AtomicUsize, Ordering::*},
8};
9
10use thiserror::Error;
11
12pub fn channel<T>(buf_size: usize) -> (Sender<T>, Receiver<T>) {
13    if buf_size == 0 {
14        panic!(
15            "non-blocking SPSC bounded channel cannot have zero-sized buffer"
16        );
17    }
18    let (sender_shared, receiver_shared) = SharedPtr::new(buf_size);
19    (Sender::new(sender_shared), Receiver::new(receiver_shared))
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
23pub enum SendErrorKind {
24    #[error("Buffer is full")]
25    Overflow,
26    #[error("Receiver disconnected")]
27    Disconnected,
28}
29
30#[derive(Debug, Clone, Error)]
31#[error("Receiver disconnected")]
32pub struct SendError<T> {
33    #[source]
34    kind: SendErrorKind,
35    message: T,
36}
37
38impl<T> SendError<T> {
39    fn new(kind: SendErrorKind, message: T) -> Self {
40        Self { kind, message }
41    }
42
43    pub fn kind(&self) -> SendErrorKind {
44        self.kind
45    }
46
47    pub fn message(&self) -> &T {
48        &self.message
49    }
50
51    pub fn into_message(self) -> T {
52        self.message
53    }
54}
55
56#[derive(Debug, Clone, Error)]
57#[error("Senders disconnected")]
58pub struct RecvError {
59    _private: (),
60}
61
62impl RecvError {
63    fn new() -> Self {
64        Self { _private: () }
65    }
66}
67
68#[repr(C)]
69struct Shared<T> {
70    connected: AtomicBool,
71    unread: AtomicUsize,
72    front: Cell<usize>,
73    back: Cell<usize>,
74    buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
75}
76
77impl<T> fmt::Debug for Shared<T> {
78    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79        f.debug_struct("Shared")
80            .field("connected", &self.connected)
81            .field("unread", &self.unread)
82            .field("buf", &self.buf)
83            .finish()
84    }
85}
86
87impl<T> Shared<T> {
88    pub fn new_connected(buf_size: usize) -> Self {
89        Self {
90            connected: AtomicBool::new(true),
91            unread: AtomicUsize::new(0),
92            front: Cell::new(0),
93            back: Cell::new(0),
94            buf: (0 .. buf_size)
95                .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
96                .collect(),
97        }
98    }
99
100    pub fn is_connected_weak(&self) -> bool {
101        self.connected.load(Relaxed)
102    }
103
104    pub unsafe fn send(&self, message: T) -> Result<(), SendError<T>> {
105        let unread = self.unread.load(Acquire);
106        let connected = self.connected.load(Acquire);
107        if !connected {
108            return Err(SendError::new(SendErrorKind::Disconnected, message));
109        }
110        if unread == self.buf.len() {
111            return Err(SendError::new(SendErrorKind::Overflow, message));
112        }
113        unsafe {
114            let back = self.back.get();
115            (*self.buf[back].get()).write(message);
116            self.back.set((back + 1) % self.buf.len());
117            self.unread.fetch_add(1, Release);
118        }
119        Ok(())
120    }
121
122    pub unsafe fn recv_one(&self) -> Result<Option<T>, RecvError> {
123        let unread = self.unread.load(Acquire);
124        if unread == 0 {
125            if self.connected.load(Acquire) {
126                return Ok(None);
127            }
128            return Err(RecvError::new());
129        }
130        unsafe {
131            let front = self.front.get();
132            let message = (*self.buf[front].get()).as_ptr().read();
133            self.front.set((front + 1) % self.buf.len());
134            self.unread.fetch_sub(1, Release);
135            Ok(Some(message))
136        }
137    }
138
139    pub unsafe fn recv_many<'a>(
140        &'a self,
141    ) -> Result<RecvMany<'a, T>, RecvError> {
142        let unread = self.unread.load(Acquire);
143        if unread == 0 && !self.connected.load(Acquire) {
144            Err(RecvError::new())?
145        }
146        Ok(RecvMany { shared: self, count: unread })
147    }
148}
149
150impl<T> Drop for Shared<T> {
151    fn drop(&mut self) {
152        unsafe { while let Ok(Some(_)) = self.recv_one() {} }
153    }
154}
155
156struct SharedPtr<T> {
157    inner: NonNull<Shared<T>>,
158}
159
160impl<T> SharedPtr<T> {
161    pub fn new(buf_size: usize) -> (Self, Self) {
162        let shared = Shared::new_connected(buf_size);
163        let shared_boxed = Box::new(shared);
164
165        let inner =
166            unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
167
168        (Self { inner }, Self { inner })
169    }
170}
171
172impl<T> fmt::Debug for SharedPtr<T> {
173    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174        f.debug_struct("SharedPtr").field("inner", &**self).finish()
175    }
176}
177
178impl<T> Deref for SharedPtr<T> {
179    type Target = Shared<T>;
180
181    fn deref(&self) -> &Self::Target {
182        unsafe { self.inner.as_ref() }
183    }
184}
185
186impl<T> Drop for SharedPtr<T> {
187    fn drop(&mut self) {
188        unsafe {
189            let last = self
190                .connected
191                .compare_exchange(true, false, AcqRel, Acquire)
192                .is_err();
193            if last {
194                let _ = Box::from_raw(self.inner.as_ptr());
195            }
196        }
197    }
198}
199
200#[derive(Debug)]
201pub struct RecvMany<'a, T> {
202    shared: &'a Shared<T>,
203    count: usize,
204}
205
206unsafe impl<'a, T> Send for RecvMany<'a, T> where T: Send {}
207unsafe impl<'a, T> Sync for RecvMany<'a, T> where T: Send {}
208
209impl<'a, T> Iterator for RecvMany<'a, T> {
210    type Item = T;
211
212    fn next(&mut self) -> Option<Self::Item> {
213        if self.count == 0 {
214            None?
215        }
216        self.count -= 1;
217        unsafe { self.shared.recv_one().ok().flatten() }
218    }
219}
220
221#[derive(Debug)]
222pub struct Sender<T> {
223    shared: SharedPtr<T>,
224}
225
226impl<T> Sender<T> {
227    fn new(shared: SharedPtr<T>) -> Self {
228        Self { shared }
229    }
230
231    pub fn is_connected(&self) -> bool {
232        self.shared.is_connected_weak()
233    }
234
235    pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
236        unsafe { self.shared.send(message) }
237    }
238}
239
240unsafe impl<T> Send for Sender<T> where T: Send {}
241unsafe impl<T> Sync for Sender<T> where T: Send {}
242
243#[derive(Debug)]
244pub struct Receiver<T> {
245    shared: SharedPtr<T>,
246}
247
248impl<T> Receiver<T> {
249    fn new(shared: SharedPtr<T>) -> Self {
250        Self { shared }
251    }
252
253    pub fn is_connected(&self) -> bool {
254        self.shared.is_connected_weak()
255    }
256
257    pub fn recv_one(&mut self) -> Result<Option<T>, RecvError> {
258        unsafe { self.shared.recv_one() }
259    }
260
261    pub fn recv_many<'a>(&'a mut self) -> Result<RecvMany<'a, T>, RecvError> {
262        unsafe { self.shared.recv_many() }
263    }
264}
265
266unsafe impl<T> Send for Receiver<T> where T: Send {}
267unsafe impl<T> Sync for Receiver<T> where T: Send {}
268
269#[cfg(test)]
270mod test {
271    use crate::non_blocking::spsc::bounded::SendErrorKind;
272
273    use super::channel;
274
275    #[test]
276    #[should_panic]
277    fn zero_sized_buf_panics() {
278        channel::<u64>(0);
279    }
280
281    #[test]
282    fn recv_empty() {
283        let (_sender, mut receiver) = channel::<u64>(3);
284        assert_eq!(receiver.recv_one().unwrap(), None);
285    }
286
287    #[test]
288    fn simple_send_recv_ok() {
289        let (mut sender, mut receiver) = channel::<u64>(3);
290        sender.send(1).unwrap();
291        sender.send(2).unwrap();
292        sender.send(3).unwrap();
293        assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
294        assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
295        assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
296    }
297
298    #[test]
299    fn simple_send_recv_ok_interleaved() {
300        let (mut sender, mut receiver) = channel::<u64>(3);
301        sender.send(1).unwrap();
302        assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
303        sender.send(2).unwrap();
304        assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
305        sender.send(3).unwrap();
306        assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
307    }
308
309    #[test]
310    fn send_recv_ok_interleaved_wrap_around() {
311        let (mut sender, mut receiver) = channel::<u64>(3);
312        sender.send(1).unwrap();
313        assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
314        sender.send(2).unwrap();
315        assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
316        sender.send(3).unwrap();
317        assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
318        sender.send(4).unwrap();
319        assert_eq!(receiver.recv_one().unwrap().unwrap(), 4);
320        sender.send(5).unwrap();
321        assert_eq!(receiver.recv_one().unwrap().unwrap(), 5);
322    }
323
324    #[test]
325    fn send_overflow() {
326        let (mut sender, mut receiver) = channel::<u64>(3);
327        sender.send(1).unwrap();
328        sender.send(2).unwrap();
329        sender.send(3).unwrap();
330        assert_eq!(sender.send(4).unwrap_err().kind(), SendErrorKind::Overflow);
331        assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
332        assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
333        assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
334        assert_eq!(receiver.recv_one().unwrap(), None);
335        sender.send(5).unwrap();
336        assert_eq!(receiver.recv_one().unwrap().unwrap(), 5);
337    }
338
339    #[test]
340    fn sender_disconnected() {
341        let (_, mut receiver) = channel::<u64>(3);
342        receiver.recv_one().unwrap_err();
343    }
344
345    #[test]
346    fn receiver_disconnected() {
347        let (mut sender, _) = channel::<u64>(3);
348        assert_eq!(
349            sender.send(4).unwrap_err().kind(),
350            SendErrorKind::Disconnected
351        );
352    }
353
354    #[test]
355    fn recv_many() {
356        let (mut sender, mut receiver) = channel::<u64>(6);
357        sender.send(1).unwrap();
358        sender.send(2).unwrap();
359        sender.send(3).unwrap();
360
361        let recv_many = receiver.recv_many().unwrap();
362
363        sender.send(4).unwrap();
364
365        let messages: Vec<_> = recv_many.collect();
366
367        assert_eq!(messages, vec![1, 2, 3]);
368    }
369
370    #[test]
371    fn recv_many_but_sender_dropped() {
372        let (mut sender, mut receiver) = channel::<u64>(3);
373        sender.send(1).unwrap();
374        assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
375        drop(sender);
376        receiver.recv_many().unwrap_err();
377    }
378}