1mod handshake;
2
3pub(crate) use handshake::{IoSession, MidHandshake};
4use rustls::{ConnectionCommon, SideData};
5use std::io::{self, IoSlice, Read, Write};
6use std::ops::{Deref, DerefMut};
7use std::pin::Pin;
8use std::task::{Context, Poll};
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10
11#[derive(Debug)]
12pub enum TlsState {
13 #[cfg(feature = "early-data")]
14 EarlyData(usize, Vec<u8>),
15 Stream,
16 ReadShutdown,
17 WriteShutdown,
18 FullyShutdown,
19}
20
21impl TlsState {
22 #[inline]
23 pub fn shutdown_read(&mut self) {
24 match *self {
25 TlsState::WriteShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
26 _ => *self = TlsState::ReadShutdown,
27 }
28 }
29
30 #[inline]
31 pub fn shutdown_write(&mut self) {
32 match *self {
33 TlsState::ReadShutdown | TlsState::FullyShutdown => *self = TlsState::FullyShutdown,
34 _ => *self = TlsState::WriteShutdown,
35 }
36 }
37
38 #[inline]
39 pub fn writeable(&self) -> bool {
40 !matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
41 }
42
43 #[inline]
44 pub fn readable(&self) -> bool {
45 !matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
46 }
47
48 #[inline]
49 #[cfg(feature = "early-data")]
50 pub fn is_early_data(&self) -> bool {
51 matches!(self, TlsState::EarlyData(..))
52 }
53
54 #[inline]
55 #[cfg(not(feature = "early-data"))]
56 pub const fn is_early_data(&self) -> bool {
57 false
58 }
59}
60
61pub struct Stream<'a, IO, C> {
62 pub io: &'a mut IO,
63 pub session: &'a mut C,
64 pub eof: bool,
65}
66
67impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> Stream<'a, IO, C>
68where
69 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
70 SD: SideData,
71{
72 pub fn new(io: &'a mut IO, session: &'a mut C) -> Self {
73 Stream {
74 io,
75 session,
76 eof: false,
79 }
80 }
81
82 pub fn set_eof(mut self, eof: bool) -> Self {
83 self.eof = eof;
84 self
85 }
86
87 pub fn as_mut_pin(&mut self) -> Pin<&mut Self> {
88 Pin::new(self)
89 }
90
91 pub fn read_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
92 let mut reader = SyncReadAdapter { io: self.io, cx };
93
94 let n = match self.session.read_tls(&mut reader) {
95 Ok(n) => n,
96 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
97 Err(err) => return Poll::Ready(Err(err)),
98 };
99
100 let stats = self.session.process_new_packets().map_err(|err| {
101 let _ = self.write_io(cx);
105
106 io::Error::new(io::ErrorKind::InvalidData, err)
107 })?;
108
109 if stats.peer_has_closed() && self.session.is_handshaking() {
110 return Poll::Ready(Err(io::Error::new(
111 io::ErrorKind::UnexpectedEof,
112 "tls handshake alert",
113 )));
114 }
115
116 Poll::Ready(Ok(n))
117 }
118
119 pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
120 struct Writer<'a, 'b, T> {
121 io: &'a mut T,
122 cx: &'a mut Context<'b>,
123 }
124
125 impl<'a, 'b, T: Unpin> Writer<'a, 'b, T> {
126 #[inline]
127 fn poll_with<U>(
128 &mut self,
129 f: impl FnOnce(Pin<&mut T>, &mut Context<'_>) -> Poll<io::Result<U>>,
130 ) -> io::Result<U> {
131 match f(Pin::new(self.io), self.cx) {
132 Poll::Ready(result) => result,
133 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
134 }
135 }
136 }
137
138 impl<'a, 'b, T: AsyncWrite + Unpin> Write for Writer<'a, 'b, T> {
139 #[inline]
140 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
141 self.poll_with(|io, cx| io.poll_write(cx, buf))
142 }
143
144 #[inline]
145 fn write_vectored(&mut self, bufs: &[IoSlice<'_>]) -> io::Result<usize> {
146 self.poll_with(|io, cx| io.poll_write_vectored(cx, bufs))
147 }
148
149 fn flush(&mut self) -> io::Result<()> {
150 self.poll_with(|io, cx| io.poll_flush(cx))
151 }
152 }
153
154 let mut writer = Writer { io: self.io, cx };
155
156 match self.session.write_tls(&mut writer) {
157 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
158 result => Poll::Ready(result),
159 }
160 }
161
162 pub fn handshake(&mut self, cx: &mut Context) -> Poll<io::Result<(usize, usize)>> {
163 let mut wrlen = 0;
164 let mut rdlen = 0;
165
166 loop {
167 let mut write_would_block = false;
168 let mut read_would_block = false;
169 let mut need_flush = false;
170
171 while self.session.wants_write() {
172 match self.write_io(cx) {
173 Poll::Ready(Ok(n)) => {
174 wrlen += n;
175 need_flush = true;
176 }
177 Poll::Pending => {
178 write_would_block = true;
179 break;
180 }
181 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
182 }
183 }
184
185 if need_flush {
186 match Pin::new(&mut self.io).poll_flush(cx) {
187 Poll::Ready(Ok(())) => (),
188 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
189 Poll::Pending => write_would_block = true,
190 }
191 }
192
193 while !self.eof && self.session.wants_read() {
194 match self.read_io(cx) {
195 Poll::Ready(Ok(0)) => self.eof = true,
196 Poll::Ready(Ok(n)) => rdlen += n,
197 Poll::Pending => {
198 read_would_block = true;
199 break;
200 }
201 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
202 }
203 }
204
205 return match (self.eof, self.session.is_handshaking()) {
206 (true, true) => {
207 let err = io::Error::new(io::ErrorKind::UnexpectedEof, "tls handshake eof");
208 Poll::Ready(Err(err))
209 }
210 (_, false) => Poll::Ready(Ok((rdlen, wrlen))),
211 (_, true) if write_would_block || read_would_block => {
212 if rdlen != 0 || wrlen != 0 {
213 Poll::Ready(Ok((rdlen, wrlen)))
214 } else {
215 Poll::Pending
216 }
217 }
218 (..) => continue,
219 };
220 }
221 }
222}
223
224impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncRead for Stream<'a, IO, C>
225where
226 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
227 SD: SideData,
228{
229 fn poll_read(
230 mut self: Pin<&mut Self>,
231 cx: &mut Context<'_>,
232 buf: &mut ReadBuf<'_>,
233 ) -> Poll<io::Result<()>> {
234 let mut io_pending = false;
235
236 while !self.eof && self.session.wants_read() {
238 match self.read_io(cx) {
239 Poll::Ready(Ok(0)) => {
240 break;
241 }
242 Poll::Ready(Ok(_)) => (),
243 Poll::Pending => {
244 io_pending = true;
245 break;
246 }
247 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
248 }
249 }
250
251 match self.session.reader().read(buf.initialize_unfilled()) {
252 Ok(n) => {
261 buf.advance(n);
262 Poll::Ready(Ok(()))
263 }
264
265 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
267 if !io_pending {
268 cx.waker().wake_by_ref();
274 }
275
276 Poll::Pending
277 }
278
279 Err(err) => Poll::Ready(Err(err)),
280 }
281 }
282}
283
284impl<'a, IO: AsyncRead + AsyncWrite + Unpin, C, SD> AsyncWrite for Stream<'a, IO, C>
285where
286 C: DerefMut + Deref<Target = ConnectionCommon<SD>>,
287 SD: SideData,
288{
289 fn poll_write(
290 mut self: Pin<&mut Self>,
291 cx: &mut Context,
292 buf: &[u8],
293 ) -> Poll<io::Result<usize>> {
294 let mut pos = 0;
295
296 while pos != buf.len() {
297 let mut would_block = false;
298
299 match self.session.writer().write(&buf[pos..]) {
300 Ok(n) => pos += n,
301 Err(err) => return Poll::Ready(Err(err)),
302 };
303
304 while self.session.wants_write() {
305 match self.write_io(cx) {
306 Poll::Ready(Ok(0)) | Poll::Pending => {
307 would_block = true;
308 break;
309 }
310 Poll::Ready(Ok(_)) => (),
311 Poll::Ready(Err(err)) => return Poll::Ready(Err(err)),
312 }
313 }
314
315 return match (pos, would_block) {
316 (0, true) => Poll::Pending,
317 (n, true) => Poll::Ready(Ok(n)),
318 (_, false) => continue,
319 };
320 }
321
322 Poll::Ready(Ok(pos))
323 }
324
325 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
326 self.session.writer().flush()?;
327 while self.session.wants_write() {
328 ready!(self.write_io(cx))?;
329 }
330 Pin::new(&mut self.io).poll_flush(cx)
331 }
332
333 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
334 while self.session.wants_write() {
335 ready!(self.write_io(cx))?;
336 }
337 Pin::new(&mut self.io).poll_shutdown(cx)
338 }
339}
340
341pub struct SyncReadAdapter<'a, 'b, T> {
346 pub io: &'a mut T,
347 pub cx: &'a mut Context<'b>,
348}
349
350impl<'a, 'b, T: AsyncRead + Unpin> Read for SyncReadAdapter<'a, 'b, T> {
351 #[inline]
352 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
353 let mut buf = ReadBuf::new(buf);
354 match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
355 Poll::Ready(Ok(())) => Ok(buf.filled().len()),
356 Poll::Ready(Err(err)) => Err(err),
357 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
358 }
359 }
360}
361
362#[cfg(test)]
363mod test_stream;