hyper_rustls/
acceptor.rs
1use core::task::{Context, Poll};
2use std::future::Future;
3use std::io;
4use std::pin::Pin;
5use std::sync::Arc;
6
7use futures_util::ready;
8use hyper::server::{
9 accept::Accept,
10 conn::{AddrIncoming, AddrStream},
11};
12use rustls::{ServerConfig, ServerConnection};
13use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
14
15mod builder;
16pub use builder::AcceptorBuilder;
17use builder::WantsTlsConfig;
18
19pub struct TlsAcceptor<A = AddrIncoming> {
21 config: Arc<ServerConfig>,
22 acceptor: A,
23}
24
25impl TlsAcceptor {
27 pub fn builder() -> AcceptorBuilder<WantsTlsConfig> {
29 AcceptorBuilder::new()
30 }
31
32 pub fn new(config: Arc<ServerConfig>, incoming: AddrIncoming) -> Self {
34 Self {
35 config,
36 acceptor: incoming,
37 }
38 }
39}
40
41impl<A> Accept for TlsAcceptor<A>
42where
43 A: Accept<Error = io::Error> + Unpin,
44 A::Conn: AsyncRead + AsyncWrite + Unpin,
45{
46 type Conn = TlsStream<A::Conn>;
47 type Error = io::Error;
48
49 fn poll_accept(
50 self: Pin<&mut Self>,
51 cx: &mut Context<'_>,
52 ) -> Poll<Option<Result<Self::Conn, Self::Error>>> {
53 let pin = self.get_mut();
54 Poll::Ready(match ready!(Pin::new(&mut pin.acceptor).poll_accept(cx)) {
55 Some(Ok(sock)) => Some(Ok(TlsStream::new(sock, pin.config.clone()))),
56 Some(Err(e)) => Some(Err(e)),
57 None => None,
58 })
59 }
60}
61
62impl<C, I> From<(C, I)> for TlsAcceptor
63where
64 C: Into<Arc<ServerConfig>>,
65 I: Into<AddrIncoming>,
66{
67 fn from((config, incoming): (C, I)) -> Self {
68 Self::new(config.into(), incoming.into())
69 }
70}
71
72pub struct TlsStream<C = AddrStream> {
77 state: State<C>,
78}
79
80impl<C: AsyncRead + AsyncWrite + Unpin> TlsStream<C> {
81 fn new(stream: C, config: Arc<ServerConfig>) -> Self {
82 let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream);
83 Self {
84 state: State::Handshaking(accept),
85 }
86 }
87 pub fn io(&self) -> Option<&C> {
91 match &self.state {
92 State::Handshaking(accept) => accept.get_ref(),
93 State::Streaming(stream) => Some(stream.get_ref().0),
94 }
95 }
96
97 pub fn connection(&self) -> Option<&ServerConnection> {
101 match &self.state {
102 State::Handshaking(_) => None,
103 State::Streaming(stream) => Some(stream.get_ref().1),
104 }
105 }
106}
107
108impl<C: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsStream<C> {
109 fn poll_read(
110 self: Pin<&mut Self>,
111 cx: &mut Context,
112 buf: &mut ReadBuf,
113 ) -> Poll<io::Result<()>> {
114 let pin = self.get_mut();
115 let accept = match &mut pin.state {
116 State::Handshaking(accept) => accept,
117 State::Streaming(stream) => return Pin::new(stream).poll_read(cx, buf),
118 };
119
120 let mut stream = match ready!(Pin::new(accept).poll(cx)) {
121 Ok(stream) => stream,
122 Err(err) => return Poll::Ready(Err(err)),
123 };
124
125 let result = Pin::new(&mut stream).poll_read(cx, buf);
126 pin.state = State::Streaming(stream);
127 result
128 }
129}
130
131impl<C: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsStream<C> {
132 fn poll_write(
133 self: Pin<&mut Self>,
134 cx: &mut Context<'_>,
135 buf: &[u8],
136 ) -> Poll<io::Result<usize>> {
137 let pin = self.get_mut();
138 let accept = match &mut pin.state {
139 State::Handshaking(accept) => accept,
140 State::Streaming(stream) => return Pin::new(stream).poll_write(cx, buf),
141 };
142
143 let mut stream = match ready!(Pin::new(accept).poll(cx)) {
144 Ok(stream) => stream,
145 Err(err) => return Poll::Ready(Err(err)),
146 };
147
148 let result = Pin::new(&mut stream).poll_write(cx, buf);
149 pin.state = State::Streaming(stream);
150 result
151 }
152
153 fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
154 match &mut self.state {
155 State::Handshaking(_) => Poll::Ready(Ok(())),
156 State::Streaming(stream) => Pin::new(stream).poll_flush(cx),
157 }
158 }
159
160 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
161 match &mut self.state {
162 State::Handshaking(_) => Poll::Ready(Ok(())),
163 State::Streaming(stream) => Pin::new(stream).poll_shutdown(cx),
164 }
165 }
166}
167
168enum State<C> {
169 Handshaking(tokio_rustls::Accept<C>),
170 Streaming(tokio_rustls::server::TlsStream<C>),
171}