hyper_rustls/
acceptor.rs

1use core::task::{Context, Poll};
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use futures_util::ready;
8use hyper::server::{
9    accept::Accept,
10    conn::{AddrIncoming, AddrStream},
11};
12use rustls::{ServerConfig, ServerConnection};
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15mod builder;
16pub use builder::AcceptorBuilder;
17use builder::WantsTlsConfig;
18
19/// A TLS acceptor that can be used with hyper servers.
20pub struct TlsAcceptor<A = AddrIncoming> {
21    config: Arc<ServerConfig>,
22    acceptor: A,
23}
24
25/// An Acceptor for the `https` scheme.
26impl TlsAcceptor {
27    /// Provides a builder for a `TlsAcceptor`.
28    pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
29        AcceptorBuilder::new()
30    }
31
32    /// Creates a new `TlsAcceptor` from a `ServerConfig` and an `AddrIncoming`.
33    pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
34        Self {
35            config,
36            acceptor: incoming,
37        }
38    }
39}
40
41impl<A> Accept for TlsAcceptor<A>
42where
43    A: Accept<Error = io::Error> + Unpin,
44    A::Conn: AsyncRead + AsyncWrite + Unpin,
45{
46    type Conn = TlsStream<A::Conn>;
47    type Error = io::Error;
48
49    fn poll_accept(
50        self: Pin<&mut Self>,
51        cx: &mut Context<'_>,
52    ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
53        let pin = self.get_mut();
54        Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
55            Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
56            Some(Err(e)) => Some(Err(e)),
57            None => None,
58        })
59    }
60}
61
62impl<C, I> From<(C, I)> for TlsAcceptor
63where
64    C: Into<Arc<ServerConfig>>,
65    I: Into<AddrIncoming>,
66{
67    fn from((config, incoming): (C, I)) -> Self {
68        Self::new(config.into(), incoming.into())
69    }
70}
71
72/// A TLS stream constructed by a [`TlsAcceptor`].
73// tokio_rustls::server::TlsStream doesn't expose constructor methods,
74// so we have to TlsAcceptor::accept and handshake to have access to it
75// TlsStream implements AsyncRead/AsyncWrite by handshaking with tokio_rustls::Accept first
76pub struct TlsStream<C = AddrStream> {
77    state: State<C>,
78}
79
80impl<C: AsyncRead + AsyncWrite + Unpin> TlsStream<C> {
81    fn new(stream: C, config: Arc<ServerConfig>) -> Self {
82        let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
83        Self {
84            state: State::Handshaking(accept),
85        }
86    }
87    /// Returns a reference to the underlying IO stream.
88    ///
89    /// This should always return `Some`, except if an error has already been yielded.
90    pub fn io(&self) -> Option<&C> {
91        match &self.state {
92            State::Handshaking(accept) => accept.get_ref(),
93            State::Streaming(stream) => Some(stream.get_ref().0),
94        }
95    }
96
97    /// Returns a reference to the underlying [`rustls::ServerConnection'].
98    ///
99    /// This will start yielding `Some` only after the handshake has completed.
100    pub fn connection(&self) -> Option<&ServerConnection> {
101        match &self.state {
102            State::Handshaking(_) => None,
103            State::Streaming(stream) => Some(stream.get_ref().1),
104        }
105    }
106}
107
108impl<C: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<C> {
109    fn poll_read(
110        self: Pin<&mut Self>,
111        cx: &mut Context,
112        buf: &mut ReadBuf,
113    ) -> Poll<io::Result<()>> {
114        let pin = self.get_mut();
115        let accept = match &mut pin.state {
116            State::Handshaking(accept) => accept,
117            State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf),
118        };
119
120        let mut stream = match ready!(Pin::new(accept).poll(cx)) {
121            Ok(stream) => stream,
122            Err(err) => return Poll::Ready(Err(err)),
123        };
124
125        let result = Pin::new(&mut stream).poll_read(cx, buf);
126        pin.state = State::Streaming(stream);
127        result
128    }
129}
130
131impl<C: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<C> {
132    fn poll_write(
133        self: Pin<&mut Self>,
134        cx: &mut Context<'_>,
135        buf: &[u8],
136    ) -> Poll<io::Result<usize>> {
137        let pin = self.get_mut();
138        let accept = match &mut pin.state {
139            State::Handshaking(accept) => accept,
140            State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf),
141        };
142
143        let mut stream = match ready!(Pin::new(accept).poll(cx)) {
144            Ok(stream) => stream,
145            Err(err) => return Poll::Ready(Err(err)),
146        };
147
148        let result = Pin::new(&mut stream).poll_write(cx, buf);
149        pin.state = State::Streaming(stream);
150        result
151    }
152
153    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
154        match &mut self.state {
155            State::Handshaking(_) => Poll::Ready(Ok(())),
156            State::Streaming(stream) => Pin::new(stream).poll_flush(cx),
157        }
158    }
159
160    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
161        match &mut self.state {
162            State::Handshaking(_) => Poll::Ready(Ok(())),
163            State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
164        }
165    }
166}
167
168enum State<C> {
169    Handshaking(tokio_rustls::Accept<C>),
170    Streaming(tokio_rustls::server::TlsStream<C>),
171}