thedes_async_util/non_blocking/spsc/
watch.rs1use std::{
2 fmt,
3 ops::Deref,
4 ptr::{NonNull, null_mut},
5 sync::atomic::{
6 AtomicBool,
7 AtomicPtr,
8 Ordering::{self, *},
9 },
10};
11
12use thiserror::Error;
13
14pub fn channel<T>() -> (Sender<T>, Receiver<T>)
15where
16 T: AtomicMessage,
17{
18 let (sender_shared, receiver_shared) = SharedPtr::new();
19 (Sender::new(sender_shared), Receiver::new(receiver_shared))
20}
21
22#[derive(Debug, Clone, Error)]
23#[error("Receiver disconnected")]
24pub struct SendError<T> {
25 message: T,
26}
27
28impl<T> SendError<T> {
29 fn new(message: T) -> Self {
30 Self { message }
31 }
32
33 pub fn message(&self) -> &T {
34 &self.message
35 }
36
37 pub fn into_message(self) -> T {
38 self.message
39 }
40}
41
42#[derive(Debug, Clone, Error)]
43#[error("Senders disconnected")]
44pub struct RecvError {
45 _private: (),
46}
47
48impl RecvError {
49 fn new() -> Self {
50 Self { _private: () }
51 }
52}
53
54pub trait AtomicMessage: Sized {
55 type Data;
56
57 fn empty() -> Self;
58
59 fn take(&self, ordering: Ordering) -> Option<Self::Data>;
60
61 fn store(&self, value: Self::Data, ordering: Ordering);
62}
63
64struct Shared<M> {
65 connected: AtomicBool,
66 current: M,
67}
68
69impl<M> Shared<M>
70where
71 M: AtomicMessage,
72{
73 pub fn new_connected() -> Self {
74 Self { current: M::empty(), connected: AtomicBool::new(true) }
75 }
76
77 pub fn is_connected_weak(&self) -> bool {
78 self.connected.load(Relaxed)
79 }
80
81 pub fn send(&self, message: M::Data) -> Result<(), SendError<M::Data>> {
82 if self.connected.load(Acquire) {
83 self.current.store(message, Release);
84 Ok(())
85 } else {
86 Err(SendError::new(message))
87 }
88 }
89
90 pub fn recv(&self) -> Result<Option<M::Data>, RecvError> {
91 match self.current.take(Acquire) {
92 Some(data) => Ok(Some(data)),
93 None => {
94 if self.connected.load(Acquire) {
95 Ok(None)
96 } else {
97 Err(RecvError::new())
98 }
99 },
100 }
101 }
102}
103
104impl<M> fmt::Debug for Shared<M>
105where
106 M: fmt::Debug,
107{
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 f.debug_struct("Shared")
110 .field("current", &self.current)
111 .field("connected", &self.connected)
112 .finish()
113 }
114}
115
116impl<M> Drop for Shared<M> {
117 fn drop(&mut self) {}
118}
119
120struct SharedPtr<M> {
121 inner: NonNull<Shared<M>>,
122}
123
124impl<M> SharedPtr<M>
125where
126 M: AtomicMessage,
127{
128 pub fn new() -> (Self, Self) {
129 let shared = Shared::new_connected();
130 let shared_boxed = Box::new(shared);
131
132 let inner =
133 unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
134
135 (Self { inner }, Self { inner })
136 }
137}
138
139impl<M> fmt::Debug for SharedPtr<M>
140where
141 M: fmt::Debug,
142{
143 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144 f.debug_struct("SharedPtr").field("inner", &**self).finish()
145 }
146}
147
148impl<M> Deref for SharedPtr<M> {
149 type Target = Shared<M>;
150
151 fn deref(&self) -> &Self::Target {
152 unsafe { self.inner.as_ref() }
153 }
154}
155
156impl<M> Drop for SharedPtr<M> {
157 fn drop(&mut self) {
158 unsafe {
159 let last = self
160 .connected
161 .compare_exchange(true, false, AcqRel, Acquire)
162 .is_err();
163 if last {
164 let _ = Box::from_raw(self.inner.as_ptr());
165 }
166 }
167 }
168}
169
170#[derive(Debug)]
171pub struct Sender<M> {
172 shared: SharedPtr<M>,
173}
174
175impl<M> Sender<M>
176where
177 M: AtomicMessage,
178{
179 fn new(shared: SharedPtr<M>) -> Self {
180 Self { shared }
181 }
182
183 pub fn is_connected(&self) -> bool {
184 self.shared.is_connected_weak()
185 }
186
187 pub fn send(&mut self, message: M::Data) -> Result<(), SendError<M::Data>> {
188 self.shared.send(message)
189 }
190}
191
192unsafe impl<M> Send for Sender<M>
193where
194 M: AtomicMessage + Send,
195 M::Data: Send,
196{
197}
198
199unsafe impl<M> Sync for Sender<M>
200where
201 M: AtomicMessage + Send,
202 M::Data: Send,
203{
204}
205
206#[derive(Debug)]
207pub struct Receiver<M> {
208 shared: SharedPtr<M>,
209}
210
211impl<M> Receiver<M>
212where
213 M: AtomicMessage,
214{
215 fn new(shared: SharedPtr<M>) -> Self {
216 Self { shared }
217 }
218
219 pub fn is_connected(&self) -> bool {
220 self.shared.is_connected_weak()
221 }
222
223 pub fn recv(&mut self) -> Result<Option<M::Data>, RecvError> {
224 self.shared.recv()
225 }
226}
227
228unsafe impl<M> Send for Receiver<M>
229where
230 M: AtomicMessage + Send,
231 M::Data: Send,
232{
233}
234
235unsafe impl<M> Sync for Receiver<M>
236where
237 M: AtomicMessage + Send,
238 M::Data: Send,
239{
240}
241
242#[derive(Debug)]
243pub struct MessageBox<T> {
244 inner: AtomicPtr<T>,
245}
246
247impl<T> AtomicMessage for MessageBox<T> {
248 type Data = T;
249
250 fn empty() -> Self {
251 Self { inner: AtomicPtr::new(null_mut()) }
252 }
253
254 fn take(&self, ordering: Ordering) -> Option<Self::Data> {
255 let ptr = self.inner.swap(null_mut(), ordering);
256 if ptr.is_null() { None } else { unsafe { Some(*Box::from_raw(ptr)) } }
257 }
258
259 fn store(&self, value: Self::Data, ordering: Ordering) {
260 let ptr = Box::into_raw(Box::new(value));
261 self.inner.store(ptr, ordering);
262 }
263}
264
265impl<T> Drop for MessageBox<T> {
266 fn drop(&mut self) {
267 unsafe {
268 let ptr: *mut T = *self.inner.get_mut();
269 if !ptr.is_null() {
270 let _ = Box::from_raw(ptr);
271 }
272 }
273 }
274}
275
276#[cfg(test)]
277mod test {
278 use super::{MessageBox, channel};
279
280 #[test]
281 fn recv_empty() {
282 let (_sender, mut receiver) = channel::<MessageBox<u64>>();
283 assert_eq!(receiver.recv().unwrap(), None);
284 }
285
286 #[test]
287 fn recv_one() {
288 let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
289 sender.send(12).unwrap();
290 assert_eq!(receiver.recv().unwrap(), Some(12));
291 }
292
293 #[test]
294 fn recv_one_then_none() {
295 let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
296 sender.send(12).unwrap();
297 assert_eq!(receiver.recv().unwrap(), Some(12));
298 assert_eq!(receiver.recv().unwrap(), None);
299 }
300
301 #[test]
302 fn recv_twice_then_none() {
303 let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
304 sender.send(12).unwrap();
305 assert_eq!(receiver.recv().unwrap(), Some(12));
306 sender.send(13).unwrap();
307 assert_eq!(receiver.recv().unwrap(), Some(13));
308 assert_eq!(receiver.recv().unwrap(), None);
309 assert_eq!(receiver.recv().unwrap(), None);
310 }
311
312 #[test]
313 fn sender_disconnected() {
314 let (_, mut receiver) = channel::<MessageBox<u64>>();
315 receiver.recv().unwrap_err();
316 }
317
318 #[test]
319 fn receiver_disconnected() {
320 let (mut sender, _) = channel::<MessageBox<u64>>();
321 sender.send(32).unwrap_err();
322 }
323}