hyper/common/
drain.rs

1use std::future::Future;
2use std::mem;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use pin_project_lite::pin_project;
7use tokio::sync::watch;
8
9pub(crate) fn channel() -> (Signal, Watch) {
10    let (tx, rx) = watch::channel(());
11    (Signal { tx }, Watch { rx })
12}
13
14pub(crate) struct Signal {
15    tx: watch::Sender<()>,
16}
17
18pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>);
19
20#[derive(Clone)]
21pub(crate) struct Watch {
22    rx: watch::Receiver<()>,
23}
24
25pin_project! {
26    #[allow(missing_debug_implementations)]
27    pub struct Watching<F, FN> {
28        #[pin]
29        future: F,
30        state: State<FN>,
31        watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
32        _rx: watch::Receiver<()>,
33    }
34}
35
36enum State<F> {
37    Watch(F),
38    Draining,
39}
40
41impl Signal {
42    pub(crate) fn drain(self) -> Draining {
43        let _ = self.tx.send(());
44        Draining(Box::pin(async move { self.tx.closed().await }))
45    }
46}
47
48impl Future for Draining {
49    type Output = ();
50
51    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52        Pin::new(&mut self.as_mut().0).poll(cx)
53    }
54}
55
56impl Watch {
57    pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
58    where
59        F: Future,
60        FN: FnOnce(Pin<&mut F>),
61    {
62        let Self { mut rx } = self;
63        let _rx = rx.clone();
64        Watching {
65            future,
66            state: State::Watch(on_drain),
67            watch: Box::pin(async move {
68                let _ = rx.changed().await;
69            }),
70            // Keep the receiver alive until the future completes, so that
71            // dropping it can signal that draining has completed.
72            _rx,
73        }
74    }
75}
76
77impl<F, FN> Future for Watching<F, FN>
78where
79    F: Future,
80    FN: FnOnce(Pin<&mut F>),
81{
82    type Output = F::Output;
83
84    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85        let mut me = self.project();
86        loop {
87            match mem::replace(me.state, State::Draining) {
88                State::Watch(on_drain) => {
89                    match Pin::new(&mut me.watch).poll(cx) {
90                        Poll::Ready(()) => {
91                            // Drain has been triggered!
92                            on_drain(me.future.as_mut());
93                        }
94                        Poll::Pending => {
95                            *me.state = State::Watch(on_drain);
96                            return me.future.poll(cx);
97                        }
98                    }
99                }
100                State::Draining => return me.future.poll(cx),
101            }
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    struct TestMe {
111        draining: bool,
112        finished: bool,
113        poll_cnt: usize,
114    }
115
116    impl Future for TestMe {
117        type Output = ();
118
119        fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
120            self.poll_cnt += 1;
121            if self.finished {
122                Poll::Ready(())
123            } else {
124                Poll::Pending
125            }
126        }
127    }
128
129    #[test]
130    fn watch() {
131        let mut mock = tokio_test::task::spawn(());
132        mock.enter(|cx, _| {
133            let (tx, rx) = channel();
134            let fut = TestMe {
135                draining: false,
136                finished: false,
137                poll_cnt: 0,
138            };
139
140            let mut watch = rx.watch(fut, |mut fut| {
141                fut.draining = true;
142            });
143
144            assert_eq!(watch.future.poll_cnt, 0);
145
146            // First poll should poll the inner future
147            assert!(Pin::new(&mut watch).poll(cx).is_pending());
148            assert_eq!(watch.future.poll_cnt, 1);
149
150            // Second poll should poll the inner future again
151            assert!(Pin::new(&mut watch).poll(cx).is_pending());
152            assert_eq!(watch.future.poll_cnt, 2);
153
154            let mut draining = tx.drain();
155            // Drain signaled, but needs another poll to be noticed.
156            assert!(!watch.future.draining);
157            assert_eq!(watch.future.poll_cnt, 2);
158
159            // Now, poll after drain has been signaled.
160            assert!(Pin::new(&mut watch).poll(cx).is_pending());
161            assert_eq!(watch.future.poll_cnt, 3);
162            assert!(watch.future.draining);
163
164            // Draining is not ready until watcher completes
165            assert!(Pin::new(&mut draining).poll(cx).is_pending());
166
167            // Finishing up the watch future
168            watch.future.finished = true;
169            assert!(Pin::new(&mut watch).poll(cx).is_ready());
170            assert_eq!(watch.future.poll_cnt, 4);
171            drop(watch);
172
173            assert!(Pin::new(&mut draining).poll(cx).is_ready());
174        })
175    }
176
177    #[test]
178    fn watch_clones() {
179        let mut mock = tokio_test::task::spawn(());
180        mock.enter(|cx, _| {
181            let (tx, rx) = channel();
182
183            let fut1 = TestMe {
184                draining: false,
185                finished: false,
186                poll_cnt: 0,
187            };
188            let fut2 = TestMe {
189                draining: false,
190                finished: false,
191                poll_cnt: 0,
192            };
193
194            let watch1 = rx.clone().watch(fut1, |mut fut| {
195                fut.draining = true;
196            });
197            let watch2 = rx.watch(fut2, |mut fut| {
198                fut.draining = true;
199            });
200
201            let mut draining = tx.drain();
202
203            // Still 2 outstanding watchers
204            assert!(Pin::new(&mut draining).poll(cx).is_pending());
205
206            // drop 1 for whatever reason
207            drop(watch1);
208
209            // Still not ready, 1 other watcher still pending
210            assert!(Pin::new(&mut draining).poll(cx).is_pending());
211
212            drop(watch2);
213
214            // Now all watchers are gone, draining is complete
215            assert!(Pin::new(&mut draining).poll(cx).is_ready());
216        });
217    }
218}