hyper/server/
tcp.rs

1use socket2::TcpKeepalive;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::net::{SocketAddr, TcpListener as StdTcpListener};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use tokio::net::TcpListener;
11use tokio::time::Sleep;
12use tracing::{debug, error, trace};
13
14#[allow(unreachable_pub)] // https://github.com/rust-lang/rust/issues/57411
15pub use self::addr_stream::AddrStream;
16use super::accept::Accept;
17
18#[derive(Default, Debug, Clone, Copy)]
19struct TcpKeepaliveConfig {
20    time: Option<Duration>,
21    interval: Option<Duration>,
22    retries: Option<u32>,
23}
24
25impl TcpKeepaliveConfig {
26    /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration.
27    fn into_socket2(self) -> Option<TcpKeepalive> {
28        let mut dirty = false;
29        let mut ka = TcpKeepalive::new();
30        if let Some(time) = self.time {
31            ka = ka.with_time(time);
32            dirty = true
33        }
34        if let Some(interval) = self.interval {
35            ka = Self::ka_with_interval(ka, interval, &mut dirty)
36        };
37        if let Some(retries) = self.retries {
38            ka = Self::ka_with_retries(ka, retries, &mut dirty)
39        };
40        if dirty {
41            Some(ka)
42        } else {
43            None
44        }
45    }
46
47    #[cfg(any(
48        target_os = "android",
49        target_os = "dragonfly",
50        target_os = "freebsd",
51        target_os = "fuchsia",
52        target_os = "illumos",
53        target_os = "linux",
54        target_os = "netbsd",
55        target_vendor = "apple",
56        windows,
57    ))]
58    fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
59        *dirty = true;
60        ka.with_interval(interval)
61    }
62
63    #[cfg(not(any(
64        target_os = "android",
65        target_os = "dragonfly",
66        target_os = "freebsd",
67        target_os = "fuchsia",
68        target_os = "illumos",
69        target_os = "linux",
70        target_os = "netbsd",
71        target_vendor = "apple",
72        windows,
73    )))]
74    fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
75        ka // no-op as keepalive interval is not supported on this platform
76    }
77
78    #[cfg(any(
79        target_os = "android",
80        target_os = "dragonfly",
81        target_os = "freebsd",
82        target_os = "fuchsia",
83        target_os = "illumos",
84        target_os = "linux",
85        target_os = "netbsd",
86        target_vendor = "apple",
87    ))]
88    fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
89        *dirty = true;
90        ka.with_retries(retries)
91    }
92
93    #[cfg(not(any(
94        target_os = "android",
95        target_os = "dragonfly",
96        target_os = "freebsd",
97        target_os = "fuchsia",
98        target_os = "illumos",
99        target_os = "linux",
100        target_os = "netbsd",
101        target_vendor = "apple",
102    )))]
103    fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
104        ka // no-op as keepalive retries is not supported on this platform
105    }
106}
107
108/// A stream of connections from binding to an address.
109#[must_use = "streams do nothing unless polled"]
110pub struct AddrIncoming {
111    addr: SocketAddr,
112    listener: TcpListener,
113    sleep_on_errors: bool,
114    tcp_keepalive_config: TcpKeepaliveConfig,
115    tcp_nodelay: bool,
116    timeout: Option<Pin<Box<Sleep>>>,
117}
118
119impl AddrIncoming {
120    pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
121        let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
122
123        AddrIncoming::from_std(std_listener)
124    }
125
126    pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
127        // TcpListener::from_std doesn't set O_NONBLOCK
128        std_listener
129            .set_nonblocking(true)
130            .map_err(crate::Error::new_listen)?;
131        let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
132        AddrIncoming::from_listener(listener)
133    }
134
135    /// Creates a new `AddrIncoming` binding to provided socket address.
136    pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
137        AddrIncoming::new(addr)
138    }
139
140    /// Creates a new `AddrIncoming` from an existing `tokio::net::TcpListener`.
141    pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
142        let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
143        Ok(AddrIncoming {
144            listener,
145            addr,
146            sleep_on_errors: true,
147            tcp_keepalive_config: TcpKeepaliveConfig::default(),
148            tcp_nodelay: false,
149            timeout: None,
150        })
151    }
152
153    /// Get the local address bound to this listener.
154    pub fn local_addr(&self) -> SocketAddr {
155        self.addr
156    }
157
158    /// Set the duration to remain idle before sending TCP keepalive probes.
159    ///
160    /// If `None` is specified, keepalive is disabled.
161    pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self {
162        self.tcp_keepalive_config.time = time;
163        self
164    }
165
166    /// Set the duration between two successive TCP keepalive retransmissions,
167    /// if acknowledgement to the previous keepalive transmission is not received.
168    pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self {
169        self.tcp_keepalive_config.interval = interval;
170        self
171    }
172
173    /// Set the number of retransmissions to be carried out before declaring that remote end is not available.
174    pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self {
175        self.tcp_keepalive_config.retries = retries;
176        self
177    }
178
179    /// Set the value of `TCP_NODELAY` option for accepted connections.
180    pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
181        self.tcp_nodelay = enabled;
182        self
183    }
184
185    /// Set whether to sleep on accept errors.
186    ///
187    /// A possible scenario is that the process has hit the max open files
188    /// allowed, and so trying to accept a new connection will fail with
189    /// `EMFILE`. In some cases, it's preferable to just wait for some time, if
190    /// the application will likely close some files (or connections), and try
191    /// to accept the connection again. If this option is `true`, the error
192    /// will be logged at the `error` level, since it is still a big deal,
193    /// and then the listener will sleep for 1 second.
194    ///
195    /// In other cases, hitting the max open files should be treat similarly
196    /// to being out-of-memory, and simply error (and shutdown). Setting
197    /// this option to `false` will allow that.
198    ///
199    /// Default is `true`.
200    pub fn set_sleep_on_errors(&mut self, val: bool) {
201        self.sleep_on_errors = val;
202    }
203
204    fn poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>> {
205        // Check if a previous timeout is active that was set by IO errors.
206        if let Some(ref mut to) = self.timeout {
207            ready!(Pin::new(to).poll(cx));
208        }
209        self.timeout = None;
210
211        loop {
212            match ready!(self.listener.poll_accept(cx)) {
213                Ok((socket, remote_addr)) => {
214                    if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() {
215                        let sock_ref = socket2::SockRef::from(&socket);
216                        if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) {
217                            trace!("error trying to set TCP keepalive: {}", e);
218                        }
219                    }
220                    if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
221                        trace!("error trying to set TCP nodelay: {}", e);
222                    }
223                    let local_addr = socket.local_addr()?;
224                    return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr)));
225                }
226                Err(e) => {
227                    // Connection errors can be ignored directly, continue by
228                    // accepting the next request.
229                    if is_connection_error(&e) {
230                        debug!("accepted connection already errored: {}", e);
231                        continue;
232                    }
233
234                    if self.sleep_on_errors {
235                        error!("accept error: {}", e);
236
237                        // Sleep 1s.
238                        let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
239
240                        match timeout.as_mut().poll(cx) {
241                            Poll::Ready(()) => {
242                                // Wow, it's been a second already? Ok then...
243                                continue;
244                            }
245                            Poll::Pending => {
246                                self.timeout = Some(timeout);
247                                return Poll::Pending;
248                            }
249                        }
250                    } else {
251                        return Poll::Ready(Err(e));
252                    }
253                }
254            }
255        }
256    }
257}
258
259impl Accept for AddrIncoming {
260    type Conn = AddrStream;
261    type Error = io::Error;
262
263    fn poll_accept(
264        mut self: Pin<&mut Self>,
265        cx: &mut Context<'_>,
266    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
267        let result = ready!(self.poll_next_(cx));
268        Poll::Ready(Some(result))
269    }
270}
271
272/// This function defines errors that are per-connection. Which basically
273/// means that if we get this error from `accept()` system call it means
274/// next connection might be ready to be accepted.
275///
276/// All other errors will incur a timeout before next `accept()` is performed.
277/// The timeout is useful to handle resource exhaustion errors like ENFILE
278/// and EMFILE. Otherwise, could enter into tight loop.
279fn is_connection_error(e: &io::Error) -> bool {
280    matches!(
281        e.kind(),
282        io::ErrorKind::ConnectionRefused
283            | io::ErrorKind::ConnectionAborted
284            | io::ErrorKind::ConnectionReset
285    )
286}
287
288impl fmt::Debug for AddrIncoming {
289    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290        f.debug_struct("AddrIncoming")
291            .field("addr", &self.addr)
292            .field("sleep_on_errors", &self.sleep_on_errors)
293            .field("tcp_keepalive_config", &self.tcp_keepalive_config)
294            .field("tcp_nodelay", &self.tcp_nodelay)
295            .finish()
296    }
297}
298
299mod addr_stream {
300    use std::io;
301    use std::net::SocketAddr;
302    #[cfg(unix)]
303    use std::os::unix::io::{AsRawFd, RawFd};
304    use std::pin::Pin;
305    use std::task::{Context, Poll};
306    use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
307    use tokio::net::TcpStream;
308
309    pin_project_lite::pin_project! {
310        /// A transport returned yieled by `AddrIncoming`.
311        #[derive(Debug)]
312        pub struct AddrStream {
313            #[pin]
314            inner: TcpStream,
315            pub(super) remote_addr: SocketAddr,
316            pub(super) local_addr: SocketAddr
317        }
318    }
319
320    impl AddrStream {
321        pub(super) fn new(
322            tcp: TcpStream,
323            remote_addr: SocketAddr,
324            local_addr: SocketAddr,
325        ) -> AddrStream {
326            AddrStream {
327                inner: tcp,
328                remote_addr,
329                local_addr,
330            }
331        }
332
333        /// Returns the remote (peer) address of this connection.
334        #[inline]
335        pub fn remote_addr(&self) -> SocketAddr {
336            self.remote_addr
337        }
338
339        /// Returns the local address of this connection.
340        #[inline]
341        pub fn local_addr(&self) -> SocketAddr {
342            self.local_addr
343        }
344
345        /// Consumes the AddrStream and returns the underlying IO object
346        #[inline]
347        pub fn into_inner(self) -> TcpStream {
348            self.inner
349        }
350
351        /// Attempt to receive data on the socket, without removing that data
352        /// from the queue, registering the current task for wakeup if data is
353        /// not yet available.
354        pub fn poll_peek(
355            &mut self,
356            cx: &mut Context<'_>,
357            buf: &mut tokio::io::ReadBuf<'_>,
358        ) -> Poll<io::Result<usize>> {
359            self.inner.poll_peek(cx, buf)
360        }
361    }
362
363    impl AsyncRead for AddrStream {
364        #[inline]
365        fn poll_read(
366            self: Pin<&mut Self>,
367            cx: &mut Context<'_>,
368            buf: &mut ReadBuf<'_>,
369        ) -> Poll<io::Result<()>> {
370            self.project().inner.poll_read(cx, buf)
371        }
372    }
373
374    impl AsyncWrite for AddrStream {
375        #[inline]
376        fn poll_write(
377            self: Pin<&mut Self>,
378            cx: &mut Context<'_>,
379            buf: &[u8],
380        ) -> Poll<io::Result<usize>> {
381            self.project().inner.poll_write(cx, buf)
382        }
383
384        #[inline]
385        fn poll_write_vectored(
386            self: Pin<&mut Self>,
387            cx: &mut Context<'_>,
388            bufs: &[io::IoSlice<'_>],
389        ) -> Poll<io::Result<usize>> {
390            self.project().inner.poll_write_vectored(cx, bufs)
391        }
392
393        #[inline]
394        fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395            // TCP flush is a noop
396            Poll::Ready(Ok(()))
397        }
398
399        #[inline]
400        fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401            self.project().inner.poll_shutdown(cx)
402        }
403
404        #[inline]
405        fn is_write_vectored(&self) -> bool {
406            // Note that since `self.inner` is a `TcpStream`, this could
407            // *probably* be hard-coded to return `true`...but it seems more
408            // correct to ask it anyway (maybe we're on some platform without
409            // scatter-gather IO?)
410            self.inner.is_write_vectored()
411        }
412    }
413
414    #[cfg(unix)]
415    impl AsRawFd for AddrStream {
416        fn as_raw_fd(&self) -> RawFd {
417            self.inner.as_raw_fd()
418        }
419    }
420}
421
422#[cfg(test)]
423mod tests {
424    use crate::server::tcp::TcpKeepaliveConfig;
425    use std::time::Duration;
426
427    #[test]
428    fn no_tcp_keepalive_config() {
429        assert!(TcpKeepaliveConfig::default().into_socket2().is_none());
430    }
431
432    #[test]
433    fn tcp_keepalive_time_config() {
434        let mut kac = TcpKeepaliveConfig::default();
435        kac.time = Some(Duration::from_secs(60));
436        if let Some(tcp_keepalive) = kac.into_socket2() {
437            assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
438        } else {
439            panic!("test failed");
440        }
441    }
442
443    #[cfg(any(
444        target_os = "android",
445        target_os = "dragonfly",
446        target_os = "freebsd",
447        target_os = "fuchsia",
448        target_os = "illumos",
449        target_os = "linux",
450        target_os = "netbsd",
451        target_vendor = "apple",
452        windows,
453    ))]
454    #[test]
455    fn tcp_keepalive_interval_config() {
456        let mut kac = TcpKeepaliveConfig::default();
457        kac.interval = Some(Duration::from_secs(1));
458        if let Some(tcp_keepalive) = kac.into_socket2() {
459            assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
460        } else {
461            panic!("test failed");
462        }
463    }
464
465    #[cfg(any(
466        target_os = "android",
467        target_os = "dragonfly",
468        target_os = "freebsd",
469        target_os = "fuchsia",
470        target_os = "illumos",
471        target_os = "linux",
472        target_os = "netbsd",
473        target_vendor = "apple",
474    ))]
475    #[test]
476    fn tcp_keepalive_retries_config() {
477        let mut kac = TcpKeepaliveConfig::default();
478        kac.retries = Some(3);
479        if let Some(tcp_keepalive) = kac.into_socket2() {
480            assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
481        } else {
482            panic!("test failed");
483        }
484    }
485}