tokio_rustls/
server.rs

1#[cfg(unix)]
2use std::os::unix::io::{AsRawFd, RawFd};
3#[cfg(windows)]
4use std::os::windows::io::{AsRawSocket, RawSocket};
5
6use super::*;
7use crate::common::IoSession;
8
9/// A wrapper around an underlying raw stream which implements the TLS or SSL
10/// protocol.
11#[derive(Debug)]
12pub struct TlsStream<IO> {
13    pub(crate) io: IO,
14    pub(crate) session: ServerConnection,
15    pub(crate) state: TlsState,
16}
17
18impl<IO> TlsStream<IO> {
19    #[inline]
20    pub fn get_ref(&self) -> (&IO, &ServerConnection) {
21        (&self.io, &self.session)
22    }
23
24    #[inline]
25    pub fn get_mut(&mut self) -> (&mut IO, &mut ServerConnection) {
26        (&mut self.io, &mut self.session)
27    }
28
29    #[inline]
30    pub fn into_inner(self) -> (IO, ServerConnection) {
31        (self.io, self.session)
32    }
33}
34
35impl<IO> IoSession for TlsStream<IO> {
36    type Io = IO;
37    type Session = ServerConnection;
38
39    #[inline]
40    fn skip_handshake(&self) -> bool {
41        false
42    }
43
44    #[inline]
45    fn get_mut(&mut self) -> (&mut TlsState, &mut Self::Io, &mut Self::Session) {
46        (&mut self.state, &mut self.io, &mut self.session)
47    }
48
49    #[inline]
50    fn into_io(self) -> Self::Io {
51        self.io
52    }
53}
54
55impl<IO> AsyncRead for TlsStream<IO>
56where
57    IO: AsyncRead + AsyncWrite + Unpin,
58{
59    fn poll_read(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &mut ReadBuf<'_>,
63    ) -> Poll<io::Result<()>> {
64        let this = self.get_mut();
65        let mut stream =
66            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
67
68        match &this.state {
69            TlsState::Stream | TlsState::WriteShutdown => {
70                let prev = buf.remaining();
71
72                match stream.as_mut_pin().poll_read(cx, buf) {
73                    Poll::Ready(Ok(())) => {
74                        if prev == buf.remaining() || stream.eof {
75                            this.state.shutdown_read();
76                        }
77
78                        Poll::Ready(Ok(()))
79                    }
80                    Poll::Ready(Err(err)) if err.kind() == io::ErrorKind::UnexpectedEof => {
81                        this.state.shutdown_read();
82                        Poll::Ready(Err(err))
83                    }
84                    output => output,
85                }
86            }
87            TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
88            #[cfg(feature = "early-data")]
89            s => unreachable!("server TLS can not hit this state: {:?}", s),
90        }
91    }
92}
93
94impl<IO> AsyncWrite for TlsStream<IO>
95where
96    IO: AsyncRead + AsyncWrite + Unpin,
97{
98    /// Note: that it does not guarantee the final data to be sent.
99    /// To be cautious, you must manually call `flush`.
100    fn poll_write(
101        self: Pin<&mut Self>,
102        cx: &mut Context<'_>,
103        buf: &[u8],
104    ) -> Poll<io::Result<usize>> {
105        let this = self.get_mut();
106        let mut stream =
107            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
108        stream.as_mut_pin().poll_write(cx, buf)
109    }
110
111    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
112        let this = self.get_mut();
113        let mut stream =
114            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
115        stream.as_mut_pin().poll_flush(cx)
116    }
117
118    fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
119        if self.state.writeable() {
120            self.session.send_close_notify();
121            self.state.shutdown_write();
122        }
123
124        let this = self.get_mut();
125        let mut stream =
126            Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
127        stream.as_mut_pin().poll_shutdown(cx)
128    }
129}
130
131#[cfg(unix)]
132impl<IO> AsRawFd for TlsStream<IO>
133where
134    IO: AsRawFd,
135{
136    fn as_raw_fd(&self) -> RawFd {
137        self.get_ref().0.as_raw_fd()
138    }
139}
140
141#[cfg(windows)]
142impl<IO> AsRawSocket for TlsStream<IO>
143where
144    IO: AsRawSocket,
145{
146    fn as_raw_socket(&self) -> RawSocket {
147        self.get_ref().0.as_raw_socket()
148    }
149}