hyper/common/
drain.rs
1use std::future::Future;
2use std::mem;
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use pin_project_lite::pin_project;
7use tokio::sync::watch;
8
9pub(crate) fn channel() -> (Signal, Watch) {
10 let (tx, rx) = watch::channel(());
11 (Signal { tx }, Watch { rx })
12}
13
14pub(crate) struct Signal {
15 tx: watch::Sender<()>,
16}
17
18pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>);
19
20#[derive(Clone)]
21pub(crate) struct Watch {
22 rx: watch::Receiver<()>,
23}
24
25pin_project! {
26 #[allow(missing_debug_implementations)]
27 pub struct Watching<F, FN> {
28 #[pin]
29 future: F,
30 state: State<FN>,
31 watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
32 _rx: watch::Receiver<()>,
33 }
34}
35
36enum State<F> {
37 Watch(F),
38 Draining,
39}
40
41impl Signal {
42 pub(crate) fn drain(self) -> Draining {
43 let _ = self.tx.send(());
44 Draining(Box::pin(async move { self.tx.closed().await }))
45 }
46}
47
48impl Future for Draining {
49 type Output = ();
50
51 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52 Pin::new(&mut self.as_mut().0).poll(cx)
53 }
54}
55
56impl Watch {
57 pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
58 where
59 F: Future,
60 FN: FnOnce(Pin<&mut F>),
61 {
62 let Self { mut rx } = self;
63 let _rx = rx.clone();
64 Watching {
65 future,
66 state: State::Watch(on_drain),
67 watch: Box::pin(async move {
68 let _ = rx.changed().await;
69 }),
70 _rx,
73 }
74 }
75}
76
77impl<F, FN> Future for Watching<F, FN>
78where
79 F: Future,
80 FN: FnOnce(Pin<&mut F>),
81{
82 type Output = F::Output;
83
84 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85 let mut me = self.project();
86 loop {
87 match mem::replace(me.state, State::Draining) {
88 State::Watch(on_drain) => {
89 match Pin::new(&mut me.watch).poll(cx) {
90 Poll::Ready(()) => {
91 on_drain(me.future.as_mut());
93 }
94 Poll::Pending => {
95 *me.state = State::Watch(on_drain);
96 return me.future.poll(cx);
97 }
98 }
99 }
100 State::Draining => return me.future.poll(cx),
101 }
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 struct TestMe {
111 draining: bool,
112 finished: bool,
113 poll_cnt: usize,
114 }
115
116 impl Future for TestMe {
117 type Output = ();
118
119 fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
120 self.poll_cnt += 1;
121 if self.finished {
122 Poll::Ready(())
123 } else {
124 Poll::Pending
125 }
126 }
127 }
128
129 #[test]
130 fn watch() {
131 let mut mock = tokio_test::task::spawn(());
132 mock.enter(|cx, _| {
133 let (tx, rx) = channel();
134 let fut = TestMe {
135 draining: false,
136 finished: false,
137 poll_cnt: 0,
138 };
139
140 let mut watch = rx.watch(fut, |mut fut| {
141 fut.draining = true;
142 });
143
144 assert_eq!(watch.future.poll_cnt, 0);
145
146 assert!(Pin::new(&mut watch).poll(cx).is_pending());
148 assert_eq!(watch.future.poll_cnt, 1);
149
150 assert!(Pin::new(&mut watch).poll(cx).is_pending());
152 assert_eq!(watch.future.poll_cnt, 2);
153
154 let mut draining = tx.drain();
155 assert!(!watch.future.draining);
157 assert_eq!(watch.future.poll_cnt, 2);
158
159 assert!(Pin::new(&mut watch).poll(cx).is_pending());
161 assert_eq!(watch.future.poll_cnt, 3);
162 assert!(watch.future.draining);
163
164 assert!(Pin::new(&mut draining).poll(cx).is_pending());
166
167 watch.future.finished = true;
169 assert!(Pin::new(&mut watch).poll(cx).is_ready());
170 assert_eq!(watch.future.poll_cnt, 4);
171 drop(watch);
172
173 assert!(Pin::new(&mut draining).poll(cx).is_ready());
174 })
175 }
176
177 #[test]
178 fn watch_clones() {
179 let mut mock = tokio_test::task::spawn(());
180 mock.enter(|cx, _| {
181 let (tx, rx) = channel();
182
183 let fut1 = TestMe {
184 draining: false,
185 finished: false,
186 poll_cnt: 0,
187 };
188 let fut2 = TestMe {
189 draining: false,
190 finished: false,
191 poll_cnt: 0,
192 };
193
194 let watch1 = rx.clone().watch(fut1, |mut fut| {
195 fut.draining = true;
196 });
197 let watch2 = rx.watch(fut2, |mut fut| {
198 fut.draining = true;
199 });
200
201 let mut draining = tx.drain();
202
203 assert!(Pin::new(&mut draining).poll(cx).is_pending());
205
206 drop(watch1);
208
209 assert!(Pin::new(&mut draining).poll(cx).is_pending());
211
212 drop(watch2);
213
214 assert!(Pin::new(&mut draining).poll(cx).is_ready());
216 });
217 }
218}