hyper_rustls/
stream.rs

1// Copied from hyperium/hyper-tls#62e3376/src/stream.rs
2use std::fmt;
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use hyper::client::connect::{Connected, Connection};
8
9use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
10use tokio_rustls::client::TlsStream;
11
12/// A stream that might be protected with TLS.
13#[allow(clippy::large_enum_variant)]
14pub enum MaybeHttpsStream<T> {
15    /// A stream over plain text.
16    Http(T),
17    /// A stream protected with TLS.
18    Https(TlsStream<T>),
19}
20
21impl<T: AsyncRead + AsyncWrite + Connection + Unpin> Connection for MaybeHttpsStream<T> {
22    fn connected(&self) -> Connected {
23        match self {
24            Self::Http(s) => s.connected(),
25            Self::Https(s) => {
26                let (tcp, tls) = s.get_ref();
27                if tls.alpn_protocol() == Some(b"h2") {
28                    tcp.connected().negotiated_h2()
29                } else {
30                    tcp.connected()
31                }
32            }
33        }
34    }
35}
36
37impl<T: fmt::Debug> fmt::Debug for MaybeHttpsStream<T> {
38    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
39        match *self {
40            Self::Http(..) => f.pad("Http(..)"),
41            Self::Https(..) => f.pad("Https(..)"),
42        }
43    }
44}
45
46impl<T> From<T> for MaybeHttpsStream<T> {
47    fn from(inner: T) -> Self {
48        Self::Http(inner)
49    }
50}
51
52impl<T> From<TlsStream<T>> for MaybeHttpsStream<T> {
53    fn from(inner: TlsStream<T>) -> Self {
54        Self::Https(inner)
55    }
56}
57
58impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for MaybeHttpsStream<T> {
59    #[inline]
60    fn poll_read(
61        self: Pin<&mut Self>,
62        cx: &mut Context,
63        buf: &mut ReadBuf<'_>,
64    ) -> Poll<Result<(), io::Error>> {
65        match Pin::get_mut(self) {
66            Self::Http(s) => Pin::new(s).poll_read(cx, buf),
67            Self::Https(s) => Pin::new(s).poll_read(cx, buf),
68        }
69    }
70}
71
72impl<T: AsyncWrite + AsyncRead + Unpin> AsyncWrite for MaybeHttpsStream<T> {
73    #[inline]
74    fn poll_write(
75        self: Pin<&mut Self>,
76        cx: &mut Context<'_>,
77        buf: &[u8],
78    ) -> Poll<Result<usize, io::Error>> {
79        match Pin::get_mut(self) {
80            Self::Http(s) => Pin::new(s).poll_write(cx, buf),
81            Self::Https(s) => Pin::new(s).poll_write(cx, buf),
82        }
83    }
84
85    #[inline]
86    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
87        match Pin::get_mut(self) {
88            Self::Http(s) => Pin::new(s).poll_flush(cx),
89            Self::Https(s) => Pin::new(s).poll_flush(cx),
90        }
91    }
92
93    #[inline]
94    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
95        match Pin::get_mut(self) {
96            Self::Http(s) => Pin::new(s).poll_shutdown(cx),
97            Self::Https(s) => Pin::new(s).poll_shutdown(cx),
98        }
99    }
100}