rustls/msgs/
handshake.rs

1#![allow(non_camel_case_types)]
2use crate::dns_name::{DnsName, DnsNameRef};
3use crate::enums::{CipherSuite, HandshakeType, ProtocolVersion, SignatureScheme};
4use crate::error::InvalidMessage;
5use crate::key;
6#[cfg(feature = "logging")]
7use crate::log::warn;
8use crate::msgs::base::{Payload, PayloadU16, PayloadU24, PayloadU8};
9use crate::msgs::codec::{self, Codec, ListLength, Reader, TlsListElement};
10use crate::msgs::enums::{
11    CertificateStatusType, ClientCertificateType, Compression, ECCurveType, ECPointFormat,
12    ExtensionType, KeyUpdateRequest, NamedGroup, PSKKeyExchangeMode, ServerNameType,
13};
14use crate::rand;
15use crate::verify::DigitallySignedStruct;
16
17use std::collections;
18use std::fmt;
19use std::net::IpAddr;
20use std::str::FromStr;
21
22/// Create a newtype wrapper around a given type.
23///
24/// This is used to create newtypes for the various TLS message types which is used to wrap
25/// the `PayloadU8` or `PayloadU16` types. This is typically used for types where we don't need
26/// anything other than access to the underlying bytes.
27macro_rules! wrapped_payload(
28  ($(#[$comment:meta])* $name:ident, $inner:ident,) => {
29    $(#[$comment])*
30    #[derive(Clone, Debug)]
31    pub struct $name($inner);
32
33    impl From<Vec<u8>> for $name {
34        fn from(v: Vec<u8>) -> Self {
35            Self($inner::new(v))
36        }
37    }
38
39    impl AsRef<[u8]> for $name {
40        fn as_ref(&self) -> &[u8] {
41            self.0.0.as_slice()
42        }
43    }
44
45    impl Codec for $name {
46        fn encode(&self, bytes: &mut Vec<u8>) {
47            self.0.encode(bytes);
48        }
49
50        fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
51            Ok(Self($inner::read(r)?))
52        }
53    }
54  }
55);
56
57#[derive(Clone, Copy, Eq, PartialEq)]
58pub struct Random(pub [u8; 32]);
59
60impl fmt::Debug for Random {
61    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62        super::base::hex(f, &self.0)
63    }
64}
65
66static HELLO_RETRY_REQUEST_RANDOM: Random = Random([
67    0xcf, 0x21, 0xad, 0x74, 0xe5, 0x9a, 0x61, 0x11, 0xbe, 0x1d, 0x8c, 0x02, 0x1e, 0x65, 0xb8, 0x91,
68    0xc2, 0xa2, 0x11, 0x16, 0x7a, 0xbb, 0x8c, 0x5e, 0x07, 0x9e, 0x09, 0xe2, 0xc8, 0xa8, 0x33, 0x9c,
69]);
70
71static ZERO_RANDOM: Random = Random([0u8; 32]);
72
73impl Codec for Random {
74    fn encode(&self, bytes: &mut Vec<u8>) {
75        bytes.extend_from_slice(&self.0);
76    }
77
78    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
79        let bytes = match r.take(32) {
80            Some(bytes) => bytes,
81            None => return Err(InvalidMessage::MissingData("Random")),
82        };
83
84        let mut opaque = [0; 32];
85        opaque.clone_from_slice(bytes);
86        Ok(Self(opaque))
87    }
88}
89
90impl Random {
91    pub fn new() -> Result<Self, rand::GetRandomFailed> {
92        let mut data = [0u8; 32];
93        rand::fill_random(&mut data)?;
94        Ok(Self(data))
95    }
96
97    pub fn write_slice(&self, bytes: &mut [u8]) {
98        let buf = self.get_encoding();
99        bytes.copy_from_slice(&buf);
100    }
101}
102
103impl From<[u8; 32]> for Random {
104    #[inline]
105    fn from(bytes: [u8; 32]) -> Self {
106        Self(bytes)
107    }
108}
109
110#[derive(Copy, Clone)]
111pub struct SessionId {
112    len: usize,
113    data: [u8; 32],
114}
115
116impl fmt::Debug for SessionId {
117    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
118        super::base::hex(f, &self.data[..self.len])
119    }
120}
121
122impl PartialEq for SessionId {
123    fn eq(&self, other: &Self) -> bool {
124        if self.len != other.len {
125            return false;
126        }
127
128        let mut diff = 0u8;
129        for i in 0..self.len {
130            diff |= self.data[i] ^ other.data[i];
131        }
132
133        diff == 0u8
134    }
135}
136
137impl Codec for SessionId {
138    fn encode(&self, bytes: &mut Vec<u8>) {
139        debug_assert!(self.len <= 32);
140        bytes.push(self.len as u8);
141        bytes.extend_from_slice(&self.data[..self.len]);
142    }
143
144    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
145        let len = u8::read(r)? as usize;
146        if len > 32 {
147            return Err(InvalidMessage::TrailingData("SessionID"));
148        }
149
150        let bytes = match r.take(len) {
151            Some(bytes) => bytes,
152            None => return Err(InvalidMessage::MissingData("SessionID")),
153        };
154
155        let mut out = [0u8; 32];
156        out[..len].clone_from_slice(&bytes[..len]);
157        Ok(Self { data: out, len })
158    }
159}
160
161impl SessionId {
162    pub fn random() -> Result<Self, rand::GetRandomFailed> {
163        let mut data = [0u8; 32];
164        rand::fill_random(&mut data)?;
165        Ok(Self { data, len: 32 })
166    }
167
168    pub fn empty() -> Self {
169        Self {
170            data: [0u8; 32],
171            len: 0,
172        }
173    }
174
175    pub fn len(&self) -> usize {
176        self.len
177    }
178
179    pub fn is_empty(&self) -> bool {
180        self.len == 0
181    }
182}
183
184#[derive(Clone, Debug)]
185pub struct UnknownExtension {
186    pub typ: ExtensionType,
187    pub payload: Payload,
188}
189
190impl UnknownExtension {
191    fn encode(&self, bytes: &mut Vec<u8>) {
192        self.payload.encode(bytes);
193    }
194
195    fn read(typ: ExtensionType, r: &mut Reader) -> Self {
196        let payload = Payload::read(r);
197        Self { typ, payload }
198    }
199}
200
201impl TlsListElement for ECPointFormat {
202    const SIZE_LEN: ListLength = ListLength::U8;
203}
204
205impl TlsListElement for NamedGroup {
206    const SIZE_LEN: ListLength = ListLength::U16;
207}
208
209impl TlsListElement for SignatureScheme {
210    const SIZE_LEN: ListLength = ListLength::U16;
211}
212
213#[derive(Clone, Debug)]
214pub enum ServerNamePayload {
215    HostName(DnsName),
216    IpAddress(PayloadU16),
217    Unknown(Payload),
218}
219
220impl ServerNamePayload {
221    pub fn new_hostname(hostname: DnsName) -> Self {
222        Self::HostName(hostname)
223    }
224
225    fn read_hostname(r: &mut Reader) -> Result<Self, InvalidMessage> {
226        let raw = PayloadU16::read(r)?;
227        match DnsName::try_from_ascii(&raw.0) {
228            Ok(dns_name) => Ok(Self::HostName(dns_name)),
229            Err(_) => {
230                let _ = IpAddr::from_str(&String::from_utf8_lossy(&raw.0))
231                    .map_err(|_| InvalidMessage::InvalidServerName)?;
232                Ok(Self::IpAddress(raw))
233            }
234        }
235    }
236
237    fn encode(&self, bytes: &mut Vec<u8>) {
238        match *self {
239            Self::HostName(ref name) => {
240                (name.as_ref().len() as u16).encode(bytes);
241                bytes.extend_from_slice(name.as_ref().as_bytes());
242            }
243            Self::IpAddress(ref r) => r.encode(bytes),
244            Self::Unknown(ref r) => r.encode(bytes),
245        }
246    }
247}
248
249#[derive(Clone, Debug)]
250pub struct ServerName {
251    pub typ: ServerNameType,
252    pub payload: ServerNamePayload,
253}
254
255impl Codec for ServerName {
256    fn encode(&self, bytes: &mut Vec<u8>) {
257        self.typ.encode(bytes);
258        self.payload.encode(bytes);
259    }
260
261    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
262        let typ = ServerNameType::read(r)?;
263
264        let payload = match typ {
265            ServerNameType::HostName => ServerNamePayload::read_hostname(r)?,
266            _ => ServerNamePayload::Unknown(Payload::read(r)),
267        };
268
269        Ok(Self { typ, payload })
270    }
271}
272
273impl TlsListElement for ServerName {
274    const SIZE_LEN: ListLength = ListLength::U16;
275}
276
277pub trait ConvertServerNameList {
278    fn has_duplicate_names_for_type(&self) -> bool;
279    fn get_single_hostname(&self) -> Option<DnsNameRef>;
280}
281
282impl ConvertServerNameList for [ServerName] {
283    /// RFC6066: "The ServerNameList MUST NOT contain more than one name of the same name_type."
284    fn has_duplicate_names_for_type(&self) -> bool {
285        let mut seen = collections::HashSet::new();
286
287        for name in self {
288            if !seen.insert(name.typ.get_u8()) {
289                return true;
290            }
291        }
292
293        false
294    }
295
296    fn get_single_hostname(&self) -> Option<DnsNameRef> {
297        fn only_dns_hostnames(name: &ServerName) -> Option<DnsNameRef> {
298            if let ServerNamePayload::HostName(ref dns) = name.payload {
299                Some(dns.borrow())
300            } else {
301                None
302            }
303        }
304
305        self.iter()
306            .filter_map(only_dns_hostnames)
307            .next()
308    }
309}
310
311wrapped_payload!(ProtocolName, PayloadU8,);
312
313impl TlsListElement for ProtocolName {
314    const SIZE_LEN: ListLength = ListLength::U16;
315}
316
317pub trait ConvertProtocolNameList {
318    fn from_slices(names: &[&[u8]]) -> Self;
319    fn to_slices(&self) -> Vec<&[u8]>;
320    fn as_single_slice(&self) -> Option<&[u8]>;
321}
322
323impl ConvertProtocolNameList for Vec<ProtocolName> {
324    fn from_slices(names: &[&[u8]]) -> Self {
325        let mut ret = Self::new();
326
327        for name in names {
328            ret.push(ProtocolName::from(name.to_vec()));
329        }
330
331        ret
332    }
333
334    fn to_slices(&self) -> Vec<&[u8]> {
335        self.iter()
336            .map(|proto| proto.as_ref())
337            .collect::<Vec<&[u8]>>()
338    }
339
340    fn as_single_slice(&self) -> Option<&[u8]> {
341        if self.len() == 1 {
342            Some(self[0].as_ref())
343        } else {
344            None
345        }
346    }
347}
348
349// --- TLS 1.3 Key shares ---
350#[derive(Clone, Debug)]
351pub struct KeyShareEntry {
352    pub group: NamedGroup,
353    pub payload: PayloadU16,
354}
355
356impl KeyShareEntry {
357    pub fn new(group: NamedGroup, payload: &[u8]) -> Self {
358        Self {
359            group,
360            payload: PayloadU16::new(payload.to_vec()),
361        }
362    }
363}
364
365impl Codec for KeyShareEntry {
366    fn encode(&self, bytes: &mut Vec<u8>) {
367        self.group.encode(bytes);
368        self.payload.encode(bytes);
369    }
370
371    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
372        let group = NamedGroup::read(r)?;
373        let payload = PayloadU16::read(r)?;
374
375        Ok(Self { group, payload })
376    }
377}
378
379// --- TLS 1.3 PresharedKey offers ---
380#[derive(Clone, Debug)]
381pub struct PresharedKeyIdentity {
382    pub identity: PayloadU16,
383    pub obfuscated_ticket_age: u32,
384}
385
386impl PresharedKeyIdentity {
387    pub fn new(id: Vec<u8>, age: u32) -> Self {
388        Self {
389            identity: PayloadU16::new(id),
390            obfuscated_ticket_age: age,
391        }
392    }
393}
394
395impl Codec for PresharedKeyIdentity {
396    fn encode(&self, bytes: &mut Vec<u8>) {
397        self.identity.encode(bytes);
398        self.obfuscated_ticket_age.encode(bytes);
399    }
400
401    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
402        Ok(Self {
403            identity: PayloadU16::read(r)?,
404            obfuscated_ticket_age: u32::read(r)?,
405        })
406    }
407}
408
409impl TlsListElement for PresharedKeyIdentity {
410    const SIZE_LEN: ListLength = ListLength::U16;
411}
412
413wrapped_payload!(PresharedKeyBinder, PayloadU8,);
414
415impl TlsListElement for PresharedKeyBinder {
416    const SIZE_LEN: ListLength = ListLength::U16;
417}
418
419#[derive(Clone, Debug)]
420pub struct PresharedKeyOffer {
421    pub identities: Vec<PresharedKeyIdentity>,
422    pub binders: Vec<PresharedKeyBinder>,
423}
424
425impl PresharedKeyOffer {
426    /// Make a new one with one entry.
427    pub fn new(id: PresharedKeyIdentity, binder: Vec<u8>) -> Self {
428        Self {
429            identities: vec![id],
430            binders: vec![PresharedKeyBinder::from(binder)],
431        }
432    }
433}
434
435impl Codec for PresharedKeyOffer {
436    fn encode(&self, bytes: &mut Vec<u8>) {
437        self.identities.encode(bytes);
438        self.binders.encode(bytes);
439    }
440
441    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
442        Ok(Self {
443            identities: Vec::read(r)?,
444            binders: Vec::read(r)?,
445        })
446    }
447}
448
449// --- RFC6066 certificate status request ---
450wrapped_payload!(ResponderId, PayloadU16,);
451
452impl TlsListElement for ResponderId {
453    const SIZE_LEN: ListLength = ListLength::U16;
454}
455
456#[derive(Clone, Debug)]
457pub struct OCSPCertificateStatusRequest {
458    pub responder_ids: Vec<ResponderId>,
459    pub extensions: PayloadU16,
460}
461
462impl Codec for OCSPCertificateStatusRequest {
463    fn encode(&self, bytes: &mut Vec<u8>) {
464        CertificateStatusType::OCSP.encode(bytes);
465        self.responder_ids.encode(bytes);
466        self.extensions.encode(bytes);
467    }
468
469    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
470        Ok(Self {
471            responder_ids: Vec::read(r)?,
472            extensions: PayloadU16::read(r)?,
473        })
474    }
475}
476
477#[derive(Clone, Debug)]
478pub enum CertificateStatusRequest {
479    OCSP(OCSPCertificateStatusRequest),
480    Unknown((CertificateStatusType, Payload)),
481}
482
483impl Codec for CertificateStatusRequest {
484    fn encode(&self, bytes: &mut Vec<u8>) {
485        match self {
486            Self::OCSP(ref r) => r.encode(bytes),
487            Self::Unknown((typ, payload)) => {
488                typ.encode(bytes);
489                payload.encode(bytes);
490            }
491        }
492    }
493
494    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
495        let typ = CertificateStatusType::read(r)?;
496
497        match typ {
498            CertificateStatusType::OCSP => {
499                let ocsp_req = OCSPCertificateStatusRequest::read(r)?;
500                Ok(Self::OCSP(ocsp_req))
501            }
502            _ => {
503                let data = Payload::read(r);
504                Ok(Self::Unknown((typ, data)))
505            }
506        }
507    }
508}
509
510impl CertificateStatusRequest {
511    pub fn build_ocsp() -> Self {
512        let ocsp = OCSPCertificateStatusRequest {
513            responder_ids: Vec::new(),
514            extensions: PayloadU16::empty(),
515        };
516        Self::OCSP(ocsp)
517    }
518}
519
520// ---
521// SCTs
522
523wrapped_payload!(Sct, PayloadU16,);
524
525impl TlsListElement for Sct {
526    const SIZE_LEN: ListLength = ListLength::U16;
527}
528
529// ---
530
531impl TlsListElement for PSKKeyExchangeMode {
532    const SIZE_LEN: ListLength = ListLength::U8;
533}
534
535impl TlsListElement for KeyShareEntry {
536    const SIZE_LEN: ListLength = ListLength::U16;
537}
538
539impl TlsListElement for ProtocolVersion {
540    const SIZE_LEN: ListLength = ListLength::U8;
541}
542
543#[derive(Clone, Debug)]
544pub enum ClientExtension {
545    ECPointFormats(Vec<ECPointFormat>),
546    NamedGroups(Vec<NamedGroup>),
547    SignatureAlgorithms(Vec<SignatureScheme>),
548    ServerName(Vec<ServerName>),
549    SessionTicket(ClientSessionTicket),
550    Protocols(Vec<ProtocolName>),
551    SupportedVersions(Vec<ProtocolVersion>),
552    KeyShare(Vec<KeyShareEntry>),
553    PresharedKeyModes(Vec<PSKKeyExchangeMode>),
554    PresharedKey(PresharedKeyOffer),
555    Cookie(PayloadU16),
556    ExtendedMasterSecretRequest,
557    CertificateStatusRequest(CertificateStatusRequest),
558    SignedCertificateTimestampRequest,
559    TransportParameters(Vec<u8>),
560    TransportParametersDraft(Vec<u8>),
561    EarlyData,
562    Unknown(UnknownExtension),
563}
564
565impl ClientExtension {
566    pub fn get_type(&self) -> ExtensionType {
567        match *self {
568            Self::ECPointFormats(_) => ExtensionType::ECPointFormats,
569            Self::NamedGroups(_) => ExtensionType::EllipticCurves,
570            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
571            Self::ServerName(_) => ExtensionType::ServerName,
572            Self::SessionTicket(_) => ExtensionType::SessionTicket,
573            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
574            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
575            Self::KeyShare(_) => ExtensionType::KeyShare,
576            Self::PresharedKeyModes(_) => ExtensionType::PSKKeyExchangeModes,
577            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
578            Self::Cookie(_) => ExtensionType::Cookie,
579            Self::ExtendedMasterSecretRequest => ExtensionType::ExtendedMasterSecret,
580            Self::CertificateStatusRequest(_) => ExtensionType::StatusRequest,
581            Self::SignedCertificateTimestampRequest => ExtensionType::SCT,
582            Self::TransportParameters(_) => ExtensionType::TransportParameters,
583            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
584            Self::EarlyData => ExtensionType::EarlyData,
585            Self::Unknown(ref r) => r.typ,
586        }
587    }
588}
589
590impl Codec for ClientExtension {
591    fn encode(&self, bytes: &mut Vec<u8>) {
592        self.get_type().encode(bytes);
593
594        let mut sub: Vec<u8> = Vec::new();
595        match *self {
596            Self::ECPointFormats(ref r) => r.encode(&mut sub),
597            Self::NamedGroups(ref r) => r.encode(&mut sub),
598            Self::SignatureAlgorithms(ref r) => r.encode(&mut sub),
599            Self::ServerName(ref r) => r.encode(&mut sub),
600            Self::SessionTicket(ClientSessionTicket::Request)
601            | Self::ExtendedMasterSecretRequest
602            | Self::SignedCertificateTimestampRequest
603            | Self::EarlyData => {}
604            Self::SessionTicket(ClientSessionTicket::Offer(ref r)) => r.encode(&mut sub),
605            Self::Protocols(ref r) => r.encode(&mut sub),
606            Self::SupportedVersions(ref r) => r.encode(&mut sub),
607            Self::KeyShare(ref r) => r.encode(&mut sub),
608            Self::PresharedKeyModes(ref r) => r.encode(&mut sub),
609            Self::PresharedKey(ref r) => r.encode(&mut sub),
610            Self::Cookie(ref r) => r.encode(&mut sub),
611            Self::CertificateStatusRequest(ref r) => r.encode(&mut sub),
612            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
613                sub.extend_from_slice(r);
614            }
615            Self::Unknown(ref r) => r.encode(&mut sub),
616        }
617
618        (sub.len() as u16).encode(bytes);
619        bytes.append(&mut sub);
620    }
621
622    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
623        let typ = ExtensionType::read(r)?;
624        let len = u16::read(r)? as usize;
625        let mut sub = r.sub(len)?;
626
627        let ext = match typ {
628            ExtensionType::ECPointFormats => Self::ECPointFormats(Vec::read(&mut sub)?),
629            ExtensionType::EllipticCurves => Self::NamedGroups(Vec::read(&mut sub)?),
630            ExtensionType::SignatureAlgorithms => Self::SignatureAlgorithms(Vec::read(&mut sub)?),
631            ExtensionType::ServerName => Self::ServerName(Vec::read(&mut sub)?),
632            ExtensionType::SessionTicket => {
633                if sub.any_left() {
634                    let contents = Payload::read(&mut sub);
635                    Self::SessionTicket(ClientSessionTicket::Offer(contents))
636                } else {
637                    Self::SessionTicket(ClientSessionTicket::Request)
638                }
639            }
640            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
641            ExtensionType::SupportedVersions => Self::SupportedVersions(Vec::read(&mut sub)?),
642            ExtensionType::KeyShare => Self::KeyShare(Vec::read(&mut sub)?),
643            ExtensionType::PSKKeyExchangeModes => Self::PresharedKeyModes(Vec::read(&mut sub)?),
644            ExtensionType::PreSharedKey => Self::PresharedKey(PresharedKeyOffer::read(&mut sub)?),
645            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
646            ExtensionType::ExtendedMasterSecret if !sub.any_left() => {
647                Self::ExtendedMasterSecretRequest
648            }
649            ExtensionType::StatusRequest => {
650                let csr = CertificateStatusRequest::read(&mut sub)?;
651                Self::CertificateStatusRequest(csr)
652            }
653            ExtensionType::SCT if !sub.any_left() => Self::SignedCertificateTimestampRequest,
654            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
655            ExtensionType::TransportParametersDraft => {
656                Self::TransportParametersDraft(sub.rest().to_vec())
657            }
658            ExtensionType::EarlyData if !sub.any_left() => Self::EarlyData,
659            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
660        };
661
662        sub.expect_empty("ClientExtension")
663            .map(|_| ext)
664    }
665}
666
667fn trim_hostname_trailing_dot_for_sni(dns_name: DnsNameRef) -> DnsName {
668    let dns_name_str: &str = dns_name.as_ref();
669
670    // RFC6066: "The hostname is represented as a byte string using
671    // ASCII encoding without a trailing dot"
672    if dns_name_str.ends_with('.') {
673        let trimmed = &dns_name_str[0..dns_name_str.len() - 1];
674        DnsNameRef::try_from(trimmed)
675            .unwrap()
676            .to_owned()
677    } else {
678        dns_name.to_owned()
679    }
680}
681
682impl ClientExtension {
683    /// Make a basic SNI ServerNameRequest quoting `hostname`.
684    pub fn make_sni(dns_name: DnsNameRef) -> Self {
685        let name = ServerName {
686            typ: ServerNameType::HostName,
687            payload: ServerNamePayload::new_hostname(trim_hostname_trailing_dot_for_sni(dns_name)),
688        };
689
690        Self::ServerName(vec![name])
691    }
692}
693
694#[derive(Clone, Debug)]
695pub enum ClientSessionTicket {
696    Request,
697    Offer(Payload),
698}
699
700#[derive(Clone, Debug)]
701pub enum ServerExtension {
702    ECPointFormats(Vec<ECPointFormat>),
703    ServerNameAck,
704    SessionTicketAck,
705    RenegotiationInfo(PayloadU8),
706    Protocols(Vec<ProtocolName>),
707    KeyShare(KeyShareEntry),
708    PresharedKey(u16),
709    ExtendedMasterSecretAck,
710    CertificateStatusAck,
711    SignedCertificateTimestamp(Vec<Sct>),
712    SupportedVersions(ProtocolVersion),
713    TransportParameters(Vec<u8>),
714    TransportParametersDraft(Vec<u8>),
715    EarlyData,
716    Unknown(UnknownExtension),
717}
718
719impl ServerExtension {
720    pub fn get_type(&self) -> ExtensionType {
721        match *self {
722            Self::ECPointFormats(_) => ExtensionType::ECPointFormats,
723            Self::ServerNameAck => ExtensionType::ServerName,
724            Self::SessionTicketAck => ExtensionType::SessionTicket,
725            Self::RenegotiationInfo(_) => ExtensionType::RenegotiationInfo,
726            Self::Protocols(_) => ExtensionType::ALProtocolNegotiation,
727            Self::KeyShare(_) => ExtensionType::KeyShare,
728            Self::PresharedKey(_) => ExtensionType::PreSharedKey,
729            Self::ExtendedMasterSecretAck => ExtensionType::ExtendedMasterSecret,
730            Self::CertificateStatusAck => ExtensionType::StatusRequest,
731            Self::SignedCertificateTimestamp(_) => ExtensionType::SCT,
732            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
733            Self::TransportParameters(_) => ExtensionType::TransportParameters,
734            Self::TransportParametersDraft(_) => ExtensionType::TransportParametersDraft,
735            Self::EarlyData => ExtensionType::EarlyData,
736            Self::Unknown(ref r) => r.typ,
737        }
738    }
739}
740
741impl Codec for ServerExtension {
742    fn encode(&self, bytes: &mut Vec<u8>) {
743        self.get_type().encode(bytes);
744
745        let mut sub: Vec<u8> = Vec::new();
746        match *self {
747            Self::ECPointFormats(ref r) => r.encode(&mut sub),
748            Self::ServerNameAck
749            | Self::SessionTicketAck
750            | Self::ExtendedMasterSecretAck
751            | Self::CertificateStatusAck
752            | Self::EarlyData => {}
753            Self::RenegotiationInfo(ref r) => r.encode(&mut sub),
754            Self::Protocols(ref r) => r.encode(&mut sub),
755            Self::KeyShare(ref r) => r.encode(&mut sub),
756            Self::PresharedKey(r) => r.encode(&mut sub),
757            Self::SignedCertificateTimestamp(ref r) => r.encode(&mut sub),
758            Self::SupportedVersions(ref r) => r.encode(&mut sub),
759            Self::TransportParameters(ref r) | Self::TransportParametersDraft(ref r) => {
760                sub.extend_from_slice(r);
761            }
762            Self::Unknown(ref r) => r.encode(&mut sub),
763        }
764
765        (sub.len() as u16).encode(bytes);
766        bytes.append(&mut sub);
767    }
768
769    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
770        let typ = ExtensionType::read(r)?;
771        let len = u16::read(r)? as usize;
772        let mut sub = r.sub(len)?;
773
774        let ext = match typ {
775            ExtensionType::ECPointFormats => Self::ECPointFormats(Vec::read(&mut sub)?),
776            ExtensionType::ServerName => Self::ServerNameAck,
777            ExtensionType::SessionTicket => Self::SessionTicketAck,
778            ExtensionType::StatusRequest => Self::CertificateStatusAck,
779            ExtensionType::RenegotiationInfo => Self::RenegotiationInfo(PayloadU8::read(&mut sub)?),
780            ExtensionType::ALProtocolNegotiation => Self::Protocols(Vec::read(&mut sub)?),
781            ExtensionType::KeyShare => Self::KeyShare(KeyShareEntry::read(&mut sub)?),
782            ExtensionType::PreSharedKey => Self::PresharedKey(u16::read(&mut sub)?),
783            ExtensionType::ExtendedMasterSecret => Self::ExtendedMasterSecretAck,
784            ExtensionType::SCT => Self::SignedCertificateTimestamp(Vec::read(&mut sub)?),
785            ExtensionType::SupportedVersions => {
786                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
787            }
788            ExtensionType::TransportParameters => Self::TransportParameters(sub.rest().to_vec()),
789            ExtensionType::TransportParametersDraft => {
790                Self::TransportParametersDraft(sub.rest().to_vec())
791            }
792            ExtensionType::EarlyData => Self::EarlyData,
793            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
794        };
795
796        sub.expect_empty("ServerExtension")
797            .map(|_| ext)
798    }
799}
800
801impl ServerExtension {
802    pub fn make_alpn(proto: &[&[u8]]) -> Self {
803        Self::Protocols(Vec::from_slices(proto))
804    }
805
806    pub fn make_empty_renegotiation_info() -> Self {
807        let empty = Vec::new();
808        Self::RenegotiationInfo(PayloadU8::new(empty))
809    }
810
811    pub fn make_sct(sctl: Vec<u8>) -> Self {
812        let scts = Vec::read_bytes(&sctl).expect("invalid SCT list");
813        Self::SignedCertificateTimestamp(scts)
814    }
815}
816
817#[derive(Debug)]
818pub struct ClientHelloPayload {
819    pub client_version: ProtocolVersion,
820    pub random: Random,
821    pub session_id: SessionId,
822    pub cipher_suites: Vec<CipherSuite>,
823    pub compression_methods: Vec<Compression>,
824    pub extensions: Vec<ClientExtension>,
825}
826
827impl Codec for ClientHelloPayload {
828    fn encode(&self, bytes: &mut Vec<u8>) {
829        self.client_version.encode(bytes);
830        self.random.encode(bytes);
831        self.session_id.encode(bytes);
832        self.cipher_suites.encode(bytes);
833        self.compression_methods.encode(bytes);
834
835        if !self.extensions.is_empty() {
836            self.extensions.encode(bytes);
837        }
838    }
839
840    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
841        let mut ret = Self {
842            client_version: ProtocolVersion::read(r)?,
843            random: Random::read(r)?,
844            session_id: SessionId::read(r)?,
845            cipher_suites: Vec::read(r)?,
846            compression_methods: Vec::read(r)?,
847            extensions: Vec::new(),
848        };
849
850        if r.any_left() {
851            ret.extensions = Vec::read(r)?;
852        }
853
854        match (r.any_left(), ret.extensions.is_empty()) {
855            (true, _) => Err(InvalidMessage::TrailingData("ClientHelloPayload")),
856            (_, true) => Err(InvalidMessage::MissingData("ClientHelloPayload")),
857            _ => Ok(ret),
858        }
859    }
860}
861
862impl TlsListElement for CipherSuite {
863    const SIZE_LEN: ListLength = ListLength::U16;
864}
865
866impl TlsListElement for Compression {
867    const SIZE_LEN: ListLength = ListLength::U8;
868}
869
870impl TlsListElement for ClientExtension {
871    const SIZE_LEN: ListLength = ListLength::U16;
872}
873
874impl ClientHelloPayload {
875    /// Returns true if there is more than one extension of a given
876    /// type.
877    pub fn has_duplicate_extension(&self) -> bool {
878        let mut seen = collections::HashSet::new();
879
880        for ext in &self.extensions {
881            let typ = ext.get_type().get_u16();
882
883            if seen.contains(&typ) {
884                return true;
885            }
886            seen.insert(typ);
887        }
888
889        false
890    }
891
892    pub fn find_extension(&self, ext: ExtensionType) -> Option<&ClientExtension> {
893        self.extensions
894            .iter()
895            .find(|x| x.get_type() == ext)
896    }
897
898    pub fn get_sni_extension(&self) -> Option<&[ServerName]> {
899        let ext = self.find_extension(ExtensionType::ServerName)?;
900        match *ext {
901            // Does this comply with RFC6066?
902            //
903            // [RFC6066][] specifies that literal IP addresses are illegal in
904            // `ServerName`s with a `name_type` of `host_name`.
905            //
906            // Some clients incorrectly send such extensions: we choose to
907            // successfully parse these (into `ServerNamePayload::IpAddress`)
908            // but then act like the client sent no `server_name` extension.
909            //
910            // [RFC6066]: https://datatracker.ietf.org/doc/html/rfc6066#section-3
911            ClientExtension::ServerName(ref req)
912                if !req
913                    .iter()
914                    .any(|name| matches!(name.payload, ServerNamePayload::IpAddress(_))) =>
915            {
916                Some(req)
917            }
918            _ => None,
919        }
920    }
921
922    pub fn get_sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
923        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
924        match *ext {
925            ClientExtension::SignatureAlgorithms(ref req) => Some(req),
926            _ => None,
927        }
928    }
929
930    pub fn get_namedgroups_extension(&self) -> Option<&[NamedGroup]> {
931        let ext = self.find_extension(ExtensionType::EllipticCurves)?;
932        match *ext {
933            ClientExtension::NamedGroups(ref req) => Some(req),
934            _ => None,
935        }
936    }
937
938    pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
939        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
940        match *ext {
941            ClientExtension::ECPointFormats(ref req) => Some(req),
942            _ => None,
943        }
944    }
945
946    pub fn get_alpn_extension(&self) -> Option<&Vec<ProtocolName>> {
947        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
948        match *ext {
949            ClientExtension::Protocols(ref req) => Some(req),
950            _ => None,
951        }
952    }
953
954    pub fn get_quic_params_extension(&self) -> Option<Vec<u8>> {
955        let ext = self
956            .find_extension(ExtensionType::TransportParameters)
957            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
958        match *ext {
959            ClientExtension::TransportParameters(ref bytes)
960            | ClientExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
961            _ => None,
962        }
963    }
964
965    pub fn get_ticket_extension(&self) -> Option<&ClientExtension> {
966        self.find_extension(ExtensionType::SessionTicket)
967    }
968
969    pub fn get_versions_extension(&self) -> Option<&[ProtocolVersion]> {
970        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
971        match *ext {
972            ClientExtension::SupportedVersions(ref vers) => Some(vers),
973            _ => None,
974        }
975    }
976
977    pub fn get_keyshare_extension(&self) -> Option<&[KeyShareEntry]> {
978        let ext = self.find_extension(ExtensionType::KeyShare)?;
979        match *ext {
980            ClientExtension::KeyShare(ref shares) => Some(shares),
981            _ => None,
982        }
983    }
984
985    pub fn has_keyshare_extension_with_duplicates(&self) -> bool {
986        if let Some(entries) = self.get_keyshare_extension() {
987            let mut seen = collections::HashSet::new();
988
989            for kse in entries {
990                let grp = kse.group.get_u16();
991
992                if !seen.insert(grp) {
993                    return true;
994                }
995            }
996        }
997
998        false
999    }
1000
1001    pub fn get_psk(&self) -> Option<&PresharedKeyOffer> {
1002        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1003        match *ext {
1004            ClientExtension::PresharedKey(ref psk) => Some(psk),
1005            _ => None,
1006        }
1007    }
1008
1009    pub fn check_psk_ext_is_last(&self) -> bool {
1010        self.extensions
1011            .last()
1012            .map_or(false, |ext| ext.get_type() == ExtensionType::PreSharedKey)
1013    }
1014
1015    pub fn get_psk_modes(&self) -> Option<&[PSKKeyExchangeMode]> {
1016        let ext = self.find_extension(ExtensionType::PSKKeyExchangeModes)?;
1017        match *ext {
1018            ClientExtension::PresharedKeyModes(ref psk_modes) => Some(psk_modes),
1019            _ => None,
1020        }
1021    }
1022
1023    pub fn psk_mode_offered(&self, mode: PSKKeyExchangeMode) -> bool {
1024        self.get_psk_modes()
1025            .map(|modes| modes.contains(&mode))
1026            .unwrap_or(false)
1027    }
1028
1029    pub fn set_psk_binder(&mut self, binder: impl Into<Vec<u8>>) {
1030        let last_extension = self.extensions.last_mut();
1031        if let Some(ClientExtension::PresharedKey(ref mut offer)) = last_extension {
1032            offer.binders[0] = PresharedKeyBinder::from(binder.into());
1033        }
1034    }
1035
1036    pub fn ems_support_offered(&self) -> bool {
1037        self.find_extension(ExtensionType::ExtendedMasterSecret)
1038            .is_some()
1039    }
1040
1041    pub fn early_data_extension_offered(&self) -> bool {
1042        self.find_extension(ExtensionType::EarlyData)
1043            .is_some()
1044    }
1045}
1046
1047#[derive(Debug)]
1048pub enum HelloRetryExtension {
1049    KeyShare(NamedGroup),
1050    Cookie(PayloadU16),
1051    SupportedVersions(ProtocolVersion),
1052    Unknown(UnknownExtension),
1053}
1054
1055impl HelloRetryExtension {
1056    pub fn get_type(&self) -> ExtensionType {
1057        match *self {
1058            Self::KeyShare(_) => ExtensionType::KeyShare,
1059            Self::Cookie(_) => ExtensionType::Cookie,
1060            Self::SupportedVersions(_) => ExtensionType::SupportedVersions,
1061            Self::Unknown(ref r) => r.typ,
1062        }
1063    }
1064}
1065
1066impl Codec for HelloRetryExtension {
1067    fn encode(&self, bytes: &mut Vec<u8>) {
1068        self.get_type().encode(bytes);
1069
1070        let mut sub: Vec<u8> = Vec::new();
1071        match *self {
1072            Self::KeyShare(ref r) => r.encode(&mut sub),
1073            Self::Cookie(ref r) => r.encode(&mut sub),
1074            Self::SupportedVersions(ref r) => r.encode(&mut sub),
1075            Self::Unknown(ref r) => r.encode(&mut sub),
1076        }
1077
1078        (sub.len() as u16).encode(bytes);
1079        bytes.append(&mut sub);
1080    }
1081
1082    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1083        let typ = ExtensionType::read(r)?;
1084        let len = u16::read(r)? as usize;
1085        let mut sub = r.sub(len)?;
1086
1087        let ext = match typ {
1088            ExtensionType::KeyShare => Self::KeyShare(NamedGroup::read(&mut sub)?),
1089            ExtensionType::Cookie => Self::Cookie(PayloadU16::read(&mut sub)?),
1090            ExtensionType::SupportedVersions => {
1091                Self::SupportedVersions(ProtocolVersion::read(&mut sub)?)
1092            }
1093            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1094        };
1095
1096        sub.expect_empty("HelloRetryExtension")
1097            .map(|_| ext)
1098    }
1099}
1100
1101impl TlsListElement for HelloRetryExtension {
1102    const SIZE_LEN: ListLength = ListLength::U16;
1103}
1104
1105#[derive(Debug)]
1106pub struct HelloRetryRequest {
1107    pub legacy_version: ProtocolVersion,
1108    pub session_id: SessionId,
1109    pub cipher_suite: CipherSuite,
1110    pub extensions: Vec<HelloRetryExtension>,
1111}
1112
1113impl Codec for HelloRetryRequest {
1114    fn encode(&self, bytes: &mut Vec<u8>) {
1115        self.legacy_version.encode(bytes);
1116        HELLO_RETRY_REQUEST_RANDOM.encode(bytes);
1117        self.session_id.encode(bytes);
1118        self.cipher_suite.encode(bytes);
1119        Compression::Null.encode(bytes);
1120        self.extensions.encode(bytes);
1121    }
1122
1123    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1124        let session_id = SessionId::read(r)?;
1125        let cipher_suite = CipherSuite::read(r)?;
1126        let compression = Compression::read(r)?;
1127
1128        if compression != Compression::Null {
1129            return Err(InvalidMessage::UnsupportedCompression);
1130        }
1131
1132        Ok(Self {
1133            legacy_version: ProtocolVersion::Unknown(0),
1134            session_id,
1135            cipher_suite,
1136            extensions: Vec::read(r)?,
1137        })
1138    }
1139}
1140
1141impl HelloRetryRequest {
1142    /// Returns true if there is more than one extension of a given
1143    /// type.
1144    pub fn has_duplicate_extension(&self) -> bool {
1145        let mut seen = collections::HashSet::new();
1146
1147        for ext in &self.extensions {
1148            let typ = ext.get_type().get_u16();
1149
1150            if seen.contains(&typ) {
1151                return true;
1152            }
1153            seen.insert(typ);
1154        }
1155
1156        false
1157    }
1158
1159    pub fn has_unknown_extension(&self) -> bool {
1160        self.extensions.iter().any(|ext| {
1161            ext.get_type() != ExtensionType::KeyShare
1162                && ext.get_type() != ExtensionType::SupportedVersions
1163                && ext.get_type() != ExtensionType::Cookie
1164        })
1165    }
1166
1167    fn find_extension(&self, ext: ExtensionType) -> Option<&HelloRetryExtension> {
1168        self.extensions
1169            .iter()
1170            .find(|x| x.get_type() == ext)
1171    }
1172
1173    pub fn get_requested_key_share_group(&self) -> Option<NamedGroup> {
1174        let ext = self.find_extension(ExtensionType::KeyShare)?;
1175        match *ext {
1176            HelloRetryExtension::KeyShare(grp) => Some(grp),
1177            _ => None,
1178        }
1179    }
1180
1181    pub fn get_cookie(&self) -> Option<&PayloadU16> {
1182        let ext = self.find_extension(ExtensionType::Cookie)?;
1183        match *ext {
1184            HelloRetryExtension::Cookie(ref ck) => Some(ck),
1185            _ => None,
1186        }
1187    }
1188
1189    pub fn get_supported_versions(&self) -> Option<ProtocolVersion> {
1190        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1191        match *ext {
1192            HelloRetryExtension::SupportedVersions(ver) => Some(ver),
1193            _ => None,
1194        }
1195    }
1196}
1197
1198#[derive(Debug)]
1199pub struct ServerHelloPayload {
1200    pub legacy_version: ProtocolVersion,
1201    pub random: Random,
1202    pub session_id: SessionId,
1203    pub cipher_suite: CipherSuite,
1204    pub compression_method: Compression,
1205    pub extensions: Vec<ServerExtension>,
1206}
1207
1208impl Codec for ServerHelloPayload {
1209    fn encode(&self, bytes: &mut Vec<u8>) {
1210        self.legacy_version.encode(bytes);
1211        self.random.encode(bytes);
1212
1213        self.session_id.encode(bytes);
1214        self.cipher_suite.encode(bytes);
1215        self.compression_method.encode(bytes);
1216
1217        if !self.extensions.is_empty() {
1218            self.extensions.encode(bytes);
1219        }
1220    }
1221
1222    // minus version and random, which have already been read.
1223    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1224        let session_id = SessionId::read(r)?;
1225        let suite = CipherSuite::read(r)?;
1226        let compression = Compression::read(r)?;
1227
1228        // RFC5246:
1229        // "The presence of extensions can be detected by determining whether
1230        //  there are bytes following the compression_method field at the end of
1231        //  the ServerHello."
1232        let extensions = if r.any_left() { Vec::read(r)? } else { vec![] };
1233
1234        let ret = Self {
1235            legacy_version: ProtocolVersion::Unknown(0),
1236            random: ZERO_RANDOM,
1237            session_id,
1238            cipher_suite: suite,
1239            compression_method: compression,
1240            extensions,
1241        };
1242
1243        r.expect_empty("ServerHelloPayload")
1244            .map(|_| ret)
1245    }
1246}
1247
1248impl HasServerExtensions for ServerHelloPayload {
1249    fn get_extensions(&self) -> &[ServerExtension] {
1250        &self.extensions
1251    }
1252}
1253
1254impl ServerHelloPayload {
1255    pub fn get_key_share(&self) -> Option<&KeyShareEntry> {
1256        let ext = self.find_extension(ExtensionType::KeyShare)?;
1257        match *ext {
1258            ServerExtension::KeyShare(ref share) => Some(share),
1259            _ => None,
1260        }
1261    }
1262
1263    pub fn get_psk_index(&self) -> Option<u16> {
1264        let ext = self.find_extension(ExtensionType::PreSharedKey)?;
1265        match *ext {
1266            ServerExtension::PresharedKey(ref index) => Some(*index),
1267            _ => None,
1268        }
1269    }
1270
1271    pub fn get_ecpoints_extension(&self) -> Option<&[ECPointFormat]> {
1272        let ext = self.find_extension(ExtensionType::ECPointFormats)?;
1273        match *ext {
1274            ServerExtension::ECPointFormats(ref fmts) => Some(fmts),
1275            _ => None,
1276        }
1277    }
1278
1279    pub fn ems_support_acked(&self) -> bool {
1280        self.find_extension(ExtensionType::ExtendedMasterSecret)
1281            .is_some()
1282    }
1283
1284    pub fn get_sct_list(&self) -> Option<&[Sct]> {
1285        let ext = self.find_extension(ExtensionType::SCT)?;
1286        match *ext {
1287            ServerExtension::SignedCertificateTimestamp(ref sctl) => Some(sctl),
1288            _ => None,
1289        }
1290    }
1291
1292    pub fn get_supported_versions(&self) -> Option<ProtocolVersion> {
1293        let ext = self.find_extension(ExtensionType::SupportedVersions)?;
1294        match *ext {
1295            ServerExtension::SupportedVersions(vers) => Some(vers),
1296            _ => None,
1297        }
1298    }
1299}
1300
1301pub type CertificatePayload = Vec<key::Certificate>;
1302
1303impl TlsListElement for key::Certificate {
1304    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1305}
1306
1307// TLS1.3 changes the Certificate payload encoding.
1308// That's annoying. It means the parsing is not
1309// context-free any more.
1310
1311#[derive(Debug)]
1312pub enum CertificateExtension {
1313    CertificateStatus(CertificateStatus),
1314    SignedCertificateTimestamp(Vec<Sct>),
1315    Unknown(UnknownExtension),
1316}
1317
1318impl CertificateExtension {
1319    pub fn get_type(&self) -> ExtensionType {
1320        match *self {
1321            Self::CertificateStatus(_) => ExtensionType::StatusRequest,
1322            Self::SignedCertificateTimestamp(_) => ExtensionType::SCT,
1323            Self::Unknown(ref r) => r.typ,
1324        }
1325    }
1326
1327    pub fn make_sct(sct_list: Vec<u8>) -> Self {
1328        let sctl = Vec::read_bytes(&sct_list).expect("invalid SCT list");
1329        Self::SignedCertificateTimestamp(sctl)
1330    }
1331
1332    pub fn get_cert_status(&self) -> Option<&Vec<u8>> {
1333        match *self {
1334            Self::CertificateStatus(ref cs) => Some(&cs.ocsp_response.0),
1335            _ => None,
1336        }
1337    }
1338
1339    pub fn get_sct_list(&self) -> Option<&[Sct]> {
1340        match *self {
1341            Self::SignedCertificateTimestamp(ref sctl) => Some(sctl),
1342            _ => None,
1343        }
1344    }
1345}
1346
1347impl Codec for CertificateExtension {
1348    fn encode(&self, bytes: &mut Vec<u8>) {
1349        self.get_type().encode(bytes);
1350
1351        let mut sub: Vec<u8> = Vec::new();
1352        match *self {
1353            Self::CertificateStatus(ref r) => r.encode(&mut sub),
1354            Self::SignedCertificateTimestamp(ref r) => r.encode(&mut sub),
1355            Self::Unknown(ref r) => r.encode(&mut sub),
1356        }
1357
1358        (sub.len() as u16).encode(bytes);
1359        bytes.append(&mut sub);
1360    }
1361
1362    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1363        let typ = ExtensionType::read(r)?;
1364        let len = u16::read(r)? as usize;
1365        let mut sub = r.sub(len)?;
1366
1367        let ext = match typ {
1368            ExtensionType::StatusRequest => {
1369                let st = CertificateStatus::read(&mut sub)?;
1370                Self::CertificateStatus(st)
1371            }
1372            ExtensionType::SCT => Self::SignedCertificateTimestamp(Vec::read(&mut sub)?),
1373            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1374        };
1375
1376        sub.expect_empty("CertificateExtension")
1377            .map(|_| ext)
1378    }
1379}
1380
1381impl TlsListElement for CertificateExtension {
1382    const SIZE_LEN: ListLength = ListLength::U16;
1383}
1384
1385#[derive(Debug)]
1386pub struct CertificateEntry {
1387    pub cert: key::Certificate,
1388    pub exts: Vec<CertificateExtension>,
1389}
1390
1391impl Codec for CertificateEntry {
1392    fn encode(&self, bytes: &mut Vec<u8>) {
1393        self.cert.encode(bytes);
1394        self.exts.encode(bytes);
1395    }
1396
1397    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1398        Ok(Self {
1399            cert: key::Certificate::read(r)?,
1400            exts: Vec::read(r)?,
1401        })
1402    }
1403}
1404
1405impl CertificateEntry {
1406    pub fn new(cert: key::Certificate) -> Self {
1407        Self {
1408            cert,
1409            exts: Vec::new(),
1410        }
1411    }
1412
1413    pub fn has_duplicate_extension(&self) -> bool {
1414        let mut seen = collections::HashSet::new();
1415
1416        for ext in &self.exts {
1417            let typ = ext.get_type().get_u16();
1418
1419            if seen.contains(&typ) {
1420                return true;
1421            }
1422            seen.insert(typ);
1423        }
1424
1425        false
1426    }
1427
1428    pub fn has_unknown_extension(&self) -> bool {
1429        self.exts.iter().any(|ext| {
1430            ext.get_type() != ExtensionType::StatusRequest && ext.get_type() != ExtensionType::SCT
1431        })
1432    }
1433
1434    pub fn get_ocsp_response(&self) -> Option<&Vec<u8>> {
1435        self.exts
1436            .iter()
1437            .find(|ext| ext.get_type() == ExtensionType::StatusRequest)
1438            .and_then(CertificateExtension::get_cert_status)
1439    }
1440
1441    pub fn get_scts(&self) -> Option<&[Sct]> {
1442        self.exts
1443            .iter()
1444            .find(|ext| ext.get_type() == ExtensionType::SCT)
1445            .and_then(CertificateExtension::get_sct_list)
1446    }
1447}
1448
1449impl TlsListElement for CertificateEntry {
1450    const SIZE_LEN: ListLength = ListLength::U24 { max: 0x1_0000 };
1451}
1452
1453#[derive(Debug)]
1454pub struct CertificatePayloadTLS13 {
1455    pub context: PayloadU8,
1456    pub entries: Vec<CertificateEntry>,
1457}
1458
1459impl Codec for CertificatePayloadTLS13 {
1460    fn encode(&self, bytes: &mut Vec<u8>) {
1461        self.context.encode(bytes);
1462        self.entries.encode(bytes);
1463    }
1464
1465    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1466        Ok(Self {
1467            context: PayloadU8::read(r)?,
1468            entries: Vec::read(r)?,
1469        })
1470    }
1471}
1472
1473impl CertificatePayloadTLS13 {
1474    pub fn new(entries: Vec<CertificateEntry>) -> Self {
1475        Self {
1476            context: PayloadU8::empty(),
1477            entries,
1478        }
1479    }
1480
1481    pub fn any_entry_has_duplicate_extension(&self) -> bool {
1482        for entry in &self.entries {
1483            if entry.has_duplicate_extension() {
1484                return true;
1485            }
1486        }
1487
1488        false
1489    }
1490
1491    pub fn any_entry_has_unknown_extension(&self) -> bool {
1492        for entry in &self.entries {
1493            if entry.has_unknown_extension() {
1494                return true;
1495            }
1496        }
1497
1498        false
1499    }
1500
1501    pub fn any_entry_has_extension(&self) -> bool {
1502        for entry in &self.entries {
1503            if !entry.exts.is_empty() {
1504                return true;
1505            }
1506        }
1507
1508        false
1509    }
1510
1511    pub fn get_end_entity_ocsp(&self) -> Vec<u8> {
1512        self.entries
1513            .first()
1514            .and_then(CertificateEntry::get_ocsp_response)
1515            .cloned()
1516            .unwrap_or_default()
1517    }
1518
1519    pub fn get_end_entity_scts(&self) -> Option<&[Sct]> {
1520        self.entries
1521            .first()
1522            .and_then(CertificateEntry::get_scts)
1523    }
1524
1525    pub fn convert(&self) -> CertificatePayload {
1526        let mut ret = Vec::new();
1527        for entry in &self.entries {
1528            ret.push(entry.cert.clone());
1529        }
1530        ret
1531    }
1532}
1533
1534#[derive(Clone, Copy, Debug, PartialEq)]
1535pub enum KeyExchangeAlgorithm {
1536    BulkOnly,
1537    DH,
1538    DHE,
1539    RSA,
1540    ECDH,
1541    ECDHE,
1542}
1543
1544// We don't support arbitrary curves.  It's a terrible
1545// idea and unnecessary attack surface.  Please,
1546// get a grip.
1547#[derive(Debug)]
1548pub struct ECParameters {
1549    pub curve_type: ECCurveType,
1550    pub named_group: NamedGroup,
1551}
1552
1553impl Codec for ECParameters {
1554    fn encode(&self, bytes: &mut Vec<u8>) {
1555        self.curve_type.encode(bytes);
1556        self.named_group.encode(bytes);
1557    }
1558
1559    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1560        let ct = ECCurveType::read(r)?;
1561        if ct != ECCurveType::NamedCurve {
1562            return Err(InvalidMessage::UnsupportedCurveType);
1563        }
1564
1565        let grp = NamedGroup::read(r)?;
1566
1567        Ok(Self {
1568            curve_type: ct,
1569            named_group: grp,
1570        })
1571    }
1572}
1573
1574#[derive(Debug)]
1575pub struct ClientECDHParams {
1576    pub public: PayloadU8,
1577}
1578
1579impl Codec for ClientECDHParams {
1580    fn encode(&self, bytes: &mut Vec<u8>) {
1581        self.public.encode(bytes);
1582    }
1583
1584    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1585        let pb = PayloadU8::read(r)?;
1586        Ok(Self { public: pb })
1587    }
1588}
1589
1590#[derive(Debug)]
1591pub struct ServerECDHParams {
1592    pub curve_params: ECParameters,
1593    pub public: PayloadU8,
1594}
1595
1596impl ServerECDHParams {
1597    pub fn new(named_group: NamedGroup, pubkey: &[u8]) -> Self {
1598        Self {
1599            curve_params: ECParameters {
1600                curve_type: ECCurveType::NamedCurve,
1601                named_group,
1602            },
1603            public: PayloadU8::new(pubkey.to_vec()),
1604        }
1605    }
1606}
1607
1608impl Codec for ServerECDHParams {
1609    fn encode(&self, bytes: &mut Vec<u8>) {
1610        self.curve_params.encode(bytes);
1611        self.public.encode(bytes);
1612    }
1613
1614    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1615        let cp = ECParameters::read(r)?;
1616        let pb = PayloadU8::read(r)?;
1617
1618        Ok(Self {
1619            curve_params: cp,
1620            public: pb,
1621        })
1622    }
1623}
1624
1625#[derive(Debug)]
1626pub struct ECDHEServerKeyExchange {
1627    pub params: ServerECDHParams,
1628    pub dss: DigitallySignedStruct,
1629}
1630
1631impl Codec for ECDHEServerKeyExchange {
1632    fn encode(&self, bytes: &mut Vec<u8>) {
1633        self.params.encode(bytes);
1634        self.dss.encode(bytes);
1635    }
1636
1637    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1638        let params = ServerECDHParams::read(r)?;
1639        let dss = DigitallySignedStruct::read(r)?;
1640
1641        Ok(Self { params, dss })
1642    }
1643}
1644
1645#[derive(Debug)]
1646pub enum ServerKeyExchangePayload {
1647    ECDHE(ECDHEServerKeyExchange),
1648    Unknown(Payload),
1649}
1650
1651impl Codec for ServerKeyExchangePayload {
1652    fn encode(&self, bytes: &mut Vec<u8>) {
1653        match *self {
1654            Self::ECDHE(ref x) => x.encode(bytes),
1655            Self::Unknown(ref x) => x.encode(bytes),
1656        }
1657    }
1658
1659    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1660        // read as Unknown, fully parse when we know the
1661        // KeyExchangeAlgorithm
1662        Ok(Self::Unknown(Payload::read(r)))
1663    }
1664}
1665
1666impl ServerKeyExchangePayload {
1667    pub fn unwrap_given_kxa(&self, kxa: KeyExchangeAlgorithm) -> Option<ECDHEServerKeyExchange> {
1668        if let Self::Unknown(ref unk) = *self {
1669            let mut rd = Reader::init(&unk.0);
1670
1671            let result = match kxa {
1672                KeyExchangeAlgorithm::ECDHE => ECDHEServerKeyExchange::read(&mut rd),
1673                _ => return None,
1674            };
1675
1676            if !rd.any_left() {
1677                return result.ok();
1678            };
1679        }
1680
1681        None
1682    }
1683}
1684
1685// -- EncryptedExtensions (TLS1.3 only) --
1686
1687impl TlsListElement for ServerExtension {
1688    const SIZE_LEN: ListLength = ListLength::U16;
1689}
1690
1691pub trait HasServerExtensions {
1692    fn get_extensions(&self) -> &[ServerExtension];
1693
1694    /// Returns true if there is more than one extension of a given
1695    /// type.
1696    fn has_duplicate_extension(&self) -> bool {
1697        let mut seen = collections::HashSet::new();
1698
1699        for ext in self.get_extensions() {
1700            let typ = ext.get_type().get_u16();
1701
1702            if seen.contains(&typ) {
1703                return true;
1704            }
1705            seen.insert(typ);
1706        }
1707
1708        false
1709    }
1710
1711    fn find_extension(&self, ext: ExtensionType) -> Option<&ServerExtension> {
1712        self.get_extensions()
1713            .iter()
1714            .find(|x| x.get_type() == ext)
1715    }
1716
1717    fn get_alpn_protocol(&self) -> Option<&[u8]> {
1718        let ext = self.find_extension(ExtensionType::ALProtocolNegotiation)?;
1719        match *ext {
1720            ServerExtension::Protocols(ref protos) => protos.as_single_slice(),
1721            _ => None,
1722        }
1723    }
1724
1725    fn get_quic_params_extension(&self) -> Option<Vec<u8>> {
1726        let ext = self
1727            .find_extension(ExtensionType::TransportParameters)
1728            .or_else(|| self.find_extension(ExtensionType::TransportParametersDraft))?;
1729        match *ext {
1730            ServerExtension::TransportParameters(ref bytes)
1731            | ServerExtension::TransportParametersDraft(ref bytes) => Some(bytes.to_vec()),
1732            _ => None,
1733        }
1734    }
1735
1736    fn early_data_extension_offered(&self) -> bool {
1737        self.find_extension(ExtensionType::EarlyData)
1738            .is_some()
1739    }
1740}
1741
1742impl HasServerExtensions for Vec<ServerExtension> {
1743    fn get_extensions(&self) -> &[ServerExtension] {
1744        self
1745    }
1746}
1747
1748impl TlsListElement for ClientCertificateType {
1749    const SIZE_LEN: ListLength = ListLength::U8;
1750}
1751
1752wrapped_payload!(
1753    /// A `DistinguishedName` is a `Vec<u8>` wrapped in internal types.
1754    ///
1755    /// It contains the DER or BER encoded [`Subject` field from RFC 5280](https://datatracker.ietf.org/doc/html/rfc5280#section-4.1.2.6)
1756    /// for a single certificate. The Subject field is [encoded as an RFC 5280 `Name`](https://datatracker.ietf.org/doc/html/rfc5280#page-116).
1757    /// It can be decoded using [x509-parser's FromDer trait](https://docs.rs/x509-parser/latest/x509_parser/prelude/trait.FromDer.html).
1758    ///
1759    /// ```ignore
1760    /// for name in distinguished_names {
1761    ///     use x509_parser::prelude::FromDer;
1762    ///     println!("{}", x509_parser::x509::X509Name::from_der(&name.0)?.1);
1763    /// }
1764    /// ```
1765    DistinguishedName,
1766    PayloadU16,
1767);
1768
1769impl TlsListElement for DistinguishedName {
1770    const SIZE_LEN: ListLength = ListLength::U16;
1771}
1772
1773#[derive(Debug)]
1774pub struct CertificateRequestPayload {
1775    pub certtypes: Vec<ClientCertificateType>,
1776    pub sigschemes: Vec<SignatureScheme>,
1777    pub canames: Vec<DistinguishedName>,
1778}
1779
1780impl Codec for CertificateRequestPayload {
1781    fn encode(&self, bytes: &mut Vec<u8>) {
1782        self.certtypes.encode(bytes);
1783        self.sigschemes.encode(bytes);
1784        self.canames.encode(bytes);
1785    }
1786
1787    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1788        let certtypes = Vec::read(r)?;
1789        let sigschemes = Vec::read(r)?;
1790        let canames = Vec::read(r)?;
1791
1792        if sigschemes.is_empty() {
1793            warn!("meaningless CertificateRequest message");
1794            Err(InvalidMessage::NoSignatureSchemes)
1795        } else {
1796            Ok(Self {
1797                certtypes,
1798                sigschemes,
1799                canames,
1800            })
1801        }
1802    }
1803}
1804
1805#[derive(Debug)]
1806pub enum CertReqExtension {
1807    SignatureAlgorithms(Vec<SignatureScheme>),
1808    AuthorityNames(Vec<DistinguishedName>),
1809    Unknown(UnknownExtension),
1810}
1811
1812impl CertReqExtension {
1813    pub fn get_type(&self) -> ExtensionType {
1814        match *self {
1815            Self::SignatureAlgorithms(_) => ExtensionType::SignatureAlgorithms,
1816            Self::AuthorityNames(_) => ExtensionType::CertificateAuthorities,
1817            Self::Unknown(ref r) => r.typ,
1818        }
1819    }
1820}
1821
1822impl Codec for CertReqExtension {
1823    fn encode(&self, bytes: &mut Vec<u8>) {
1824        self.get_type().encode(bytes);
1825
1826        let mut sub: Vec<u8> = Vec::new();
1827        match *self {
1828            Self::SignatureAlgorithms(ref r) => r.encode(&mut sub),
1829            Self::AuthorityNames(ref r) => r.encode(&mut sub),
1830            Self::Unknown(ref r) => r.encode(&mut sub),
1831        }
1832
1833        (sub.len() as u16).encode(bytes);
1834        bytes.append(&mut sub);
1835    }
1836
1837    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1838        let typ = ExtensionType::read(r)?;
1839        let len = u16::read(r)? as usize;
1840        let mut sub = r.sub(len)?;
1841
1842        let ext = match typ {
1843            ExtensionType::SignatureAlgorithms => {
1844                let schemes = Vec::read(&mut sub)?;
1845                if schemes.is_empty() {
1846                    return Err(InvalidMessage::NoSignatureSchemes);
1847                }
1848                Self::SignatureAlgorithms(schemes)
1849            }
1850            ExtensionType::CertificateAuthorities => {
1851                let cas = Vec::read(&mut sub)?;
1852                Self::AuthorityNames(cas)
1853            }
1854            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1855        };
1856
1857        sub.expect_empty("CertReqExtension")
1858            .map(|_| ext)
1859    }
1860}
1861
1862impl TlsListElement for CertReqExtension {
1863    const SIZE_LEN: ListLength = ListLength::U16;
1864}
1865
1866#[derive(Debug)]
1867pub struct CertificateRequestPayloadTLS13 {
1868    pub context: PayloadU8,
1869    pub extensions: Vec<CertReqExtension>,
1870}
1871
1872impl Codec for CertificateRequestPayloadTLS13 {
1873    fn encode(&self, bytes: &mut Vec<u8>) {
1874        self.context.encode(bytes);
1875        self.extensions.encode(bytes);
1876    }
1877
1878    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1879        let context = PayloadU8::read(r)?;
1880        let extensions = Vec::read(r)?;
1881
1882        Ok(Self {
1883            context,
1884            extensions,
1885        })
1886    }
1887}
1888
1889impl CertificateRequestPayloadTLS13 {
1890    pub fn find_extension(&self, ext: ExtensionType) -> Option<&CertReqExtension> {
1891        self.extensions
1892            .iter()
1893            .find(|x| x.get_type() == ext)
1894    }
1895
1896    pub fn get_sigalgs_extension(&self) -> Option<&[SignatureScheme]> {
1897        let ext = self.find_extension(ExtensionType::SignatureAlgorithms)?;
1898        match *ext {
1899            CertReqExtension::SignatureAlgorithms(ref sa) => Some(sa),
1900            _ => None,
1901        }
1902    }
1903
1904    pub fn get_authorities_extension(&self) -> Option<&[DistinguishedName]> {
1905        let ext = self.find_extension(ExtensionType::CertificateAuthorities)?;
1906        match *ext {
1907            CertReqExtension::AuthorityNames(ref an) => Some(an),
1908            _ => None,
1909        }
1910    }
1911}
1912
1913// -- NewSessionTicket --
1914#[derive(Debug)]
1915pub struct NewSessionTicketPayload {
1916    pub lifetime_hint: u32,
1917    pub ticket: PayloadU16,
1918}
1919
1920impl NewSessionTicketPayload {
1921    pub fn new(lifetime_hint: u32, ticket: Vec<u8>) -> Self {
1922        Self {
1923            lifetime_hint,
1924            ticket: PayloadU16::new(ticket),
1925        }
1926    }
1927}
1928
1929impl Codec for NewSessionTicketPayload {
1930    fn encode(&self, bytes: &mut Vec<u8>) {
1931        self.lifetime_hint.encode(bytes);
1932        self.ticket.encode(bytes);
1933    }
1934
1935    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1936        let lifetime = u32::read(r)?;
1937        let ticket = PayloadU16::read(r)?;
1938
1939        Ok(Self {
1940            lifetime_hint: lifetime,
1941            ticket,
1942        })
1943    }
1944}
1945
1946// -- NewSessionTicket electric boogaloo --
1947#[derive(Debug)]
1948pub enum NewSessionTicketExtension {
1949    EarlyData(u32),
1950    Unknown(UnknownExtension),
1951}
1952
1953impl NewSessionTicketExtension {
1954    pub fn get_type(&self) -> ExtensionType {
1955        match *self {
1956            Self::EarlyData(_) => ExtensionType::EarlyData,
1957            Self::Unknown(ref r) => r.typ,
1958        }
1959    }
1960}
1961
1962impl Codec for NewSessionTicketExtension {
1963    fn encode(&self, bytes: &mut Vec<u8>) {
1964        self.get_type().encode(bytes);
1965
1966        let mut sub: Vec<u8> = Vec::new();
1967        match *self {
1968            Self::EarlyData(r) => r.encode(&mut sub),
1969            Self::Unknown(ref r) => r.encode(&mut sub),
1970        }
1971
1972        (sub.len() as u16).encode(bytes);
1973        bytes.append(&mut sub);
1974    }
1975
1976    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
1977        let typ = ExtensionType::read(r)?;
1978        let len = u16::read(r)? as usize;
1979        let mut sub = r.sub(len)?;
1980
1981        let ext = match typ {
1982            ExtensionType::EarlyData => Self::EarlyData(u32::read(&mut sub)?),
1983            _ => Self::Unknown(UnknownExtension::read(typ, &mut sub)),
1984        };
1985
1986        sub.expect_empty("NewSessionTicketExtension")
1987            .map(|_| ext)
1988    }
1989}
1990
1991impl TlsListElement for NewSessionTicketExtension {
1992    const SIZE_LEN: ListLength = ListLength::U16;
1993}
1994
1995#[derive(Debug)]
1996pub struct NewSessionTicketPayloadTLS13 {
1997    pub lifetime: u32,
1998    pub age_add: u32,
1999    pub nonce: PayloadU8,
2000    pub ticket: PayloadU16,
2001    pub exts: Vec<NewSessionTicketExtension>,
2002}
2003
2004impl NewSessionTicketPayloadTLS13 {
2005    pub fn new(lifetime: u32, age_add: u32, nonce: Vec<u8>, ticket: Vec<u8>) -> Self {
2006        Self {
2007            lifetime,
2008            age_add,
2009            nonce: PayloadU8::new(nonce),
2010            ticket: PayloadU16::new(ticket),
2011            exts: vec![],
2012        }
2013    }
2014
2015    pub fn has_duplicate_extension(&self) -> bool {
2016        let mut seen = collections::HashSet::new();
2017
2018        for ext in &self.exts {
2019            let typ = ext.get_type().get_u16();
2020
2021            if seen.contains(&typ) {
2022                return true;
2023            }
2024            seen.insert(typ);
2025        }
2026
2027        false
2028    }
2029
2030    pub fn find_extension(&self, ext: ExtensionType) -> Option<&NewSessionTicketExtension> {
2031        self.exts
2032            .iter()
2033            .find(|x| x.get_type() == ext)
2034    }
2035
2036    pub fn get_max_early_data_size(&self) -> Option<u32> {
2037        let ext = self.find_extension(ExtensionType::EarlyData)?;
2038        match *ext {
2039            NewSessionTicketExtension::EarlyData(ref sz) => Some(*sz),
2040            _ => None,
2041        }
2042    }
2043}
2044
2045impl Codec for NewSessionTicketPayloadTLS13 {
2046    fn encode(&self, bytes: &mut Vec<u8>) {
2047        self.lifetime.encode(bytes);
2048        self.age_add.encode(bytes);
2049        self.nonce.encode(bytes);
2050        self.ticket.encode(bytes);
2051        self.exts.encode(bytes);
2052    }
2053
2054    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2055        let lifetime = u32::read(r)?;
2056        let age_add = u32::read(r)?;
2057        let nonce = PayloadU8::read(r)?;
2058        let ticket = PayloadU16::read(r)?;
2059        let exts = Vec::read(r)?;
2060
2061        Ok(Self {
2062            lifetime,
2063            age_add,
2064            nonce,
2065            ticket,
2066            exts,
2067        })
2068    }
2069}
2070
2071// -- RFC6066 certificate status types
2072
2073/// Only supports OCSP
2074#[derive(Debug)]
2075pub struct CertificateStatus {
2076    pub ocsp_response: PayloadU24,
2077}
2078
2079impl Codec for CertificateStatus {
2080    fn encode(&self, bytes: &mut Vec<u8>) {
2081        CertificateStatusType::OCSP.encode(bytes);
2082        self.ocsp_response.encode(bytes);
2083    }
2084
2085    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2086        let typ = CertificateStatusType::read(r)?;
2087
2088        match typ {
2089            CertificateStatusType::OCSP => Ok(Self {
2090                ocsp_response: PayloadU24::read(r)?,
2091            }),
2092            _ => Err(InvalidMessage::InvalidCertificateStatusType),
2093        }
2094    }
2095}
2096
2097impl CertificateStatus {
2098    pub fn new(ocsp: Vec<u8>) -> Self {
2099        Self {
2100            ocsp_response: PayloadU24::new(ocsp),
2101        }
2102    }
2103
2104    pub fn into_inner(self) -> Vec<u8> {
2105        self.ocsp_response.0
2106    }
2107}
2108
2109#[derive(Debug)]
2110pub enum HandshakePayload {
2111    HelloRequest,
2112    ClientHello(ClientHelloPayload),
2113    ServerHello(ServerHelloPayload),
2114    HelloRetryRequest(HelloRetryRequest),
2115    Certificate(CertificatePayload),
2116    CertificateTLS13(CertificatePayloadTLS13),
2117    ServerKeyExchange(ServerKeyExchangePayload),
2118    CertificateRequest(CertificateRequestPayload),
2119    CertificateRequestTLS13(CertificateRequestPayloadTLS13),
2120    CertificateVerify(DigitallySignedStruct),
2121    ServerHelloDone,
2122    EndOfEarlyData,
2123    ClientKeyExchange(Payload),
2124    NewSessionTicket(NewSessionTicketPayload),
2125    NewSessionTicketTLS13(NewSessionTicketPayloadTLS13),
2126    EncryptedExtensions(Vec<ServerExtension>),
2127    KeyUpdate(KeyUpdateRequest),
2128    Finished(Payload),
2129    CertificateStatus(CertificateStatus),
2130    MessageHash(Payload),
2131    Unknown(Payload),
2132}
2133
2134impl HandshakePayload {
2135    fn encode(&self, bytes: &mut Vec<u8>) {
2136        use self::HandshakePayload::*;
2137        match *self {
2138            HelloRequest | ServerHelloDone | EndOfEarlyData => {}
2139            ClientHello(ref x) => x.encode(bytes),
2140            ServerHello(ref x) => x.encode(bytes),
2141            HelloRetryRequest(ref x) => x.encode(bytes),
2142            Certificate(ref x) => x.encode(bytes),
2143            CertificateTLS13(ref x) => x.encode(bytes),
2144            ServerKeyExchange(ref x) => x.encode(bytes),
2145            ClientKeyExchange(ref x) => x.encode(bytes),
2146            CertificateRequest(ref x) => x.encode(bytes),
2147            CertificateRequestTLS13(ref x) => x.encode(bytes),
2148            CertificateVerify(ref x) => x.encode(bytes),
2149            NewSessionTicket(ref x) => x.encode(bytes),
2150            NewSessionTicketTLS13(ref x) => x.encode(bytes),
2151            EncryptedExtensions(ref x) => x.encode(bytes),
2152            KeyUpdate(ref x) => x.encode(bytes),
2153            Finished(ref x) => x.encode(bytes),
2154            CertificateStatus(ref x) => x.encode(bytes),
2155            MessageHash(ref x) => x.encode(bytes),
2156            Unknown(ref x) => x.encode(bytes),
2157        }
2158    }
2159}
2160
2161#[derive(Debug)]
2162pub struct HandshakeMessagePayload {
2163    pub typ: HandshakeType,
2164    pub payload: HandshakePayload,
2165}
2166
2167impl Codec for HandshakeMessagePayload {
2168    fn encode(&self, bytes: &mut Vec<u8>) {
2169        // encode payload to learn length
2170        let mut sub: Vec<u8> = Vec::new();
2171        self.payload.encode(&mut sub);
2172
2173        // output type, length, and encoded payload
2174        match self.typ {
2175            HandshakeType::HelloRetryRequest => HandshakeType::ServerHello,
2176            _ => self.typ,
2177        }
2178        .encode(bytes);
2179        codec::u24(sub.len() as u32).encode(bytes);
2180        bytes.append(&mut sub);
2181    }
2182
2183    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
2184        Self::read_version(r, ProtocolVersion::TLSv1_2)
2185    }
2186}
2187
2188impl HandshakeMessagePayload {
2189    pub fn read_version(r: &mut Reader, vers: ProtocolVersion) -> Result<Self, InvalidMessage> {
2190        let mut typ = HandshakeType::read(r)?;
2191        let len = codec::u24::read(r)?.0 as usize;
2192        let mut sub = r.sub(len)?;
2193
2194        let payload = match typ {
2195            HandshakeType::HelloRequest if sub.left() == 0 => HandshakePayload::HelloRequest,
2196            HandshakeType::ClientHello => {
2197                HandshakePayload::ClientHello(ClientHelloPayload::read(&mut sub)?)
2198            }
2199            HandshakeType::ServerHello => {
2200                let version = ProtocolVersion::read(&mut sub)?;
2201                let random = Random::read(&mut sub)?;
2202
2203                if random == HELLO_RETRY_REQUEST_RANDOM {
2204                    let mut hrr = HelloRetryRequest::read(&mut sub)?;
2205                    hrr.legacy_version = version;
2206                    typ = HandshakeType::HelloRetryRequest;
2207                    HandshakePayload::HelloRetryRequest(hrr)
2208                } else {
2209                    let mut shp = ServerHelloPayload::read(&mut sub)?;
2210                    shp.legacy_version = version;
2211                    shp.random = random;
2212                    HandshakePayload::ServerHello(shp)
2213                }
2214            }
2215            HandshakeType::Certificate if vers == ProtocolVersion::TLSv1_3 => {
2216                let p = CertificatePayloadTLS13::read(&mut sub)?;
2217                HandshakePayload::CertificateTLS13(p)
2218            }
2219            HandshakeType::Certificate => {
2220                HandshakePayload::Certificate(CertificatePayload::read(&mut sub)?)
2221            }
2222            HandshakeType::ServerKeyExchange => {
2223                let p = ServerKeyExchangePayload::read(&mut sub)?;
2224                HandshakePayload::ServerKeyExchange(p)
2225            }
2226            HandshakeType::ServerHelloDone => {
2227                sub.expect_empty("ServerHelloDone")?;
2228                HandshakePayload::ServerHelloDone
2229            }
2230            HandshakeType::ClientKeyExchange => {
2231                HandshakePayload::ClientKeyExchange(Payload::read(&mut sub))
2232            }
2233            HandshakeType::CertificateRequest if vers == ProtocolVersion::TLSv1_3 => {
2234                let p = CertificateRequestPayloadTLS13::read(&mut sub)?;
2235                HandshakePayload::CertificateRequestTLS13(p)
2236            }
2237            HandshakeType::CertificateRequest => {
2238                let p = CertificateRequestPayload::read(&mut sub)?;
2239                HandshakePayload::CertificateRequest(p)
2240            }
2241            HandshakeType::CertificateVerify => {
2242                HandshakePayload::CertificateVerify(DigitallySignedStruct::read(&mut sub)?)
2243            }
2244            HandshakeType::NewSessionTicket if vers == ProtocolVersion::TLSv1_3 => {
2245                let p = NewSessionTicketPayloadTLS13::read(&mut sub)?;
2246                HandshakePayload::NewSessionTicketTLS13(p)
2247            }
2248            HandshakeType::NewSessionTicket => {
2249                let p = NewSessionTicketPayload::read(&mut sub)?;
2250                HandshakePayload::NewSessionTicket(p)
2251            }
2252            HandshakeType::EncryptedExtensions => {
2253                HandshakePayload::EncryptedExtensions(Vec::read(&mut sub)?)
2254            }
2255            HandshakeType::KeyUpdate => {
2256                HandshakePayload::KeyUpdate(KeyUpdateRequest::read(&mut sub)?)
2257            }
2258            HandshakeType::EndOfEarlyData => {
2259                sub.expect_empty("EndOfEarlyData")?;
2260                HandshakePayload::EndOfEarlyData
2261            }
2262            HandshakeType::Finished => HandshakePayload::Finished(Payload::read(&mut sub)),
2263            HandshakeType::CertificateStatus => {
2264                HandshakePayload::CertificateStatus(CertificateStatus::read(&mut sub)?)
2265            }
2266            HandshakeType::MessageHash => {
2267                // does not appear on the wire
2268                return Err(InvalidMessage::UnexpectedMessage("MessageHash"));
2269            }
2270            HandshakeType::HelloRetryRequest => {
2271                // not legal on wire
2272                return Err(InvalidMessage::UnexpectedMessage("HelloRetryRequest"));
2273            }
2274            _ => HandshakePayload::Unknown(Payload::read(&mut sub)),
2275        };
2276
2277        sub.expect_empty("HandshakeMessagePayload")
2278            .map(|_| Self { typ, payload })
2279    }
2280
2281    pub fn build_key_update_notify() -> Self {
2282        Self {
2283            typ: HandshakeType::KeyUpdate,
2284            payload: HandshakePayload::KeyUpdate(KeyUpdateRequest::UpdateNotRequested),
2285        }
2286    }
2287
2288    pub fn get_encoding_for_binder_signing(&self) -> Vec<u8> {
2289        let mut ret = self.get_encoding();
2290
2291        let binder_len = match self.payload {
2292            HandshakePayload::ClientHello(ref ch) => match ch.extensions.last() {
2293                Some(ClientExtension::PresharedKey(ref offer)) => {
2294                    let mut binders_encoding = Vec::new();
2295                    offer
2296                        .binders
2297                        .encode(&mut binders_encoding);
2298                    binders_encoding.len()
2299                }
2300                _ => 0,
2301            },
2302            _ => 0,
2303        };
2304
2305        let ret_len = ret.len() - binder_len;
2306        ret.truncate(ret_len);
2307        ret
2308    }
2309
2310    pub fn build_handshake_hash(hash: &[u8]) -> Self {
2311        Self {
2312            typ: HandshakeType::MessageHash,
2313            payload: HandshakePayload::MessageHash(Payload::new(hash.to_vec())),
2314        }
2315    }
2316}