tokio_rustls/
client.rs

1use super::*;
2use crate::common::IoSession;
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7
8/// A wrapper around an underlying raw stream which implements the TLS or SSL
9/// protocol.
10#[derive(Debug)]
11pub struct TlsStream<IO> {
12    pub(crate) io: IO,
13    pub(crate) session: ClientConnection,
14    pub(crate) state: TlsState,
15
16    #[cfg(feature = "early-data")]
17    pub(crate) early_waker: Option<std::task::Waker>,
18}
19
20impl<IO> TlsStream<IO> {
21    #[inline]
22    pub fn get_ref(&self) -> (&IO, &ClientConnection) {
23        (&self.io, &self.session)
24    }
25
26    #[inline]
27    pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
28        (&mut self.io, &mut self.session)
29    }
30
31    #[inline]
32    pub fn into_inner(self) -> (IO, ClientConnection) {
33        (self.io, self.session)
34    }
35}
36
37#[cfg(unix)]
38impl<S> AsRawFd for TlsStream<S>
39where
40    S: AsRawFd,
41{
42    fn as_raw_fd(&self) -> RawFd {
43        self.get_ref().0.as_raw_fd()
44    }
45}
46
47#[cfg(windows)]
48impl<S> AsRawSocket for TlsStream<S>
49where
50    S: AsRawSocket,
51{
52    fn as_raw_socket(&self) -> RawSocket {
53        self.get_ref().0.as_raw_socket()
54    }
55}
56
57impl<IO> IoSession for TlsStream<IO> {
58    type Io = IO;
59    type Session = ClientConnection;
60
61    #[inline]
62    fn skip_handshake(&self) -> bool {
63        self.state.is_early_data()
64    }
65
66    #[inline]
67    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
68        (&mut self.state, &mut self.io, &mut self.session)
69    }
70
71    #[inline]
72    fn into_io(self) -> Self::Io {
73        self.io
74    }
75}
76
77impl<IO> AsyncRead for TlsStream<IO>
78where
79    IO: AsyncRead + AsyncWrite + Unpin,
80{
81    fn poll_read(
82        self: Pin<&mut Self>,
83        cx: &mut Context<'_>,
84        buf: &mut ReadBuf<'_>,
85    ) -> Poll<io::Result<()>> {
86        match self.state {
87            #[cfg(feature = "early-data")]
88            TlsState::EarlyData(..) => {
89                let this = self.get_mut();
90
91                // In the EarlyData state, we have not really established a Tls connection.
92                // Before writing data through `AsyncWrite` and completing the tls handshake,
93                // we ignore read readiness and return to pending.
94                //
95                // In order to avoid event loss,
96                // we need to register a waker and wake it up after tls is connected.
97                if this
98                    .early_waker
99                    .as_ref()
100                    .filter(|waker| cx.waker().will_wake(waker))
101                    .is_none()
102                {
103                    this.early_waker = Some(cx.waker().clone());
104                }
105
106                Poll::Pending
107            }
108            TlsState::Stream | TlsState::WriteShutdown => {
109                let this = self.get_mut();
110                let mut stream =
111                    Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
112                let prev = buf.remaining();
113
114                match stream.as_mut_pin().poll_read(cx, buf) {
115                    Poll::Ready(Ok(())) => {
116                        if prev == buf.remaining() || stream.eof {
117                            this.state.shutdown_read();
118                        }
119
120                        Poll::Ready(Ok(()))
121                    }
122                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
123                        this.state.shutdown_read();
124                        Poll::Ready(Err(err))
125                    }
126                    output => output,
127                }
128            }
129            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
130        }
131    }
132}
133
134impl<IO> AsyncWrite for TlsStream<IO>
135where
136    IO: AsyncRead + AsyncWrite + Unpin,
137{
138    /// Note: that it does not guarantee the final data to be sent.
139    /// To be cautious, you must manually call `flush`.
140    fn poll_write(
141        self: Pin<&mut Self>,
142        cx: &mut Context<'_>,
143        buf: &[u8],
144    ) -> Poll<io::Result<usize>> {
145        let this = self.get_mut();
146        let mut stream =
147            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
148
149        #[allow(clippy::match_single_binding)]
150        match this.state {
151            #[cfg(feature = "early-data")]
152            TlsState::EarlyData(ref mut pos, ref mut data) => {
153                use std::io::Write;
154
155                // write early data
156                if let Some(mut early_data) = stream.session.early_data() {
157                    let len = match early_data.write(buf) {
158                        Ok(n) => n,
159                        Err(err) => return Poll::Ready(Err(err)),
160                    };
161                    if len != 0 {
162                        data.extend_from_slice(&buf[..len]);
163                        return Poll::Ready(Ok(len));
164                    }
165                }
166
167                // complete handshake
168                while stream.session.is_handshaking() {
169                    ready!(stream.handshake(cx))?;
170                }
171
172                // write early data (fallback)
173                if !stream.session.is_early_data_accepted() {
174                    while *pos < data.len() {
175                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
176                        *pos += len;
177                    }
178                }
179
180                // end
181                this.state = TlsState::Stream;
182
183                if let Some(waker) = this.early_waker.take() {
184                    waker.wake();
185                }
186
187                stream.as_mut_pin().poll_write(cx, buf)
188            }
189            _ => stream.as_mut_pin().poll_write(cx, buf),
190        }
191    }
192
193    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194        let this = self.get_mut();
195        let mut stream =
196            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
197
198        #[cfg(feature = "early-data")]
199        {
200            if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
201                // complete handshake
202                while stream.session.is_handshaking() {
203                    ready!(stream.handshake(cx))?;
204                }
205
206                // write early data (fallback)
207                if !stream.session.is_early_data_accepted() {
208                    while *pos < data.len() {
209                        let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
210                        *pos += len;
211                    }
212                }
213
214                this.state = TlsState::Stream;
215
216                if let Some(waker) = this.early_waker.take() {
217                    waker.wake();
218                }
219            }
220        }
221
222        stream.as_mut_pin().poll_flush(cx)
223    }
224
225    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226        #[cfg(feature = "early-data")]
227        {
228            // complete handshake
229            if matches!(self.state, TlsState::EarlyData(..)) {
230                ready!(self.as_mut().poll_flush(cx))?;
231            }
232        }
233
234        if self.state.writeable() {
235            self.session.send_close_notify();
236            self.state.shutdown_write();
237        }
238
239        let this = self.get_mut();
240        let mut stream =
241            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
242        stream.as_mut_pin().poll_shutdown(cx)
243    }
244}