tokio_rustls/
client.rs
1use super::*;
2use crate::common::IoSession;
3#[cfg(unix)]
4use std::os::unix::io::{AsRawFd, RawFd};
5#[cfg(windows)]
6use std::os::windows::io::{AsRawSocket, RawSocket};
7
8#[derive(Debug)]
11pub struct TlsStream<IO> {
12 pub(crate) io: IO,
13 pub(crate) session: ClientConnection,
14 pub(crate) state: TlsState,
15
16 #[cfg(feature = "early-data")]
17 pub(crate) early_waker: Option<std::task::Waker>,
18}
19
20impl<IO> TlsStream<IO> {
21 #[inline]
22 pub fn get_ref(&self) -> (&IO, &ClientConnection) {
23 (&self.io, &self.session)
24 }
25
26 #[inline]
27 pub fn get_mut(&mut self) -> (&mut IO, &mut ClientConnection) {
28 (&mut self.io, &mut self.session)
29 }
30
31 #[inline]
32 pub fn into_inner(self) -> (IO, ClientConnection) {
33 (self.io, self.session)
34 }
35}
36
37#[cfg(unix)]
38impl<S> AsRawFd for TlsStream<S>
39where
40 S: AsRawFd,
41{
42 fn as_raw_fd(&self) -> RawFd {
43 self.get_ref().0.as_raw_fd()
44 }
45}
46
47#[cfg(windows)]
48impl<S> AsRawSocket for TlsStream<S>
49where
50 S: AsRawSocket,
51{
52 fn as_raw_socket(&self) -> RawSocket {
53 self.get_ref().0.as_raw_socket()
54 }
55}
56
57impl<IO> IoSession for TlsStream<IO> {
58 type Io = IO;
59 type Session = ClientConnection;
60
61 #[inline]
62 fn skip_handshake(&self) -> bool {
63 self.state.is_early_data()
64 }
65
66 #[inline]
67 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
68 (&mut self.state, &mut self.io, &mut self.session)
69 }
70
71 #[inline]
72 fn into_io(self) -> Self::Io {
73 self.io
74 }
75}
76
77impl<IO> AsyncRead for TlsStream<IO>
78where
79 IO: AsyncRead + AsyncWrite + Unpin,
80{
81 fn poll_read(
82 self: Pin<&mut Self>,
83 cx: &mut Context<'_>,
84 buf: &mut ReadBuf<'_>,
85 ) -> Poll<io::Result<()>> {
86 match self.state {
87 #[cfg(feature = "early-data")]
88 TlsState::EarlyData(..) => {
89 let this = self.get_mut();
90
91 if this
98 .early_waker
99 .as_ref()
100 .filter(|waker| cx.waker().will_wake(waker))
101 .is_none()
102 {
103 this.early_waker = Some(cx.waker().clone());
104 }
105
106 Poll::Pending
107 }
108 TlsState::Stream | TlsState::WriteShutdown => {
109 let this = self.get_mut();
110 let mut stream =
111 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
112 let prev = buf.remaining();
113
114 match stream.as_mut_pin().poll_read(cx, buf) {
115 Poll::Ready(Ok(())) => {
116 if prev == buf.remaining() || stream.eof {
117 this.state.shutdown_read();
118 }
119
120 Poll::Ready(Ok(()))
121 }
122 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::ConnectionAborted => {
123 this.state.shutdown_read();
124 Poll::Ready(Err(err))
125 }
126 output => output,
127 }
128 }
129 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
130 }
131 }
132}
133
134impl<IO> AsyncWrite for TlsStream<IO>
135where
136 IO: AsyncRead + AsyncWrite + Unpin,
137{
138 fn poll_write(
141 self: Pin<&mut Self>,
142 cx: &mut Context<'_>,
143 buf: &[u8],
144 ) -> Poll<io::Result<usize>> {
145 let this = self.get_mut();
146 let mut stream =
147 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
148
149 #[allow(clippy::match_single_binding)]
150 match this.state {
151 #[cfg(feature = "early-data")]
152 TlsState::EarlyData(ref mut pos, ref mut data) => {
153 use std::io::Write;
154
155 if let Some(mut early_data) = stream.session.early_data() {
157 let len = match early_data.write(buf) {
158 Ok(n) => n,
159 Err(err) => return Poll::Ready(Err(err)),
160 };
161 if len != 0 {
162 data.extend_from_slice(&buf[..len]);
163 return Poll::Ready(Ok(len));
164 }
165 }
166
167 while stream.session.is_handshaking() {
169 ready!(stream.handshake(cx))?;
170 }
171
172 if !stream.session.is_early_data_accepted() {
174 while *pos < data.len() {
175 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
176 *pos += len;
177 }
178 }
179
180 this.state = TlsState::Stream;
182
183 if let Some(waker) = this.early_waker.take() {
184 waker.wake();
185 }
186
187 stream.as_mut_pin().poll_write(cx, buf)
188 }
189 _ => stream.as_mut_pin().poll_write(cx, buf),
190 }
191 }
192
193 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
194 let this = self.get_mut();
195 let mut stream =
196 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
197
198 #[cfg(feature = "early-data")]
199 {
200 if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
201 while stream.session.is_handshaking() {
203 ready!(stream.handshake(cx))?;
204 }
205
206 if !stream.session.is_early_data_accepted() {
208 while *pos < data.len() {
209 let len = ready!(stream.as_mut_pin().poll_write(cx, &data[*pos..]))?;
210 *pos += len;
211 }
212 }
213
214 this.state = TlsState::Stream;
215
216 if let Some(waker) = this.early_waker.take() {
217 waker.wake();
218 }
219 }
220 }
221
222 stream.as_mut_pin().poll_flush(cx)
223 }
224
225 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
226 #[cfg(feature = "early-data")]
227 {
228 if matches!(self.state, TlsState::EarlyData(..)) {
230 ready!(self.as_mut().poll_flush(cx))?;
231 }
232 }
233
234 if self.state.writeable() {
235 self.session.send_close_notify();
236 self.state.shutdown_write();
237 }
238
239 let this = self.get_mut();
240 let mut stream =
241 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
242 stream.as_mut_pin().poll_shutdown(cx)
243 }
244}