thedes_async_util/non_blocking/spsc/
watch.rs

1use std::{
2    fmt,
3    ops::Deref,
4    ptr::{NonNull, null_mut},
5    sync::atomic::{
6        AtomicBool,
7        AtomicPtr,
8        Ordering::{self, *},
9    },
10};
11
12use thiserror::Error;
13
14pub fn channel<T>() -> (Sender<T>, Receiver<T>)
15where
16    T: AtomicMessage,
17{
18    let (sender_shared, receiver_shared) = SharedPtr::new();
19    (Sender::new(sender_shared), Receiver::new(receiver_shared))
20}
21
22#[derive(Debug, Clone, Error)]
23#[error("Receiver disconnected")]
24pub struct SendError<T> {
25    message: T,
26}
27
28impl<T> SendError<T> {
29    fn new(message: T) -> Self {
30        Self { message }
31    }
32
33    pub fn message(&self) -> &T {
34        &self.message
35    }
36
37    pub fn into_message(self) -> T {
38        self.message
39    }
40}
41
42#[derive(Debug, Clone, Error)]
43#[error("Senders disconnected")]
44pub struct RecvError {
45    _private: (),
46}
47
48impl RecvError {
49    fn new() -> Self {
50        Self { _private: () }
51    }
52}
53
54pub trait AtomicMessage: Sized {
55    type Data;
56
57    fn empty() -> Self;
58
59    fn take(&self, ordering: Ordering) -> Option<Self::Data>;
60
61    fn store(&self, value: Self::Data, ordering: Ordering);
62}
63
64struct Shared<M> {
65    connected: AtomicBool,
66    current: M,
67}
68
69impl<M> Shared<M>
70where
71    M: AtomicMessage,
72{
73    pub fn new_connected() -> Self {
74        Self { current: M::empty(), connected: AtomicBool::new(true) }
75    }
76
77    pub fn is_connected_weak(&self) -> bool {
78        self.connected.load(Relaxed)
79    }
80
81    pub fn send(&self, message: M::Data) -> Result<(), SendError<M::Data>> {
82        if self.connected.load(Acquire) {
83            self.current.store(message, Release);
84            Ok(())
85        } else {
86            Err(SendError::new(message))
87        }
88    }
89
90    pub fn recv(&self) -> Result<Option<M::Data>, RecvError> {
91        match self.current.take(Acquire) {
92            Some(data) => Ok(Some(data)),
93            None => {
94                if self.connected.load(Acquire) {
95                    Ok(None)
96                } else {
97                    Err(RecvError::new())
98                }
99            },
100        }
101    }
102}
103
104impl<M> fmt::Debug for Shared<M>
105where
106    M: fmt::Debug,
107{
108    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109        f.debug_struct("Shared")
110            .field("current", &self.current)
111            .field("connected", &self.connected)
112            .finish()
113    }
114}
115
116impl<M> Drop for Shared<M> {
117    fn drop(&mut self) {}
118}
119
120struct SharedPtr<M> {
121    inner: NonNull<Shared<M>>,
122}
123
124impl<M> SharedPtr<M>
125where
126    M: AtomicMessage,
127{
128    pub fn new() -> (Self, Self) {
129        let shared = Shared::new_connected();
130        let shared_boxed = Box::new(shared);
131
132        let inner =
133            unsafe { NonNull::new_unchecked(Box::into_raw(shared_boxed)) };
134
135        (Self { inner }, Self { inner })
136    }
137}
138
139impl<M> fmt::Debug for SharedPtr<M>
140where
141    M: fmt::Debug,
142{
143    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
144        f.debug_struct("SharedPtr").field("inner", &**self).finish()
145    }
146}
147
148impl<M> Deref for SharedPtr<M> {
149    type Target = Shared<M>;
150
151    fn deref(&self) -> &Self::Target {
152        unsafe { self.inner.as_ref() }
153    }
154}
155
156impl<M> Drop for SharedPtr<M> {
157    fn drop(&mut self) {
158        unsafe {
159            let last = self
160                .connected
161                .compare_exchange(true, false, AcqRel, Acquire)
162                .is_err();
163            if last {
164                let _ = Box::from_raw(self.inner.as_ptr());
165            }
166        }
167    }
168}
169
170#[derive(Debug)]
171pub struct Sender<M> {
172    shared: SharedPtr<M>,
173}
174
175impl<M> Sender<M>
176where
177    M: AtomicMessage,
178{
179    fn new(shared: SharedPtr<M>) -> Self {
180        Self { shared }
181    }
182
183    pub fn is_connected(&self) -> bool {
184        self.shared.is_connected_weak()
185    }
186
187    pub fn send(&mut self, message: M::Data) -> Result<(), SendError<M::Data>> {
188        self.shared.send(message)
189    }
190}
191
192unsafe impl<M> Send for Sender<M>
193where
194    M: AtomicMessage + Send,
195    M::Data: Send,
196{
197}
198
199unsafe impl<M> Sync for Sender<M>
200where
201    M: AtomicMessage + Send,
202    M::Data: Send,
203{
204}
205
206#[derive(Debug)]
207pub struct Receiver<M> {
208    shared: SharedPtr<M>,
209}
210
211impl<M> Receiver<M>
212where
213    M: AtomicMessage,
214{
215    fn new(shared: SharedPtr<M>) -> Self {
216        Self { shared }
217    }
218
219    pub fn is_connected(&self) -> bool {
220        self.shared.is_connected_weak()
221    }
222
223    pub fn recv(&mut self) -> Result<Option<M::Data>, RecvError> {
224        self.shared.recv()
225    }
226}
227
228unsafe impl<M> Send for Receiver<M>
229where
230    M: AtomicMessage + Send,
231    M::Data: Send,
232{
233}
234
235unsafe impl<M> Sync for Receiver<M>
236where
237    M: AtomicMessage + Send,
238    M::Data: Send,
239{
240}
241
242#[derive(Debug)]
243pub struct MessageBox<T> {
244    inner: AtomicPtr<T>,
245}
246
247impl<T> AtomicMessage for MessageBox<T> {
248    type Data = T;
249
250    fn empty() -> Self {
251        Self { inner: AtomicPtr::new(null_mut()) }
252    }
253
254    fn take(&self, ordering: Ordering) -> Option<Self::Data> {
255        let ptr = self.inner.swap(null_mut(), ordering);
256        if ptr.is_null() { None } else { unsafe { Some(*Box::from_raw(ptr)) } }
257    }
258
259    fn store(&self, value: Self::Data, ordering: Ordering) {
260        let ptr = Box::into_raw(Box::new(value));
261        self.inner.store(ptr, ordering);
262    }
263}
264
265impl<T> Drop for MessageBox<T> {
266    fn drop(&mut self) {
267        unsafe {
268            let ptr: *mut T = *self.inner.get_mut();
269            if !ptr.is_null() {
270                let _ = Box::from_raw(ptr);
271            }
272        }
273    }
274}
275
276#[cfg(test)]
277mod test {
278    use super::{MessageBox, channel};
279
280    #[test]
281    fn recv_empty() {
282        let (_sender, mut receiver) = channel::<MessageBox<u64>>();
283        assert_eq!(receiver.recv().unwrap(), None);
284    }
285
286    #[test]
287    fn recv_one() {
288        let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
289        sender.send(12).unwrap();
290        assert_eq!(receiver.recv().unwrap(), Some(12));
291    }
292
293    #[test]
294    fn recv_one_then_none() {
295        let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
296        sender.send(12).unwrap();
297        assert_eq!(receiver.recv().unwrap(), Some(12));
298        assert_eq!(receiver.recv().unwrap(), None);
299    }
300
301    #[test]
302    fn recv_twice_then_none() {
303        let (mut sender, mut receiver) = channel::<MessageBox<u64>>();
304        sender.send(12).unwrap();
305        assert_eq!(receiver.recv().unwrap(), Some(12));
306        sender.send(13).unwrap();
307        assert_eq!(receiver.recv().unwrap(), Some(13));
308        assert_eq!(receiver.recv().unwrap(), None);
309        assert_eq!(receiver.recv().unwrap(), None);
310    }
311
312    #[test]
313    fn sender_disconnected() {
314        let (_, mut receiver) = channel::<MessageBox<u64>>();
315        receiver.recv().unwrap_err();
316    }
317
318    #[test]
319    fn receiver_disconnected() {
320        let (mut sender, _) = channel::<MessageBox<u64>>();
321        sender.send(32).unwrap_err();
322    }
323}