thedes_async_util/non_blocking/spsc/
unbounded.rs1use std::{
2 cell::{Cell, UnsafeCell},
3 fmt,
4 mem::MaybeUninit,
5 ops::Deref,
6 ptr::{NonNull, null_mut},
7 sync::atomic::{AtomicBool, AtomicPtr, Ordering::*},
8};
9
10use thiserror::Error;
11
12pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
13 let (sender_shared, receiver_shared) = SharedPtr::new();
14 (Sender::new(sender_shared), Receiver::new(receiver_shared))
15}
16
17#[derive(Debug, Clone, Error)]
18#[error("Receiver disconnected")]
19pub struct SendError<T> {
20 message: T,
21}
22
23impl<T> SendError<T> {
24 fn new(message: T) -> Self {
25 Self { message }
26 }
27
28 pub fn message(&self) -> &T {
29 &self.message
30 }
31
32 pub fn into_message(self) -> T {
33 self.message
34 }
35}
36
37#[derive(Debug, Clone, Error)]
38#[error("Senders disconnected")]
39pub struct RecvError {
40 _private: (),
41}
42
43impl RecvError {
44 fn new() -> Self {
45 Self { _private: () }
46 }
47}
48
49#[derive(Debug)]
50#[repr(C)]
51struct Node<T> {
52 next: AtomicPtr<Self>,
53 data: UnsafeCell<MaybeUninit<T>>,
54}
55
56struct Shared<T> {
57 connected: AtomicBool,
58 front: Cell<NonNull<Node<T>>>,
59 back: AtomicPtr<Node<T>>,
60}
61
62impl<T> Shared<T> {
63 pub fn new_connected() -> Self {
64 let dummy = Box::new(Node {
65 data: UnsafeCell::new(MaybeUninit::uninit()),
66 next: AtomicPtr::new(null_mut()),
67 });
68
69 let dummy_non_null =
70 unsafe { NonNull::new_unchecked(Box::into_raw(dummy)) };
71
72 let this = Self {
73 front: Cell::new(dummy_non_null),
74 back: AtomicPtr::new(dummy_non_null.as_ptr()),
75 connected: AtomicBool::new(true),
76 };
77 this
78 }
79
80 pub fn is_connected_weak(&self) -> bool {
81 self.connected.load(Relaxed)
82 }
83
84 pub unsafe fn send(&self, message: T) -> Result<(), SendError<T>> {
85 if self.connected.load(Acquire) {
86 let new_node = Box::new(Node {
87 data: UnsafeCell::new(MaybeUninit::new(message)),
88 next: AtomicPtr::new(null_mut()),
89 });
90 unsafe {
91 let node_non_null =
92 NonNull::new_unchecked(Box::into_raw(new_node));
93 let back = self.back.load(Relaxed);
94 (&mut *back).next.store(node_non_null.as_ptr(), Release);
95 self.back.store(node_non_null.as_ptr(), Release);
96 }
97 Ok(())
98 } else {
99 Err(SendError::new(message))
100 }
101 }
102
103 pub unsafe fn recv_one(&self) -> Result<Option<T>, RecvError> {
104 unsafe {
105 let next_ptr = self.front.get().as_ref().next.load(Acquire);
106 match NonNull::new(next_ptr) {
107 Some(next_non_null) => {
108 let data =
109 (&*next_non_null.as_ref().data.get()).as_ptr().read();
110 let _ = Box::from_raw(self.front.get().as_ptr());
111 self.front.set(next_non_null);
112 Ok(Some(data))
113 },
114 None => {
115 if self.connected.load(Acquire) {
116 Ok(None)
117 } else {
118 Err(RecvError::new())
119 }
120 },
121 }
122 }
123 }
124
125 pub unsafe fn recv_many<'a>(
126 &'a self,
127 ) -> Result<RecvMany<'a, T>, RecvError> {
128 let back = unsafe { NonNull::new_unchecked(self.back.load(Acquire)) };
129 if back == self.front.get() && !self.connected.load(Acquire) {
130 Err(RecvError::new())?
131 }
132 Ok(RecvMany { shared: self, back_limit: back })
133 }
134}
135
136impl<T> fmt::Debug for Shared<T> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("Shared")
139 .field("front", &self.front)
140 .field("back", &self.back)
141 .field("connected", &self.connected)
142 .finish()
143 }
144}
145
146impl<T> Drop for Shared<T> {
147 fn drop(&mut self) {
148 unsafe { while let Ok(Some(_)) = self.recv_one() {} }
149 unsafe {
150 let _ = Box::from_raw(self.front.get().as_ptr());
151 }
152 }
153}
154
155struct SharedPtr<T> {
156 inner: NonNull<Shared<T>>,
157}
158
159impl<T> SharedPtr<T> {
160 pub fn new() -> (Self, Self) {
161 let shared = Shared::new_connected();
162 let shared_boxed = Box::new(shared);
163
164 let inner =
165 unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
166
167 (Self { inner }, Self { inner })
168 }
169}
170
171impl<T> fmt::Debug for SharedPtr<T> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("SharedPtr").field("inner", &**self).finish()
174 }
175}
176
177impl<T> Deref for SharedPtr<T> {
178 type Target = Shared<T>;
179
180 fn deref(&self) -> &Self::Target {
181 unsafe { self.inner.as_ref() }
182 }
183}
184
185impl<T> Drop for SharedPtr<T> {
186 fn drop(&mut self) {
187 unsafe {
188 let last = self
189 .connected
190 .compare_exchange(true, false, AcqRel, Acquire)
191 .is_err();
192 if last {
193 let _ = Box::from_raw(self.inner.as_ptr());
194 }
195 }
196 }
197}
198
199#[derive(Debug)]
200pub struct RecvMany<'a, T> {
201 shared: &'a Shared<T>,
202 back_limit: NonNull<Node<T>>,
203}
204
205unsafe impl<'a, T> Send for RecvMany<'a, T> where T: Send {}
206unsafe impl<'a, T> Sync for RecvMany<'a, T> where T: Send {}
207
208impl<'a, T> Iterator for RecvMany<'a, T> {
209 type Item = T;
210
211 fn next(&mut self) -> Option<Self::Item> {
212 unsafe {
213 if self.shared.front.get() == self.back_limit {
214 None?
215 }
216 self.shared.recv_one().ok().flatten()
217 }
218 }
219}
220
221#[derive(Debug)]
222pub struct Sender<T> {
223 shared: SharedPtr<T>,
224}
225
226impl<T> Sender<T> {
227 fn new(shared: SharedPtr<T>) -> Self {
228 Self { shared }
229 }
230
231 pub fn is_connected(&self) -> bool {
232 self.shared.is_connected_weak()
233 }
234
235 pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
236 unsafe { self.shared.send(message) }
237 }
238}
239
240unsafe impl<T> Send for Sender<T> where T: Send {}
241unsafe impl<T> Sync for Sender<T> where T: Send {}
242
243#[derive(Debug)]
244pub struct Receiver<T> {
245 shared: SharedPtr<T>,
246}
247
248impl<T> Receiver<T> {
249 fn new(shared: SharedPtr<T>) -> Self {
250 Self { shared }
251 }
252
253 pub fn is_connected(&self) -> bool {
254 self.shared.is_connected_weak()
255 }
256
257 pub fn recv_one(&mut self) -> Result<Option<T>, RecvError> {
258 unsafe { self.shared.recv_one() }
259 }
260
261 pub fn recv_many<'a>(&'a mut self) -> Result<RecvMany<'a, T>, RecvError> {
262 unsafe { self.shared.recv_many() }
263 }
264}
265
266unsafe impl<T> Send for Receiver<T> where T: Send {}
267unsafe impl<T> Sync for Receiver<T> where T: Send {}
268
269#[cfg(test)]
270mod test {
271 use super::channel;
272
273 #[test]
274 fn recv_empty() {
275 let (_sender, mut receiver) = channel::<u64>();
276 assert_eq!(receiver.recv_one().unwrap(), None);
277 assert_eq!(receiver.recv_one().unwrap(), None);
278 }
279
280 #[test]
281 fn recv_empty_disconnected() {
282 let (sender, mut receiver) = channel::<u64>();
283 assert_eq!(receiver.recv_one().unwrap(), None);
284 drop(sender);
285 assert!(receiver.recv_one().is_err());
286 }
287
288 #[test]
289 fn send_recv_once() {
290 let (mut sender, mut receiver) = channel::<u64>();
291 sender.send(93).unwrap();
292 assert_eq!(receiver.recv_one().unwrap(), Some(93));
293 assert_eq!(receiver.recv_one().unwrap(), None);
294 }
295
296 #[test]
297 fn send_recv_twice_interleaved() {
298 let (mut sender, mut receiver) = channel::<u64>();
299 sender.send(4).unwrap();
300 assert_eq!(receiver.recv_one().unwrap(), Some(4));
301 sender.send(234).unwrap();
302 assert_eq!(receiver.recv_one().unwrap(), Some(234));
303 assert_eq!(receiver.recv_one().unwrap(), None);
304 }
305
306 #[test]
307 fn send_recv_twice_consecutive() {
308 let (mut sender, mut receiver) = channel::<u64>();
309 sender.send(4).unwrap();
310 sender.send(234).unwrap();
311 assert_eq!(receiver.recv_one().unwrap(), Some(4));
312 assert_eq!(receiver.recv_one().unwrap(), Some(234));
313 assert_eq!(receiver.recv_one().unwrap(), None);
314 }
315
316 #[test]
317 fn send_recv_dropped() {
318 let (mut sender, mut receiver) = channel::<u64>();
319 sender.send(9452).unwrap();
320 drop(sender);
321 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
322 assert!(receiver.recv_one().is_err());
323 }
324
325 #[test]
326 fn send_recv_twice_dropped() {
327 let (mut sender, mut receiver) = channel::<u64>();
328 sender.send(9452).unwrap();
329 sender.send(12).unwrap();
330 drop(sender);
331 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
332 assert_eq!(receiver.recv_one().unwrap(), Some(12));
333 assert!(receiver.recv_one().is_err());
334 }
335
336 #[test]
337 fn send_recv_twice_dropped_interleaved() {
338 let (mut sender, mut receiver) = channel::<u64>();
339 sender.send(9452).unwrap();
340 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
341 sender.send(12).unwrap();
342 drop(sender);
343 assert_eq!(receiver.recv_one().unwrap(), Some(12));
344 assert!(receiver.recv_one().is_err());
345 }
346
347 #[test]
348 fn recv_many() {
349 let (mut sender, mut receiver) = channel::<u64>();
350 sender.send(1).unwrap();
351 sender.send(2).unwrap();
352 sender.send(3).unwrap();
353 sender.send(4).unwrap();
354 let recv_many = receiver.recv_many().unwrap();
355 sender.send(5).unwrap();
356 sender.send(6).unwrap();
357 let numbers: Vec<_> = recv_many.collect();
358 assert_eq!(numbers, vec![1, 2, 3, 4]);
359 }
360
361 #[test]
362 fn recv_many_dropped() {
363 let (mut sender, mut receiver) = channel::<u64>();
364 sender.send(1).unwrap();
365 assert_eq!(receiver.recv_one().unwrap(), Some(1));
366 drop(sender);
367 receiver.recv_many().unwrap_err();
368 }
369
370 #[test]
371 fn recv_many_dropped_with_messages() {
372 let (mut sender, mut receiver) = channel::<u64>();
373 sender.send(1).unwrap();
374 assert_eq!(receiver.recv_one().unwrap(), Some(1));
375 sender.send(2).unwrap();
376 sender.send(3).unwrap();
377 sender.send(4).unwrap();
378 drop(sender);
379 let recv_many = receiver.recv_many().unwrap();
380 let numbers: Vec<_> = recv_many.collect();
381 assert_eq!(numbers, vec![2, 3, 4]);
382 }
383}