1macro_rules! ready {
40 ( $e:expr ) => {
41 match $e {
42 std::task::Poll::Ready(t) => t,
43 std::task::Poll::Pending => return std::task::Poll::Pending,
44 }
45 };
46}
47
48pub mod client;
49mod common;
50pub mod server;
51
52use common::{MidHandshake, Stream, TlsState};
53use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
54use std::future::Future;
55use std::io;
56#[cfg(unix)]
57use std::os::unix::io::{AsRawFd, RawFd};
58#[cfg(windows)]
59use std::os::windows::io::{AsRawSocket, RawSocket};
60use std::pin::Pin;
61use std::sync::Arc;
62use std::task::{Context, Poll};
63use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
64
65pub use rustls;
66
67#[derive(Clone)]
69pub struct TlsConnector {
70 inner: Arc<ClientConfig>,
71 #[cfg(feature = "early-data")]
72 early_data: bool,
73}
74
75#[derive(Clone)]
77pub struct TlsAcceptor {
78 inner: Arc<ServerConfig>,
79}
80
81impl From<Arc<ClientConfig>> for TlsConnector {
82 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
83 TlsConnector {
84 inner,
85 #[cfg(feature = "early-data")]
86 early_data: false,
87 }
88 }
89}
90
91impl From<Arc<ServerConfig>> for TlsAcceptor {
92 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
93 TlsAcceptor { inner }
94 }
95}
96
97impl TlsConnector {
98 #[cfg(feature = "early-data")]
103 pub fn early_data(mut self, flag: bool) -> TlsConnector {
104 self.early_data = flag;
105 self
106 }
107
108 #[inline]
109 pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
110 where
111 IO: AsyncRead + AsyncWrite + Unpin,
112 {
113 self.connect_with(domain, stream, |_| ())
114 }
115
116 pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
117 where
118 IO: AsyncRead + AsyncWrite + Unpin,
119 F: FnOnce(&mut ClientConnection),
120 {
121 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
122 Ok(session) => session,
123 Err(error) => {
124 return Connect(MidHandshake::Error {
125 io: stream,
126 error: io::Error::new(io::ErrorKind::Other, error),
129 });
130 }
131 };
132 f(&mut session);
133
134 Connect(MidHandshake::Handshaking(client::TlsStream {
135 io: stream,
136
137 #[cfg(not(feature = "early-data"))]
138 state: TlsState::Stream,
139
140 #[cfg(feature = "early-data")]
141 state: if self.early_data && session.early_data().is_some() {
142 TlsState::EarlyData(0, Vec::new())
143 } else {
144 TlsState::Stream
145 },
146
147 #[cfg(feature = "early-data")]
148 early_waker: None,
149
150 session,
151 }))
152 }
153}
154
155impl TlsAcceptor {
156 #[inline]
157 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
158 where
159 IO: AsyncRead + AsyncWrite + Unpin,
160 {
161 self.accept_with(stream, |_| ())
162 }
163
164 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
165 where
166 IO: AsyncRead + AsyncWrite + Unpin,
167 F: FnOnce(&mut ServerConnection),
168 {
169 let mut session = match ServerConnection::new(self.inner.clone()) {
170 Ok(session) => session,
171 Err(error) => {
172 return Accept(MidHandshake::Error {
173 io: stream,
174 error: io::Error::new(io::ErrorKind::Other, error),
177 });
178 }
179 };
180 f(&mut session);
181
182 Accept(MidHandshake::Handshaking(server::TlsStream {
183 session,
184 io: stream,
185 state: TlsState::Stream,
186 }))
187 }
188}
189
190pub struct LazyConfigAcceptor<IO> {
191 acceptor: rustls::server::Acceptor,
192 io: Option<IO>,
193}
194
195impl<IO> LazyConfigAcceptor<IO>
196where
197 IO: AsyncRead + AsyncWrite + Unpin,
198{
199 #[inline]
200 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
201 Self {
202 acceptor,
203 io: Some(io),
204 }
205 }
206
207 pub fn take_io(&mut self) -> Option<IO> {
249 self.io.take()
250 }
251}
252
253impl<IO> Future for LazyConfigAcceptor<IO>
254where
255 IO: AsyncRead + AsyncWrite + Unpin,
256{
257 type Output = Result<StartHandshake<IO>, io::Error>;
258
259 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
260 let this = self.get_mut();
261 loop {
262 let io = match this.io.as_mut() {
263 Some(io) => io,
264 None => {
265 return Poll::Ready(Err(io::Error::new(
266 io::ErrorKind::Other,
267 "acceptor cannot be polled after acceptance",
268 )))
269 }
270 };
271
272 let mut reader = common::SyncReadAdapter { io, cx };
273 match this.acceptor.read_tls(&mut reader) {
274 Ok(0) => return Err(io::ErrorKind::UnexpectedEof.into()).into(),
275 Ok(_) => {}
276 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
277 Err(e) => return Err(e).into(),
278 }
279
280 match this.acceptor.accept() {
281 Ok(Some(accepted)) => {
282 let io = this.io.take().unwrap();
283 return Poll::Ready(Ok(StartHandshake { accepted, io }));
284 }
285 Ok(None) => continue,
286 Err(err) => {
287 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
288 }
289 }
290 }
291 }
292}
293
294pub struct StartHandshake<IO> {
295 accepted: rustls::server::Accepted,
296 io: IO,
297}
298
299impl<IO> StartHandshake<IO>
300where
301 IO: AsyncRead + AsyncWrite + Unpin,
302{
303 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
304 self.accepted.client_hello()
305 }
306
307 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
308 self.into_stream_with(config, |_| ())
309 }
310
311 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
312 where
313 F: FnOnce(&mut ServerConnection),
314 {
315 let mut conn = match self.accepted.into_connection(config) {
316 Ok(conn) => conn,
317 Err(error) => {
318 return Accept(MidHandshake::Error {
319 io: self.io,
320 error: io::Error::new(io::ErrorKind::Other, error),
323 });
324 }
325 };
326 f(&mut conn);
327
328 Accept(MidHandshake::Handshaking(server::TlsStream {
329 session: conn,
330 io: self.io,
331 state: TlsState::Stream,
332 }))
333 }
334}
335
336pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
339
340pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
343
344pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
346
347pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
349
350impl<IO> Connect<IO> {
351 #[inline]
352 pub fn into_fallible(self) -> FallibleConnect<IO> {
353 FallibleConnect(self.0)
354 }
355
356 pub fn get_ref(&self) -> Option<&IO> {
357 match &self.0 {
358 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
359 MidHandshake::Error { io, .. } => Some(io),
360 MidHandshake::End => None,
361 }
362 }
363
364 pub fn get_mut(&mut self) -> Option<&mut IO> {
365 match &mut self.0 {
366 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
367 MidHandshake::Error { io, .. } => Some(io),
368 MidHandshake::End => None,
369 }
370 }
371}
372
373impl<IO> Accept<IO> {
374 #[inline]
375 pub fn into_fallible(self) -> FallibleAccept<IO> {
376 FallibleAccept(self.0)
377 }
378
379 pub fn get_ref(&self) -> Option<&IO> {
380 match &self.0 {
381 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
382 MidHandshake::Error { io, .. } => Some(io),
383 MidHandshake::End => None,
384 }
385 }
386
387 pub fn get_mut(&mut self) -> Option<&mut IO> {
388 match &mut self.0 {
389 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
390 MidHandshake::Error { io, .. } => Some(io),
391 MidHandshake::End => None,
392 }
393 }
394}
395
396impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
397 type Output = io::Result<client::TlsStream<IO>>;
398
399 #[inline]
400 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
401 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
402 }
403}
404
405impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
406 type Output = io::Result<server::TlsStream<IO>>;
407
408 #[inline]
409 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
410 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
411 }
412}
413
414impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
415 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
416
417 #[inline]
418 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
419 Pin::new(&mut self.0).poll(cx)
420 }
421}
422
423impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
424 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
425
426 #[inline]
427 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
428 Pin::new(&mut self.0).poll(cx)
429 }
430}
431
432#[allow(clippy::large_enum_variant)] #[derive(Debug)]
438pub enum TlsStream<T> {
439 Client(client::TlsStream<T>),
440 Server(server::TlsStream<T>),
441}
442
443impl<T> TlsStream<T> {
444 pub fn get_ref(&self) -> (&T, &CommonState) {
445 use TlsStream::*;
446 match self {
447 Client(io) => {
448 let (io, session) = io.get_ref();
449 (io, session)
450 }
451 Server(io) => {
452 let (io, session) = io.get_ref();
453 (io, session)
454 }
455 }
456 }
457
458 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
459 use TlsStream::*;
460 match self {
461 Client(io) => {
462 let (io, session) = io.get_mut();
463 (io, &mut *session)
464 }
465 Server(io) => {
466 let (io, session) = io.get_mut();
467 (io, &mut *session)
468 }
469 }
470 }
471}
472
473impl<T> From<client::TlsStream<T>> for TlsStream<T> {
474 fn from(s: client::TlsStream<T>) -> Self {
475 Self::Client(s)
476 }
477}
478
479impl<T> From<server::TlsStream<T>> for TlsStream<T> {
480 fn from(s: server::TlsStream<T>) -> Self {
481 Self::Server(s)
482 }
483}
484
485#[cfg(unix)]
486impl<S> AsRawFd for TlsStream<S>
487where
488 S: AsRawFd,
489{
490 fn as_raw_fd(&self) -> RawFd {
491 self.get_ref().0.as_raw_fd()
492 }
493}
494
495#[cfg(windows)]
496impl<S> AsRawSocket for TlsStream<S>
497where
498 S: AsRawSocket,
499{
500 fn as_raw_socket(&self) -> RawSocket {
501 self.get_ref().0.as_raw_socket()
502 }
503}
504
505impl<T> AsyncRead for TlsStream<T>
506where
507 T: AsyncRead + AsyncWrite + Unpin,
508{
509 #[inline]
510 fn poll_read(
511 self: Pin<&mut Self>,
512 cx: &mut Context<'_>,
513 buf: &mut ReadBuf<'_>,
514 ) -> Poll<io::Result<()>> {
515 match self.get_mut() {
516 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
517 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
518 }
519 }
520}
521
522impl<T> AsyncWrite for TlsStream<T>
523where
524 T: AsyncRead + AsyncWrite + Unpin,
525{
526 #[inline]
527 fn poll_write(
528 self: Pin<&mut Self>,
529 cx: &mut Context<'_>,
530 buf: &[u8],
531 ) -> Poll<io::Result<usize>> {
532 match self.get_mut() {
533 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
534 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
535 }
536 }
537
538 #[inline]
539 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
540 match self.get_mut() {
541 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
542 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
543 }
544 }
545
546 #[inline]
547 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
548 match self.get_mut() {
549 TlsStream::Client(x) => Pin::new(x).poll_shutdown(cx),
550 TlsStream::Server(x) => Pin::new(x).poll_shutdown(cx),
551 }
552 }
553}