thedes_async_util/non_blocking/spsc/
bounded.rs1use std::{
2 cell::{Cell, UnsafeCell},
3 fmt,
4 mem::MaybeUninit,
5 ops::Deref,
6 ptr::NonNull,
7 sync::atomic::{AtomicBool, AtomicUsize, Ordering::*},
8};
9
10use thiserror::Error;
11
12pub fn channel<T>(buf_size: usize) -> (Sender<T>, Receiver<T>) {
13 if buf_size == 0 {
14 panic!(
15 "non-blocking SPSC bounded channel cannot have zero-sized buffer"
16 );
17 }
18 let (sender_shared, receiver_shared) = SharedPtr::new(buf_size);
19 (Sender::new(sender_shared), Receiver::new(receiver_shared))
20}
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq, Error)]
23pub enum SendErrorKind {
24 #[error("Buffer is full")]
25 Overflow,
26 #[error("Receiver disconnected")]
27 Disconnected,
28}
29
30#[derive(Debug, Clone, Error)]
31#[error("Receiver disconnected")]
32pub struct SendError<T> {
33 #[source]
34 kind: SendErrorKind,
35 message: T,
36}
37
38impl<T> SendError<T> {
39 fn new(kind: SendErrorKind, message: T) -> Self {
40 Self { kind, message }
41 }
42
43 pub fn kind(&self) -> SendErrorKind {
44 self.kind
45 }
46
47 pub fn message(&self) -> &T {
48 &self.message
49 }
50
51 pub fn into_message(self) -> T {
52 self.message
53 }
54}
55
56#[derive(Debug, Clone, Error)]
57#[error("Senders disconnected")]
58pub struct RecvError {
59 _private: (),
60}
61
62impl RecvError {
63 fn new() -> Self {
64 Self { _private: () }
65 }
66}
67
68#[repr(C)]
69struct Shared<T> {
70 connected: AtomicBool,
71 unread: AtomicUsize,
72 front: Cell<usize>,
73 back: Cell<usize>,
74 buf: Box<[UnsafeCell<MaybeUninit<T>>]>,
75}
76
77impl<T> fmt::Debug for Shared<T> {
78 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
79 f.debug_struct("Shared")
80 .field("connected", &self.connected)
81 .field("unread", &self.unread)
82 .field("buf", &self.buf)
83 .finish()
84 }
85}
86
87impl<T> Shared<T> {
88 pub fn new_connected(buf_size: usize) -> Self {
89 Self {
90 connected: AtomicBool::new(true),
91 unread: AtomicUsize::new(0),
92 front: Cell::new(0),
93 back: Cell::new(0),
94 buf: (0 .. buf_size)
95 .map(|_| UnsafeCell::new(MaybeUninit::uninit()))
96 .collect(),
97 }
98 }
99
100 pub fn is_connected_weak(&self) -> bool {
101 self.connected.load(Relaxed)
102 }
103
104 pub unsafe fn send(&self, message: T) -> Result<(), SendError<T>> {
105 let unread = self.unread.load(Acquire);
106 let connected = self.connected.load(Acquire);
107 if !connected {
108 return Err(SendError::new(SendErrorKind::Disconnected, message));
109 }
110 if unread == self.buf.len() {
111 return Err(SendError::new(SendErrorKind::Overflow, message));
112 }
113 unsafe {
114 let back = self.back.get();
115 (*self.buf[back].get()).write(message);
116 self.back.set((back + 1) % self.buf.len());
117 self.unread.fetch_add(1, Release);
118 }
119 Ok(())
120 }
121
122 pub unsafe fn recv_one(&self) -> Result<Option<T>, RecvError> {
123 let unread = self.unread.load(Acquire);
124 if unread == 0 {
125 if self.connected.load(Acquire) {
126 return Ok(None);
127 }
128 return Err(RecvError::new());
129 }
130 unsafe {
131 let front = self.front.get();
132 let message = (*self.buf[front].get()).as_ptr().read();
133 self.front.set((front + 1) % self.buf.len());
134 self.unread.fetch_sub(1, Release);
135 Ok(Some(message))
136 }
137 }
138
139 pub unsafe fn recv_many<'a>(
140 &'a self,
141 ) -> Result<RecvMany<'a, T>, RecvError> {
142 let unread = self.unread.load(Acquire);
143 if unread == 0 && !self.connected.load(Acquire) {
144 Err(RecvError::new())?
145 }
146 Ok(RecvMany { shared: self, count: unread })
147 }
148}
149
150impl<T> Drop for Shared<T> {
151 fn drop(&mut self) {
152 unsafe { while let Ok(Some(_)) = self.recv_one() {} }
153 }
154}
155
156struct SharedPtr<T> {
157 inner: NonNull<Shared<T>>,
158}
159
160impl<T> SharedPtr<T> {
161 pub fn new(buf_size: usize) -> (Self, Self) {
162 let shared = Shared::new_connected(buf_size);
163 let shared_boxed = Box::new(shared);
164
165 let inner =
166 unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
167
168 (Self { inner }, Self { inner })
169 }
170}
171
172impl<T> fmt::Debug for SharedPtr<T> {
173 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
174 f.debug_struct("SharedPtr").field("inner", &**self).finish()
175 }
176}
177
178impl<T> Deref for SharedPtr<T> {
179 type Target = Shared<T>;
180
181 fn deref(&self) -> &Self::Target {
182 unsafe { self.inner.as_ref() }
183 }
184}
185
186impl<T> Drop for SharedPtr<T> {
187 fn drop(&mut self) {
188 unsafe {
189 let last = self
190 .connected
191 .compare_exchange(true, false, AcqRel, Acquire)
192 .is_err();
193 if last {
194 let _ = Box::from_raw(self.inner.as_ptr());
195 }
196 }
197 }
198}
199
200#[derive(Debug)]
201pub struct RecvMany<'a, T> {
202 shared: &'a Shared<T>,
203 count: usize,
204}
205
206unsafe impl<'a, T> Send for RecvMany<'a, T> where T: Send {}
207unsafe impl<'a, T> Sync for RecvMany<'a, T> where T: Send {}
208
209impl<'a, T> Iterator for RecvMany<'a, T> {
210 type Item = T;
211
212 fn next(&mut self) -> Option<Self::Item> {
213 if self.count == 0 {
214 None?
215 }
216 self.count -= 1;
217 unsafe { self.shared.recv_one().ok().flatten() }
218 }
219}
220
221#[derive(Debug)]
222pub struct Sender<T> {
223 shared: SharedPtr<T>,
224}
225
226impl<T> Sender<T> {
227 fn new(shared: SharedPtr<T>) -> Self {
228 Self { shared }
229 }
230
231 pub fn is_connected(&self) -> bool {
232 self.shared.is_connected_weak()
233 }
234
235 pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
236 unsafe { self.shared.send(message) }
237 }
238}
239
240unsafe impl<T> Send for Sender<T> where T: Send {}
241unsafe impl<T> Sync for Sender<T> where T: Send {}
242
243#[derive(Debug)]
244pub struct Receiver<T> {
245 shared: SharedPtr<T>,
246}
247
248impl<T> Receiver<T> {
249 fn new(shared: SharedPtr<T>) -> Self {
250 Self { shared }
251 }
252
253 pub fn is_connected(&self) -> bool {
254 self.shared.is_connected_weak()
255 }
256
257 pub fn recv_one(&mut self) -> Result<Option<T>, RecvError> {
258 unsafe { self.shared.recv_one() }
259 }
260
261 pub fn recv_many<'a>(&'a mut self) -> Result<RecvMany<'a, T>, RecvError> {
262 unsafe { self.shared.recv_many() }
263 }
264}
265
266unsafe impl<T> Send for Receiver<T> where T: Send {}
267unsafe impl<T> Sync for Receiver<T> where T: Send {}
268
269#[cfg(test)]
270mod test {
271 use crate::non_blocking::spsc::bounded::SendErrorKind;
272
273 use super::channel;
274
275 #[test]
276 #[should_panic]
277 fn zero_sized_buf_panics() {
278 channel::<u64>(0);
279 }
280
281 #[test]
282 fn recv_empty() {
283 let (_sender, mut receiver) = channel::<u64>(3);
284 assert_eq!(receiver.recv_one().unwrap(), None);
285 }
286
287 #[test]
288 fn simple_send_recv_ok() {
289 let (mut sender, mut receiver) = channel::<u64>(3);
290 sender.send(1).unwrap();
291 sender.send(2).unwrap();
292 sender.send(3).unwrap();
293 assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
294 assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
295 assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
296 }
297
298 #[test]
299 fn simple_send_recv_ok_interleaved() {
300 let (mut sender, mut receiver) = channel::<u64>(3);
301 sender.send(1).unwrap();
302 assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
303 sender.send(2).unwrap();
304 assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
305 sender.send(3).unwrap();
306 assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
307 }
308
309 #[test]
310 fn send_recv_ok_interleaved_wrap_around() {
311 let (mut sender, mut receiver) = channel::<u64>(3);
312 sender.send(1).unwrap();
313 assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
314 sender.send(2).unwrap();
315 assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
316 sender.send(3).unwrap();
317 assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
318 sender.send(4).unwrap();
319 assert_eq!(receiver.recv_one().unwrap().unwrap(), 4);
320 sender.send(5).unwrap();
321 assert_eq!(receiver.recv_one().unwrap().unwrap(), 5);
322 }
323
324 #[test]
325 fn send_overflow() {
326 let (mut sender, mut receiver) = channel::<u64>(3);
327 sender.send(1).unwrap();
328 sender.send(2).unwrap();
329 sender.send(3).unwrap();
330 assert_eq!(sender.send(4).unwrap_err().kind(), SendErrorKind::Overflow);
331 assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
332 assert_eq!(receiver.recv_one().unwrap().unwrap(), 2);
333 assert_eq!(receiver.recv_one().unwrap().unwrap(), 3);
334 assert_eq!(receiver.recv_one().unwrap(), None);
335 sender.send(5).unwrap();
336 assert_eq!(receiver.recv_one().unwrap().unwrap(), 5);
337 }
338
339 #[test]
340 fn sender_disconnected() {
341 let (_, mut receiver) = channel::<u64>(3);
342 receiver.recv_one().unwrap_err();
343 }
344
345 #[test]
346 fn receiver_disconnected() {
347 let (mut sender, _) = channel::<u64>(3);
348 assert_eq!(
349 sender.send(4).unwrap_err().kind(),
350 SendErrorKind::Disconnected
351 );
352 }
353
354 #[test]
355 fn recv_many() {
356 let (mut sender, mut receiver) = channel::<u64>(6);
357 sender.send(1).unwrap();
358 sender.send(2).unwrap();
359 sender.send(3).unwrap();
360
361 let recv_many = receiver.recv_many().unwrap();
362
363 sender.send(4).unwrap();
364
365 let messages: Vec<_> = recv_many.collect();
366
367 assert_eq!(messages, vec![1, 2, 3]);
368 }
369
370 #[test]
371 fn recv_many_but_sender_dropped() {
372 let (mut sender, mut receiver) = channel::<u64>(3);
373 sender.send(1).unwrap();
374 assert_eq!(receiver.recv_one().unwrap().unwrap(), 1);
375 drop(sender);
376 receiver.recv_many().unwrap_err();
377 }
378}