tokio_rustls/
server.rs
1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, RawSocket};
5
6use super::*;
7use crate::common::IoSession;
8
9#[derive(Debug)]
12pub struct TlsStream<IO> {
13 pub(crate) io: IO,
14 pub(crate) session: ServerConnection,
15 pub(crate) state: TlsState,
16}
17
18impl<IO> TlsStream<IO> {
19 #[inline]
20 pub fn get_ref(&self) -> (&IO, &ServerConnection) {
21 (&self.io, &self.session)
22 }
23
24 #[inline]
25 pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
26 (&mut self.io, &mut self.session)
27 }
28
29 #[inline]
30 pub fn into_inner(self) -> (IO, ServerConnection) {
31 (self.io, self.session)
32 }
33}
34
35impl<IO> IoSession for TlsStream<IO> {
36 type Io = IO;
37 type Session = ServerConnection;
38
39 #[inline]
40 fn skip_handshake(&self) -> bool {
41 false
42 }
43
44 #[inline]
45 fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
46 (&mut self.state, &mut self.io, &mut self.session)
47 }
48
49 #[inline]
50 fn into_io(self) -> Self::Io {
51 self.io
52 }
53}
54
55impl<IO> AsyncRead for TlsStream<IO>
56where
57 IO: AsyncRead + AsyncWrite + Unpin,
58{
59 fn poll_read(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &mut ReadBuf<'_>,
63 ) -> Poll<io::Result<()>> {
64 let this = self.get_mut();
65 let mut stream =
66 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67
68 match &this.state {
69 TlsState::Stream | TlsState::WriteShutdown => {
70 let prev = buf.remaining();
71
72 match stream.as_mut_pin().poll_read(cx, buf) {
73 Poll::Ready(Ok(())) => {
74 if prev == buf.remaining() || stream.eof {
75 this.state.shutdown_read();
76 }
77
78 Poll::Ready(Ok(()))
79 }
80 Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
81 this.state.shutdown_read();
82 Poll::Ready(Err(err))
83 }
84 output => output,
85 }
86 }
87 TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
88 #[cfg(feature = "early-data")]
89 s => unreachable!("server TLS can not hit this state: {:?}", s),
90 }
91 }
92}
93
94impl<IO> AsyncWrite for TlsStream<IO>
95where
96 IO: AsyncRead + AsyncWrite + Unpin,
97{
98 fn poll_write(
101 self: Pin<&mut Self>,
102 cx: &mut Context<'_>,
103 buf: &[u8],
104 ) -> Poll<io::Result<usize>> {
105 let this = self.get_mut();
106 let mut stream =
107 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
108 stream.as_mut_pin().poll_write(cx, buf)
109 }
110
111 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112 let this = self.get_mut();
113 let mut stream =
114 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
115 stream.as_mut_pin().poll_flush(cx)
116 }
117
118 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119 if self.state.writeable() {
120 self.session.send_close_notify();
121 self.state.shutdown_write();
122 }
123
124 let this = self.get_mut();
125 let mut stream =
126 Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
127 stream.as_mut_pin().poll_shutdown(cx)
128 }
129}
130
131#[cfg(unix)]
132impl<IO> AsRawFd for TlsStream<IO>
133where
134 IO: AsRawFd,
135{
136 fn as_raw_fd(&self) -> RawFd {
137 self.get_ref().0.as_raw_fd()
138 }
139}
140
141#[cfg(windows)]
142impl<IO> AsRawSocket for TlsStream<IO>
143where
144 IO: AsRawSocket,
145{
146 fn as_raw_socket(&self) -> RawSocket {
147 self.get_ref().0.as_raw_socket()
148 }
149}