1use socket2::TcpKeepalive;
2use std::fmt;
3use std::future::Future;
4use std::io;
5use std::net::{SocketAddr, TcpListener as StdTcpListener};
6use std::pin::Pin;
7use std::task::{Context, Poll};
8use std::time::Duration;
9
10use tokio::net::TcpListener;
11use tokio::time::Sleep;
12use tracing::{debug, error, trace};
13
14#[allow(unreachable_pub)] pub use self::addr_stream::AddrStream;
16use super::accept::Accept;
17
18#[derive(Default, Debug, Clone, Copy)]
19struct TcpKeepaliveConfig {
20 time: Option<Duration>,
21 interval: Option<Duration>,
22 retries: Option<u32>,
23}
24
25impl TcpKeepaliveConfig {
26 fn into_socket2(self) -> Option<TcpKeepalive> {
28 let mut dirty = false;
29 let mut ka = TcpKeepalive::new();
30 if let Some(time) = self.time {
31 ka = ka.with_time(time);
32 dirty = true
33 }
34 if let Some(interval) = self.interval {
35 ka = Self::ka_with_interval(ka, interval, &mut dirty)
36 };
37 if let Some(retries) = self.retries {
38 ka = Self::ka_with_retries(ka, retries, &mut dirty)
39 };
40 if dirty {
41 Some(ka)
42 } else {
43 None
44 }
45 }
46
47 #[cfg(any(
48 target_os = "android",
49 target_os = "dragonfly",
50 target_os = "freebsd",
51 target_os = "fuchsia",
52 target_os = "illumos",
53 target_os = "linux",
54 target_os = "netbsd",
55 target_vendor = "apple",
56 windows,
57 ))]
58 fn ka_with_interval(ka: TcpKeepalive, interval: Duration, dirty: &mut bool) -> TcpKeepalive {
59 *dirty = true;
60 ka.with_interval(interval)
61 }
62
63 #[cfg(not(any(
64 target_os = "android",
65 target_os = "dragonfly",
66 target_os = "freebsd",
67 target_os = "fuchsia",
68 target_os = "illumos",
69 target_os = "linux",
70 target_os = "netbsd",
71 target_vendor = "apple",
72 windows,
73 )))]
74 fn ka_with_interval(ka: TcpKeepalive, _: Duration, _: &mut bool) -> TcpKeepalive {
75 ka }
77
78 #[cfg(any(
79 target_os = "android",
80 target_os = "dragonfly",
81 target_os = "freebsd",
82 target_os = "fuchsia",
83 target_os = "illumos",
84 target_os = "linux",
85 target_os = "netbsd",
86 target_vendor = "apple",
87 ))]
88 fn ka_with_retries(ka: TcpKeepalive, retries: u32, dirty: &mut bool) -> TcpKeepalive {
89 *dirty = true;
90 ka.with_retries(retries)
91 }
92
93 #[cfg(not(any(
94 target_os = "android",
95 target_os = "dragonfly",
96 target_os = "freebsd",
97 target_os = "fuchsia",
98 target_os = "illumos",
99 target_os = "linux",
100 target_os = "netbsd",
101 target_vendor = "apple",
102 )))]
103 fn ka_with_retries(ka: TcpKeepalive, _: u32, _: &mut bool) -> TcpKeepalive {
104 ka }
106}
107
108#[must_use = "streams do nothing unless polled"]
110pub struct AddrIncoming {
111 addr: SocketAddr,
112 listener: TcpListener,
113 sleep_on_errors: bool,
114 tcp_keepalive_config: TcpKeepaliveConfig,
115 tcp_nodelay: bool,
116 timeout: Option<Pin<Box<Sleep>>>,
117}
118
119impl AddrIncoming {
120 pub(super) fn new(addr: &SocketAddr) -> crate::Result<Self> {
121 let std_listener = StdTcpListener::bind(addr).map_err(crate::Error::new_listen)?;
122
123 AddrIncoming::from_std(std_listener)
124 }
125
126 pub(super) fn from_std(std_listener: StdTcpListener) -> crate::Result<Self> {
127 std_listener
129 .set_nonblocking(true)
130 .map_err(crate::Error::new_listen)?;
131 let listener = TcpListener::from_std(std_listener).map_err(crate::Error::new_listen)?;
132 AddrIncoming::from_listener(listener)
133 }
134
135 pub fn bind(addr: &SocketAddr) -> crate::Result<Self> {
137 AddrIncoming::new(addr)
138 }
139
140 pub fn from_listener(listener: TcpListener) -> crate::Result<Self> {
142 let addr = listener.local_addr().map_err(crate::Error::new_listen)?;
143 Ok(AddrIncoming {
144 listener,
145 addr,
146 sleep_on_errors: true,
147 tcp_keepalive_config: TcpKeepaliveConfig::default(),
148 tcp_nodelay: false,
149 timeout: None,
150 })
151 }
152
153 pub fn local_addr(&self) -> SocketAddr {
155 self.addr
156 }
157
158 pub fn set_keepalive(&mut self, time: Option<Duration>) -> &mut Self {
162 self.tcp_keepalive_config.time = time;
163 self
164 }
165
166 pub fn set_keepalive_interval(&mut self, interval: Option<Duration>) -> &mut Self {
169 self.tcp_keepalive_config.interval = interval;
170 self
171 }
172
173 pub fn set_keepalive_retries(&mut self, retries: Option<u32>) -> &mut Self {
175 self.tcp_keepalive_config.retries = retries;
176 self
177 }
178
179 pub fn set_nodelay(&mut self, enabled: bool) -> &mut Self {
181 self.tcp_nodelay = enabled;
182 self
183 }
184
185 pub fn set_sleep_on_errors(&mut self, val: bool) {
201 self.sleep_on_errors = val;
202 }
203
204 fn poll_next_(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<AddrStream>> {
205 if let Some(ref mut to) = self.timeout {
207 ready!(Pin::new(to).poll(cx));
208 }
209 self.timeout = None;
210
211 loop {
212 match ready!(self.listener.poll_accept(cx)) {
213 Ok((socket, remote_addr)) => {
214 if let Some(tcp_keepalive) = &self.tcp_keepalive_config.into_socket2() {
215 let sock_ref = socket2::SockRef::from(&socket);
216 if let Err(e) = sock_ref.set_tcp_keepalive(tcp_keepalive) {
217 trace!("error trying to set TCP keepalive: {}", e);
218 }
219 }
220 if let Err(e) = socket.set_nodelay(self.tcp_nodelay) {
221 trace!("error trying to set TCP nodelay: {}", e);
222 }
223 let local_addr = socket.local_addr()?;
224 return Poll::Ready(Ok(AddrStream::new(socket, remote_addr, local_addr)));
225 }
226 Err(e) => {
227 if is_connection_error(&e) {
230 debug!("accepted connection already errored: {}", e);
231 continue;
232 }
233
234 if self.sleep_on_errors {
235 error!("accept error: {}", e);
236
237 let mut timeout = Box::pin(tokio::time::sleep(Duration::from_secs(1)));
239
240 match timeout.as_mut().poll(cx) {
241 Poll::Ready(()) => {
242 continue;
244 }
245 Poll::Pending => {
246 self.timeout = Some(timeout);
247 return Poll::Pending;
248 }
249 }
250 } else {
251 return Poll::Ready(Err(e));
252 }
253 }
254 }
255 }
256 }
257}
258
259impl Accept for AddrIncoming {
260 type Conn = AddrStream;
261 type Error = io::Error;
262
263 fn poll_accept(
264 mut self: Pin<&mut Self>,
265 cx: &mut Context<'_>,
266 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
267 let result = ready!(self.poll_next_(cx));
268 Poll::Ready(Some(result))
269 }
270}
271
272fn is_connection_error(e: &io::Error) -> bool {
280 matches!(
281 e.kind(),
282 io::ErrorKind::ConnectionRefused
283 | io::ErrorKind::ConnectionAborted
284 | io::ErrorKind::ConnectionReset
285 )
286}
287
288impl fmt::Debug for AddrIncoming {
289 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
290 f.debug_struct("AddrIncoming")
291 .field("addr", &self.addr)
292 .field("sleep_on_errors", &self.sleep_on_errors)
293 .field("tcp_keepalive_config", &self.tcp_keepalive_config)
294 .field("tcp_nodelay", &self.tcp_nodelay)
295 .finish()
296 }
297}
298
299mod addr_stream {
300 use std::io;
301 use std::net::SocketAddr;
302 #[cfg(unix)]
303 use std::os::unix::io::{AsRawFd, RawFd};
304 use std::pin::Pin;
305 use std::task::{Context, Poll};
306 use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
307 use tokio::net::TcpStream;
308
309 pin_project_lite::pin_project! {
310 #[derive(Debug)]
312 pub struct AddrStream {
313 #[pin]
314 inner: TcpStream,
315 pub(super) remote_addr: SocketAddr,
316 pub(super) local_addr: SocketAddr
317 }
318 }
319
320 impl AddrStream {
321 pub(super) fn new(
322 tcp: TcpStream,
323 remote_addr: SocketAddr,
324 local_addr: SocketAddr,
325 ) -> AddrStream {
326 AddrStream {
327 inner: tcp,
328 remote_addr,
329 local_addr,
330 }
331 }
332
333 #[inline]
335 pub fn remote_addr(&self) -> SocketAddr {
336 self.remote_addr
337 }
338
339 #[inline]
341 pub fn local_addr(&self) -> SocketAddr {
342 self.local_addr
343 }
344
345 #[inline]
347 pub fn into_inner(self) -> TcpStream {
348 self.inner
349 }
350
351 pub fn poll_peek(
355 &mut self,
356 cx: &mut Context<'_>,
357 buf: &mut tokio::io::ReadBuf<'_>,
358 ) -> Poll<io::Result<usize>> {
359 self.inner.poll_peek(cx, buf)
360 }
361 }
362
363 impl AsyncRead for AddrStream {
364 #[inline]
365 fn poll_read(
366 self: Pin<&mut Self>,
367 cx: &mut Context<'_>,
368 buf: &mut ReadBuf<'_>,
369 ) -> Poll<io::Result<()>> {
370 self.project().inner.poll_read(cx, buf)
371 }
372 }
373
374 impl AsyncWrite for AddrStream {
375 #[inline]
376 fn poll_write(
377 self: Pin<&mut Self>,
378 cx: &mut Context<'_>,
379 buf: &[u8],
380 ) -> Poll<io::Result<usize>> {
381 self.project().inner.poll_write(cx, buf)
382 }
383
384 #[inline]
385 fn poll_write_vectored(
386 self: Pin<&mut Self>,
387 cx: &mut Context<'_>,
388 bufs: &[io::IoSlice<'_>],
389 ) -> Poll<io::Result<usize>> {
390 self.project().inner.poll_write_vectored(cx, bufs)
391 }
392
393 #[inline]
394 fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
395 Poll::Ready(Ok(()))
397 }
398
399 #[inline]
400 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
401 self.project().inner.poll_shutdown(cx)
402 }
403
404 #[inline]
405 fn is_write_vectored(&self) -> bool {
406 self.inner.is_write_vectored()
411 }
412 }
413
414 #[cfg(unix)]
415 impl AsRawFd for AddrStream {
416 fn as_raw_fd(&self) -> RawFd {
417 self.inner.as_raw_fd()
418 }
419 }
420}
421
422#[cfg(test)]
423mod tests {
424 use crate::server::tcp::TcpKeepaliveConfig;
425 use std::time::Duration;
426
427 #[test]
428 fn no_tcp_keepalive_config() {
429 assert!(TcpKeepaliveConfig::default().into_socket2().is_none());
430 }
431
432 #[test]
433 fn tcp_keepalive_time_config() {
434 let mut kac = TcpKeepaliveConfig::default();
435 kac.time = Some(Duration::from_secs(60));
436 if let Some(tcp_keepalive) = kac.into_socket2() {
437 assert!(format!("{tcp_keepalive:?}").contains("time: Some(60s)"));
438 } else {
439 panic!("test failed");
440 }
441 }
442
443 #[cfg(any(
444 target_os = "android",
445 target_os = "dragonfly",
446 target_os = "freebsd",
447 target_os = "fuchsia",
448 target_os = "illumos",
449 target_os = "linux",
450 target_os = "netbsd",
451 target_vendor = "apple",
452 windows,
453 ))]
454 #[test]
455 fn tcp_keepalive_interval_config() {
456 let mut kac = TcpKeepaliveConfig::default();
457 kac.interval = Some(Duration::from_secs(1));
458 if let Some(tcp_keepalive) = kac.into_socket2() {
459 assert!(format!("{tcp_keepalive:?}").contains("interval: Some(1s)"));
460 } else {
461 panic!("test failed");
462 }
463 }
464
465 #[cfg(any(
466 target_os = "android",
467 target_os = "dragonfly",
468 target_os = "freebsd",
469 target_os = "fuchsia",
470 target_os = "illumos",
471 target_os = "linux",
472 target_os = "netbsd",
473 target_vendor = "apple",
474 ))]
475 #[test]
476 fn tcp_keepalive_retries_config() {
477 let mut kac = TcpKeepaliveConfig::default();
478 kac.retries = Some(3);
479 if let Some(tcp_keepalive) = kac.into_socket2() {
480 assert!(format!("{tcp_keepalive:?}").contains("retries: Some(3)"));
481 } else {
482 panic!("test failed");
483 }
484 }
485}