tokio_rustls/common/
mod.rs

1mod handshake;
2
3pub(crate) use handshake::{IoSession, MidHandshake};
4use rustls::{ConnectionCommon, SideData};
5use std::io::{self, IoSlice, Read, Write};
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11#[derive(Debug)]
12pub enum TlsState {
13    #[cfg(feature = "early-data")]
14    EarlyData(usize, Vec<u8>),
15    Stream,
16    ReadShutdown,
17    WriteShutdown,
18    FullyShutdown,
19}
20
21impl TlsState {
22    #[inline]
23    pub fn shutdown_read(&mut self) {
24        match *self {
25            TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
26            _ => *self = TlsState::ReadShutdown,
27        }
28    }
29
30    #[inline]
31    pub fn shutdown_write(&mut self) {
32        match *self {
33            TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
34            _ => *self = TlsState::WriteShutdown,
35        }
36    }
37
38    #[inline]
39    pub fn writeable(&self) -> bool {
40        !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
41    }
42
43    #[inline]
44    pub fn readable(&self) -> bool {
45        !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
46    }
47
48    #[inline]
49    #[cfg(feature = "early-data")]
50    pub fn is_early_data(&self) -> bool {
51        matches!(self, TlsState::EarlyData(..))
52    }
53
54    #[inline]
55    #[cfg(not(feature = "early-data"))]
56    pub const fn is_early_data(&self) -> bool {
57        false
58    }
59}
60
61pub struct Stream<'a, IO, C> {
62    pub io: &'a mut IO,
63    pub session: &'a mut C,
64    pub eof: bool,
65}
66
67impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
68where
69    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
70    SD: SideData,
71{
72    pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
73        Stream {
74            io,
75            session,
76            // The state so far is only used to detect EOF, so either Stream
77            // or EarlyData state should both be all right.
78            eof: false,
79        }
80    }
81
82    pub fn set_eof(mut self, eof: bool) -> Self {
83        self.eof = eof;
84        self
85    }
86
87    pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
88        Pin::new(self)
89    }
90
91    pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
92        let mut reader = SyncReadAdapter { io: self.io, cx };
93
94        let n = match self.session.read_tls(&mut reader) {
95            Ok(n) => n,
96            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
97            Err(err) => return Poll::Ready(Err(err)),
98        };
99
100        let stats = self.session.process_new_packets().map_err(|err| {
101            // In case we have an alert to send describing this error,
102            // try a last-gasp write -- but don't predate the primary
103            // error.
104            let _ = self.write_io(cx);
105
106            io::Error::new(io::ErrorKind::InvalidData, err)
107        })?;
108
109        if stats.peer_has_closed() && self.session.is_handshaking() {
110            return Poll::Ready(Err(io::Error::new(
111                io::ErrorKind::UnexpectedEof,
112                "tls handshake alert",
113            )));
114        }
115
116        Poll::Ready(Ok(n))
117    }
118
119    pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
120        struct Writer<'a, 'b, T> {
121            io: &'a mut T,
122            cx: &'a mut Context<'b>,
123        }
124
125        impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
126            #[inline]
127            fn poll_with<U>(
128                &mut self,
129                f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
130            ) -> io::Result<U> {
131                match f(Pin::new(self.io), self.cx) {
132                    Poll::Ready(result) => result,
133                    Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
134                }
135            }
136        }
137
138        impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
139            #[inline]
140            fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
141                self.poll_with(|io, cx| io.poll_write(cx, buf))
142            }
143
144            #[inline]
145            fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
146                self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
147            }
148
149            fn flush(&mut self) -> io::Result<()> {
150                self.poll_with(|io, cx| io.poll_flush(cx))
151            }
152        }
153
154        let mut writer = Writer { io: self.io, cx };
155
156        match self.session.write_tls(&mut writer) {
157            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
158            result => Poll::Ready(result),
159        }
160    }
161
162    pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
163        let mut wrlen = 0;
164        let mut rdlen = 0;
165
166        loop {
167            let mut write_would_block = false;
168            let mut read_would_block = false;
169            let mut need_flush = false;
170
171            while self.session.wants_write() {
172                match self.write_io(cx) {
173                    Poll::Ready(Ok(n)) => {
174                        wrlen += n;
175                        need_flush = true;
176                    }
177                    Poll::Pending => {
178                        write_would_block = true;
179                        break;
180                    }
181                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
182                }
183            }
184
185            if need_flush {
186                match Pin::new(&mut self.io).poll_flush(cx) {
187                    Poll::Ready(Ok(())) => (),
188                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
189                    Poll::Pending => write_would_block = true,
190                }
191            }
192
193            while !self.eof && self.session.wants_read() {
194                match self.read_io(cx) {
195                    Poll::Ready(Ok(0)) => self.eof = true,
196                    Poll::Ready(Ok(n)) => rdlen += n,
197                    Poll::Pending => {
198                        read_would_block = true;
199                        break;
200                    }
201                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
202                }
203            }
204
205            return match (self.eof, self.session.is_handshaking()) {
206                (true, true) => {
207                    let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
208                    Poll::Ready(Err(err))
209                }
210                (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
211                (_, true) if write_would_block || read_would_block => {
212                    if rdlen != 0 || wrlen != 0 {
213                        Poll::Ready(Ok((rdlen, wrlen)))
214                    } else {
215                        Poll::Pending
216                    }
217                }
218                (..) => continue,
219            };
220        }
221    }
222}
223
224impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
225where
226    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
227    SD: SideData,
228{
229    fn poll_read(
230        mut self: Pin<&mut Self>,
231        cx: &mut Context<'_>,
232        buf: &mut ReadBuf<'_>,
233    ) -> Poll<io::Result<()>> {
234        let mut io_pending = false;
235
236        // read a packet
237        while !self.eof && self.session.wants_read() {
238            match self.read_io(cx) {
239                Poll::Ready(Ok(0)) => {
240                    break;
241                }
242                Poll::Ready(Ok(_)) => (),
243                Poll::Pending => {
244                    io_pending = true;
245                    break;
246                }
247                Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
248            }
249        }
250
251        match self.session.reader().read(buf.initialize_unfilled()) {
252            // If Rustls returns `Ok(0)` (while `buf` is non-empty), the peer closed the
253            // connection with a `CloseNotify` message and no more data will be forthcoming.
254            //
255            // Rustls yielded more data: advance the buffer, then see if more data is coming.
256            //
257            // We don't need to modify `self.eof` here, because it is only a temporary mark.
258            // rustls will only return 0 if is has received `CloseNotify`,
259            // in which case no additional processing is required.
260            Ok(n) => {
261                buf.advance(n);
262                Poll::Ready(Ok(()))
263            }
264
265            // Rustls doesn't have more data to yield, but it believes the connection is open.
266            Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
267                if !io_pending {
268                    // If `wants_read()` is satisfied, rustls will not return `WouldBlock`.
269                    // but if it does, we can try again.
270                    //
271                    // If the rustls state is abnormal, it may cause a cyclic wakeup.
272                    // but tokio's cooperative budget will prevent infinite wakeup.
273                    cx.waker().wake_by_ref();
274                }
275
276                Poll::Pending
277            }
278
279            Err(err) => Poll::Ready(Err(err)),
280        }
281    }
282}
283
284impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
285where
286    C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
287    SD: SideData,
288{
289    fn poll_write(
290        mut self: Pin<&mut Self>,
291        cx: &mut Context,
292        buf: &[u8],
293    ) -> Poll<io::Result<usize>> {
294        let mut pos = 0;
295
296        while pos != buf.len() {
297            let mut would_block = false;
298
299            match self.session.writer().write(&buf[pos..]) {
300                Ok(n) => pos += n,
301                Err(err) => return Poll::Ready(Err(err)),
302            };
303
304            while self.session.wants_write() {
305                match self.write_io(cx) {
306                    Poll::Ready(Ok(0)) | Poll::Pending => {
307                        would_block = true;
308                        break;
309                    }
310                    Poll::Ready(Ok(_)) => (),
311                    Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
312                }
313            }
314
315            return match (pos, would_block) {
316                (0, true) => Poll::Pending,
317                (n, true) => Poll::Ready(Ok(n)),
318                (_, false) => continue,
319            };
320        }
321
322        Poll::Ready(Ok(pos))
323    }
324
325    fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
326        self.session.writer().flush()?;
327        while self.session.wants_write() {
328            ready!(self.write_io(cx))?;
329        }
330        Pin::new(&mut self.io).poll_flush(cx)
331    }
332
333    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
334        while self.session.wants_write() {
335            ready!(self.write_io(cx))?;
336        }
337        Pin::new(&mut self.io).poll_shutdown(cx)
338    }
339}
340
341/// An adapter that implements a [`Read`] interface for [`AsyncRead`] types and an
342/// associated [`Context`].
343///
344/// Turns `Poll::Pending` into `WouldBlock`.
345pub struct SyncReadAdapter<'a, 'b, T> {
346    pub io: &'a mut T,
347    pub cx: &'a mut Context<'b>,
348}
349
350impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
351    #[inline]
352    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
353        let mut buf = ReadBuf::new(buf);
354        match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
355            Poll::Ready(Ok(())) => Ok(buf.filled().len()),
356            Poll::Ready(Err(err)) => Err(err),
357            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
358        }
359    }
360}
361
362#[cfg(test)]
363mod test_stream;