tokio_rustls/
lib.rs

1//! Asynchronous TLS/SSL streams for Tokio using [Rustls](https://github.com/rustls/rustls).
2//!
3//! # Why do I need to call `poll_flush`?
4//!
5//! Most TLS implementations will have an internal buffer to improve throughput,
6//! and rustls is no exception.
7//!
8//! When we write data to `TlsStream`, we always write rustls buffer first,
9//! then take out rustls encrypted data packet, and write it to data channel (like TcpStream).
10//! When data channel is pending, some data may remain in rustls buffer.
11//!
12//! `tokio-rustls` To keep it simple and correct, [TlsStream] will behave like `BufWriter`.
13//! For `TlsStream<TcpStream>`, this means that data written by `poll_write` is not guaranteed to be written to `TcpStream`.
14//! You must call `poll_flush` to ensure that it is written to `TcpStream`.
15//!
16//! You should call `poll_flush` at the appropriate time,
17//! such as when a period of `poll_write` write is complete and there is no more data to write.
18//!
19//! ## Why don't we write during `poll_read`?
20//!
21//! We did this in the early days of `tokio-rustls`, but it caused some bugs.
22//! We can solve these bugs through some solutions, but this will cause performance degradation (reverse false wakeup).
23//!
24//! And reverse write will also prevent us implement full duplex in the future.
25//!
26//! see <https://github.com/tokio-rs/tls/issues/40>
27//!
28//! ## Why can't we handle it like `native-tls`?
29//!
30//! When data channel returns to pending, `native-tls` will falsely report the number of bytes it consumes.
31//! This means that if data written by `poll_write` is not actually written to data channel, it will not return `Ready`.
32//! Thus avoiding the call of `poll_flush`.
33//!
34//! but which does not conform to convention of `AsyncWrite` trait.
35//! This means that if you give inconsistent data in two `poll_write`, it may cause unexpected behavior.
36//!
37//! see <https://github.com/tokio-rs/tls/issues/41>
38
39macro_rules! ready {
40    ( $e:expr ) => {
41        match $e {
42            std::task::Poll::Ready(t) => t,
43            std::task::Poll::Pending => return std::task::Poll::Pending,
44        }
45    };
46}
47
48pub mod client;
49mod common;
50pub mod server;
51
52use common::{MidHandshake, Stream, TlsState};
53use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54use std::future::Future;
55use std::io;
56#[cfg(unix)]
57use std::os::unix::io::{AsRawFd, RawFd};
58#[cfg(windows)]
59use std::os::windows::io::{AsRawSocket, RawSocket};
60use std::pin::Pin;
61use std::sync::Arc;
62use std::task::{Context, Poll};
63use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
64
65pub use rustls;
66
67/// A wrapper around a `rustls::ClientConfig`, providing an async `connect` method.
68#[derive(Clone)]
69pub struct TlsConnector {
70    inner: Arc<ClientConfig>,
71    #[cfg(feature = "early-data")]
72    early_data: bool,
73}
74
75/// A wrapper around a `rustls::ServerConfig`, providing an async `accept` method.
76#[derive(Clone)]
77pub struct TlsAcceptor {
78    inner: Arc<ServerConfig>,
79}
80
81impl From<Arc<ClientConfig>> for TlsConnector {
82    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
83        TlsConnector {
84            inner,
85            #[cfg(feature = "early-data")]
86            early_data: false,
87        }
88    }
89}
90
91impl From<Arc<ServerConfig>> for TlsAcceptor {
92    fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
93        TlsAcceptor { inner }
94    }
95}
96
97impl TlsConnector {
98    /// Enable 0-RTT.
99    ///
100    /// If you want to use 0-RTT,
101    /// You must also set `ClientConfig.enable_early_data` to `true`.
102    #[cfg(feature = "early-data")]
103    pub fn early_data(mut self, flag: bool) -> TlsConnector {
104        self.early_data = flag;
105        self
106    }
107
108    #[inline]
109    pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
110    where
111        IO: AsyncRead + AsyncWrite + Unpin,
112    {
113        self.connect_with(domain, stream, |_| ())
114    }
115
116    pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
117    where
118        IO: AsyncRead + AsyncWrite + Unpin,
119        F: FnOnce(&mut ClientConnection),
120    {
121        let mut session = match ClientConnection::new(self.inner.clone(), domain) {
122            Ok(session) => session,
123            Err(error) => {
124                return Connect(MidHandshake::Error {
125                    io: stream,
126                    // TODO(eliza): should this really return an `io::Error`?
127                    // Probably not...
128                    error: io::Error::new(io::ErrorKind::Other, error),
129                });
130            }
131        };
132        f(&mut session);
133
134        Connect(MidHandshake::Handshaking(client::TlsStream {
135            io: stream,
136
137            #[cfg(not(feature = "early-data"))]
138            state: TlsState::Stream,
139
140            #[cfg(feature = "early-data")]
141            state: if self.early_data && session.early_data().is_some() {
142                TlsState::EarlyData(0, Vec::new())
143            } else {
144                TlsState::Stream
145            },
146
147            #[cfg(feature = "early-data")]
148            early_waker: None,
149
150            session,
151        }))
152    }
153}
154
155impl TlsAcceptor {
156    #[inline]
157    pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
158    where
159        IO: AsyncRead + AsyncWrite + Unpin,
160    {
161        self.accept_with(stream, |_| ())
162    }
163
164    pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
165    where
166        IO: AsyncRead + AsyncWrite + Unpin,
167        F: FnOnce(&mut ServerConnection),
168    {
169        let mut session = match ServerConnection::new(self.inner.clone()) {
170            Ok(session) => session,
171            Err(error) => {
172                return Accept(MidHandshake::Error {
173                    io: stream,
174                    // TODO(eliza): should this really return an `io::Error`?
175                    // Probably not...
176                    error: io::Error::new(io::ErrorKind::Other, error),
177                });
178            }
179        };
180        f(&mut session);
181
182        Accept(MidHandshake::Handshaking(server::TlsStream {
183            session,
184            io: stream,
185            state: TlsState::Stream,
186        }))
187    }
188}
189
190pub struct LazyConfigAcceptor<IO> {
191    acceptor: rustls::server::Acceptor,
192    io: Option<IO>,
193}
194
195impl<IO> LazyConfigAcceptor<IO>
196where
197    IO: AsyncRead + AsyncWrite + Unpin,
198{
199    #[inline]
200    pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
201        Self {
202            acceptor,
203            io: Some(io),
204        }
205    }
206
207    /// Takes back the client connection. Will return `None` if called more than once or if the
208    /// connection has been accepted.
209    ///
210    /// # Example
211    ///
212    /// ```no_run
213    /// # fn choose_server_config(
214    /// #     _: rustls::server::ClientHello,
215    /// # ) -> std::sync::Arc<rustls::ServerConfig> {
216    /// #     unimplemented!();
217    /// # }
218    /// # #[allow(unused_variables)]
219    /// # async fn listen() {
220    /// use tokio::io::AsyncWriteExt;
221    /// let listener = tokio::net::TcpListener::bind("127.0.0.1:4443").await.unwrap();
222    /// let (stream, _) = listener.accept().await.unwrap();
223    ///
224    /// let acceptor = tokio_rustls::LazyConfigAcceptor::new(rustls::server::Acceptor::default(), stream);
225    /// futures_util::pin_mut!(acceptor);
226    ///
227    /// match acceptor.as_mut().await {
228    ///     Ok(start) => {
229    ///         let clientHello = start.client_hello();
230    ///         let config = choose_server_config(clientHello);
231    ///         let stream = start.into_stream(config).await.unwrap();
232    ///         // Proceed with handling the ServerConnection...
233    ///     }
234    ///     Err(err) => {
235    ///         if let Some(mut stream) = acceptor.take_io() {
236    ///             stream
237    ///                 .write_all(
238    ///                     format!("HTTP/1.1 400 Invalid Input\r\n\r\n\r\n{:?}\n", err)
239    ///                         .as_bytes()
240    ///                 )
241    ///                 .await
242    ///                 .unwrap();
243    ///         }
244    ///     }
245    /// }
246    /// # }
247    /// ```
248    pub fn take_io(&mut self) -> Option<IO> {
249        self.io.take()
250    }
251}
252
253impl<IO> Future for LazyConfigAcceptor<IO>
254where
255    IO: AsyncRead + AsyncWrite + Unpin,
256{
257    type Output = Result<StartHandshake<IO>, io::Error>;
258
259    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
260        let this = self.get_mut();
261        loop {
262            let io = match this.io.as_mut() {
263                Some(io) => io,
264                None => {
265                    return Poll::Ready(Err(io::Error::new(
266                        io::ErrorKind::Other,
267                        "acceptor cannot be polled after acceptance",
268                    )))
269                }
270            };
271
272            let mut reader = common::SyncReadAdapter { io, cx };
273            match this.acceptor.read_tls(&mut reader) {
274                Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
275                Ok(_) => {}
276                Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
277                Err(e) => return Err(e).into(),
278            }
279
280            match this.acceptor.accept() {
281                Ok(Some(accepted)) => {
282                    let io = this.io.take().unwrap();
283                    return Poll::Ready(Ok(StartHandshake { accepted, io }));
284                }
285                Ok(None) => continue,
286                Err(err) => {
287                    return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
288                }
289            }
290        }
291    }
292}
293
294pub struct StartHandshake<IO> {
295    accepted: rustls::server::Accepted,
296    io: IO,
297}
298
299impl<IO> StartHandshake<IO>
300where
301    IO: AsyncRead + AsyncWrite + Unpin,
302{
303    pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
304        self.accepted.client_hello()
305    }
306
307    pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
308        self.into_stream_with(config, |_| ())
309    }
310
311    pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
312    where
313        F: FnOnce(&mut ServerConnection),
314    {
315        let mut conn = match self.accepted.into_connection(config) {
316            Ok(conn) => conn,
317            Err(error) => {
318                return Accept(MidHandshake::Error {
319                    io: self.io,
320                    // TODO(eliza): should this really return an `io::Error`?
321                    // Probably not...
322                    error: io::Error::new(io::ErrorKind::Other, error),
323                });
324            }
325        };
326        f(&mut conn);
327
328        Accept(MidHandshake::Handshaking(server::TlsStream {
329            session: conn,
330            io: self.io,
331            state: TlsState::Stream,
332        }))
333    }
334}
335
336/// Future returned from `TlsConnector::connect` which will resolve
337/// once the connection handshake has finished.
338pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
339
340/// Future returned from `TlsAcceptor::accept` which will resolve
341/// once the accept handshake has finished.
342pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
343
344/// Like [Connect], but returns `IO` on failure.
345pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
346
347/// Like [Accept], but returns `IO` on failure.
348pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
349
350impl<IO> Connect<IO> {
351    #[inline]
352    pub fn into_fallible(self) -> FallibleConnect<IO> {
353        FallibleConnect(self.0)
354    }
355
356    pub fn get_ref(&self) -> Option<&IO> {
357        match &self.0 {
358            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
359            MidHandshake::Error { io, .. } => Some(io),
360            MidHandshake::End => None,
361        }
362    }
363
364    pub fn get_mut(&mut self) -> Option<&mut IO> {
365        match &mut self.0 {
366            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
367            MidHandshake::Error { io, .. } => Some(io),
368            MidHandshake::End => None,
369        }
370    }
371}
372
373impl<IO> Accept<IO> {
374    #[inline]
375    pub fn into_fallible(self) -> FallibleAccept<IO> {
376        FallibleAccept(self.0)
377    }
378
379    pub fn get_ref(&self) -> Option<&IO> {
380        match &self.0 {
381            MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
382            MidHandshake::Error { io, .. } => Some(io),
383            MidHandshake::End => None,
384        }
385    }
386
387    pub fn get_mut(&mut self) -> Option<&mut IO> {
388        match &mut self.0 {
389            MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
390            MidHandshake::Error { io, .. } => Some(io),
391            MidHandshake::End => None,
392        }
393    }
394}
395
396impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
397    type Output = io::Result<client::TlsStream<IO>>;
398
399    #[inline]
400    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
402    }
403}
404
405impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
406    type Output = io::Result<server::TlsStream<IO>>;
407
408    #[inline]
409    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
410        Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
411    }
412}
413
414impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
415    type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
416
417    #[inline]
418    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
419        Pin::new(&mut self.0).poll(cx)
420    }
421}
422
423impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
424    type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
425
426    #[inline]
427    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
428        Pin::new(&mut self.0).poll(cx)
429    }
430}
431
432/// Unified TLS stream type
433///
434/// This abstracts over the inner `client::TlsStream` and `server::TlsStream`, so you can use
435/// a single type to keep both client- and server-initiated TLS-encrypted connections.
436#[allow(clippy::large_enum_variant)] // https://github.com/rust-lang/rust-clippy/issues/9798
437#[derive(Debug)]
438pub enum TlsStream<T> {
439    Client(client::TlsStream<T>),
440    Server(server::TlsStream<T>),
441}
442
443impl<T> TlsStream<T> {
444    pub fn get_ref(&self) -> (&T, &CommonState) {
445        use TlsStream::*;
446        match self {
447            Client(io) => {
448                let (io, session) = io.get_ref();
449                (io, session)
450            }
451            Server(io) => {
452                let (io, session) = io.get_ref();
453                (io, session)
454            }
455        }
456    }
457
458    pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
459        use TlsStream::*;
460        match self {
461            Client(io) => {
462                let (io, session) = io.get_mut();
463                (io, &mut *session)
464            }
465            Server(io) => {
466                let (io, session) = io.get_mut();
467                (io, &mut *session)
468            }
469        }
470    }
471}
472
473impl<T> From<client::TlsStream<T>> for TlsStream<T> {
474    fn from(s: client::TlsStream<T>) -> Self {
475        Self::Client(s)
476    }
477}
478
479impl<T> From<server::TlsStream<T>> for TlsStream<T> {
480    fn from(s: server::TlsStream<T>) -> Self {
481        Self::Server(s)
482    }
483}
484
485#[cfg(unix)]
486impl<S> AsRawFd for TlsStream<S>
487where
488    S: AsRawFd,
489{
490    fn as_raw_fd(&self) -> RawFd {
491        self.get_ref().0.as_raw_fd()
492    }
493}
494
495#[cfg(windows)]
496impl<S> AsRawSocket for TlsStream<S>
497where
498    S: AsRawSocket,
499{
500    fn as_raw_socket(&self) -> RawSocket {
501        self.get_ref().0.as_raw_socket()
502    }
503}
504
505impl<T> AsyncRead for TlsStream<T>
506where
507    T: AsyncRead + AsyncWrite + Unpin,
508{
509    #[inline]
510    fn poll_read(
511        self: Pin<&mut Self>,
512        cx: &mut Context<'_>,
513        buf: &mut ReadBuf<'_>,
514    ) -> Poll<io::Result<()>> {
515        match self.get_mut() {
516            TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
517            TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
518        }
519    }
520}
521
522impl<T> AsyncWrite for TlsStream<T>
523where
524    T: AsyncRead + AsyncWrite + Unpin,
525{
526    #[inline]
527    fn poll_write(
528        self: Pin<&mut Self>,
529        cx: &mut Context<'_>,
530        buf: &[u8],
531    ) -> Poll<io::Result<usize>> {
532        match self.get_mut() {
533            TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
534            TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
535        }
536    }
537
538    #[inline]
539    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
540        match self.get_mut() {
541            TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
542            TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
543        }
544    }
545
546    #[inline]
547    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
548        match self.get_mut() {
549            TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
550            TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
551        }
552    }
553}