thedes_async_util/non_blocking/spsc/
unbounded.rs1use std::{
2 cell::{Cell, UnsafeCell},
3 fmt,
4 mem::MaybeUninit,
5 ops::Deref,
6 ptr::{NonNull, null_mut},
7 sync::atomic::{AtomicBool, AtomicPtr, Ordering::*},
8};
9
10use thiserror::Error;
11
12pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
13 let (sender_shared, receiver_shared) = SharedPtr::new();
14 (Sender::new(sender_shared), Receiver::new(receiver_shared))
15}
16
17#[derive(Debug, Clone, Error)]
18#[error("Receiver disconnected")]
19pub struct SendError<T> {
20 message: T,
21}
22
23impl<T> SendError<T> {
24 fn new(message: T) -> Self {
25 Self { message }
26 }
27
28 pub fn message(&self) -> &T {
29 &self.message
30 }
31
32 pub fn into_message(self) -> T {
33 self.message
34 }
35}
36
37#[derive(Debug, Clone, Error)]
38#[error("Senders disconnected")]
39pub struct RecvError {
40 _private: (),
41}
42
43impl RecvError {
44 fn new() -> Self {
45 Self { _private: () }
46 }
47}
48
49#[derive(Debug)]
50#[repr(C)]
51struct Node<T> {
52 next: AtomicPtr<Self>,
53 data: UnsafeCell<MaybeUninit<T>>,
54}
55
56struct Shared<T> {
57 connected: AtomicBool,
58 front: Cell<NonNull<Node<T>>>,
59 back: AtomicPtr<Node<T>>,
60}
61
62impl<T> Shared<T> {
63 pub fn new_connected() -> Self {
64 let dummy = Box::new(Node {
65 data: UnsafeCell::new(MaybeUninit::uninit()),
66 next: AtomicPtr::new(null_mut()),
67 });
68
69 let dummy_non_null =
70 unsafe { NonNull::new_unchecked(Box::into_raw(dummy)) };
71
72 let this = Self {
73 front: Cell::new(dummy_non_null),
74 back: AtomicPtr::new(dummy_non_null.as_ptr()),
75 connected: AtomicBool::new(true),
76 };
77 this
78 }
79
80 pub fn is_connected_weak(&self) -> bool {
81 self.connected.load(Relaxed)
82 }
83
84 pub unsafe fn send(&self, message: T) -> Result<(), SendError<T>> {
85 if self.connected.load(Acquire) {
86 let new_node = Box::new(Node {
87 data: UnsafeCell::new(MaybeUninit::new(message)),
88 next: AtomicPtr::new(null_mut()),
89 });
90 unsafe {
91 let node_non_null =
92 NonNull::new_unchecked(Box::into_raw(new_node));
93 let back = self.back.load(Relaxed);
94 (&mut *back).next.store(node_non_null.as_ptr(), Release);
95 self.back.store(node_non_null.as_ptr(), Release);
96 }
97 Ok(())
98 } else {
99 Err(SendError::new(message))
100 }
101 }
102
103 pub unsafe fn recv_one(&self) -> Result<Option<T>, RecvError> {
104 unsafe {
105 let next_ptr = self.front.get().as_ref().next.load(Acquire);
106 match NonNull::new(next_ptr) {
107 Some(next_non_null) => {
108 let data =
109 (&*next_non_null.as_ref().data.get()).as_ptr().read();
110 let _ = Box::from_raw(self.front.get().as_ptr());
111 self.front.set(next_non_null);
112 Ok(Some(data))
113 },
114 None => {
115 if self.connected.load(Acquire) {
116 Ok(None)
117 } else {
118 Err(RecvError::new())
119 }
120 },
121 }
122 }
123 }
124
125 pub unsafe fn recv_many<'a>(
126 &'a self,
127 ) -> Result<RecvMany<'a, T>, RecvError> {
128 let back = unsafe { NonNull::new_unchecked(self.back.load(Acquire)) };
129 if back == self.front.get() && !self.connected.load(Acquire) {
130 Err(RecvError::new())?
131 }
132 Ok(RecvMany { shared: self, back_limit: back })
133 }
134}
135
136impl<T> fmt::Debug for Shared<T> {
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 f.debug_struct("Shared")
139 .field("front", &self.front)
140 .field("back", &self.back)
141 .field("connected", &self.connected)
142 .finish()
143 }
144}
145
146impl<T> Drop for Shared<T> {
147 fn drop(&mut self) {
148 unsafe { while let Ok(Some(_)) = self.recv_one() {} }
149 unsafe {
150 let _ = Box::from_raw(self.front.get().as_ptr());
151 }
152 }
153}
154
155struct SharedPtr<T> {
156 inner: NonNull<Shared<T>>,
157}
158
159impl<T> SharedPtr<T> {
160 pub fn new() -> (Self, Self) {
161 let shared = Shared::new_connected();
162 let shared_boxed = Box::new(shared);
163
164 let inner =
165 unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
166
167 (Self { inner }, Self { inner })
168 }
169}
170
171impl<T> fmt::Debug for SharedPtr<T> {
172 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
173 f.debug_struct("SharedPtr").field("inner", &**self).finish()
174 }
175}
176
177impl<T> Deref for SharedPtr<T> {
178 type Target = Shared<T>;
179
180 fn deref(&self) -> &Self::Target {
181 unsafe { self.inner.as_ref() }
182 }
183}
184
185impl<T> Drop for SharedPtr<T> {
186 fn drop(&mut self) {
187 unsafe {
188 let last = self
189 .connected
190 .compare_exchange(true, false, AcqRel, Acquire)
191 .is_err();
192 if last {
193 let _ = Box::from_raw(self.inner.as_ptr());
194 }
195 }
196 }
197}
198
199#[derive(Debug)]
200pub struct RecvMany<'a, T> {
201 shared: &'a Shared<T>,
202 back_limit: NonNull<Node<T>>,
203}
204
205impl<'a, T> Iterator for RecvMany<'a, T> {
206 type Item = T;
207
208 fn next(&mut self) -> Option<Self::Item> {
209 unsafe {
210 if self.shared.front.get() == self.back_limit {
211 None?
212 }
213 self.shared.recv_one().ok().flatten()
214 }
215 }
216}
217
218#[derive(Debug)]
219pub struct Sender<T> {
220 shared: SharedPtr<T>,
221}
222
223impl<T> Sender<T> {
224 fn new(shared: SharedPtr<T>) -> Self {
225 Self { shared }
226 }
227
228 pub fn is_connected(&self) -> bool {
229 self.shared.is_connected_weak()
230 }
231
232 pub fn send(&mut self, message: T) -> Result<(), SendError<T>> {
233 unsafe { self.shared.send(message) }
234 }
235}
236
237unsafe impl<T> Send for Sender<T> where T: Send {}
238unsafe impl<T> Sync for Sender<T> where T: Send {}
239
240#[derive(Debug)]
241pub struct Receiver<T> {
242 shared: SharedPtr<T>,
243}
244
245impl<T> Receiver<T> {
246 fn new(shared: SharedPtr<T>) -> Self {
247 Self { shared }
248 }
249
250 pub fn is_connected(&self) -> bool {
251 self.shared.is_connected_weak()
252 }
253
254 pub fn recv_one(&mut self) -> Result<Option<T>, RecvError> {
255 unsafe { self.shared.recv_one() }
256 }
257
258 pub fn recv_many<'a>(&'a mut self) -> Result<RecvMany<'a, T>, RecvError> {
259 unsafe { self.shared.recv_many() }
260 }
261}
262
263unsafe impl<T> Send for Receiver<T> where T: Send {}
264unsafe impl<T> Sync for Receiver<T> where T: Send {}
265
266#[cfg(test)]
267mod test {
268 use super::channel;
269
270 #[test]
271 fn recv_empty() {
272 let (_sender, mut receiver) = channel::<u64>();
273 assert_eq!(receiver.recv_one().unwrap(), None);
274 assert_eq!(receiver.recv_one().unwrap(), None);
275 }
276
277 #[test]
278 fn recv_empty_disconnected() {
279 let (sender, mut receiver) = channel::<u64>();
280 assert_eq!(receiver.recv_one().unwrap(), None);
281 drop(sender);
282 assert!(receiver.recv_one().is_err());
283 }
284
285 #[test]
286 fn send_recv_once() {
287 let (mut sender, mut receiver) = channel::<u64>();
288 sender.send(93).unwrap();
289 assert_eq!(receiver.recv_one().unwrap(), Some(93));
290 assert_eq!(receiver.recv_one().unwrap(), None);
291 }
292
293 #[test]
294 fn send_recv_twice_interleaved() {
295 let (mut sender, mut receiver) = channel::<u64>();
296 sender.send(4).unwrap();
297 assert_eq!(receiver.recv_one().unwrap(), Some(4));
298 sender.send(234).unwrap();
299 assert_eq!(receiver.recv_one().unwrap(), Some(234));
300 assert_eq!(receiver.recv_one().unwrap(), None);
301 }
302
303 #[test]
304 fn send_recv_twice_consecutive() {
305 let (mut sender, mut receiver) = channel::<u64>();
306 sender.send(4).unwrap();
307 sender.send(234).unwrap();
308 assert_eq!(receiver.recv_one().unwrap(), Some(4));
309 assert_eq!(receiver.recv_one().unwrap(), Some(234));
310 assert_eq!(receiver.recv_one().unwrap(), None);
311 }
312
313 #[test]
314 fn send_recv_dropped() {
315 let (mut sender, mut receiver) = channel::<u64>();
316 sender.send(9452).unwrap();
317 drop(sender);
318 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
319 assert!(receiver.recv_one().is_err());
320 }
321
322 #[test]
323 fn send_recv_twice_dropped() {
324 let (mut sender, mut receiver) = channel::<u64>();
325 sender.send(9452).unwrap();
326 sender.send(12).unwrap();
327 drop(sender);
328 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
329 assert_eq!(receiver.recv_one().unwrap(), Some(12));
330 assert!(receiver.recv_one().is_err());
331 }
332
333 #[test]
334 fn send_recv_twice_dropped_interleaved() {
335 let (mut sender, mut receiver) = channel::<u64>();
336 sender.send(9452).unwrap();
337 assert_eq!(receiver.recv_one().unwrap(), Some(9452));
338 sender.send(12).unwrap();
339 drop(sender);
340 assert_eq!(receiver.recv_one().unwrap(), Some(12));
341 assert!(receiver.recv_one().is_err());
342 }
343
344 #[test]
345 fn recv_many() {
346 let (mut sender, mut receiver) = channel::<u64>();
347 sender.send(1).unwrap();
348 sender.send(2).unwrap();
349 sender.send(3).unwrap();
350 sender.send(4).unwrap();
351 let recv_many = receiver.recv_many().unwrap();
352 sender.send(5).unwrap();
353 sender.send(6).unwrap();
354 let numbers: Vec<_> = recv_many.collect();
355 assert_eq!(numbers, vec![1, 2, 3, 4]);
356 }
357
358 #[test]
359 fn recv_many_dropped() {
360 let (mut sender, mut receiver) = channel::<u64>();
361 sender.send(1).unwrap();
362 assert_eq!(receiver.recv_one().unwrap(), Some(1));
363 drop(sender);
364 receiver.recv_many().unwrap_err();
365 }
366
367 #[test]
368 fn recv_many_dropped_with_messages() {
369 let (mut sender, mut receiver) = channel::<u64>();
370 sender.send(1).unwrap();
371 assert_eq!(receiver.recv_one().unwrap(), Some(1));
372 sender.send(2).unwrap();
373 sender.send(3).unwrap();
374 sender.send(4).unwrap();
375 drop(sender);
376 let recv_many = receiver.recv_many().unwrap();
377 let numbers: Vec<_> = recv_many.collect();
378 assert_eq!(numbers, vec![2, 3, 4]);
379 }
380}