rustls/msgs/
persist.rs

1use crate::dns_name::DnsName;
2use crate::enums::{CipherSuite, ProtocolVersion};
3use crate::error::InvalidMessage;
4use crate::key;
5use crate::msgs::base::{PayloadU16, PayloadU8};
6use crate::msgs::codec::{Codec, Reader};
7use crate::msgs::handshake::CertificatePayload;
8use crate::msgs::handshake::SessionId;
9use crate::ticketer::TimeBase;
10#[cfg(feature = "tls12")]
11use crate::tls12::Tls12CipherSuite;
12use crate::tls13::Tls13CipherSuite;
13
14use std::cmp;
15#[cfg(feature = "tls12")]
16use std::mem;
17
18pub struct Retrieved<T> {
19    pub value: T,
20    retrieved_at: TimeBase,
21}
22
23impl<T> Retrieved<T> {
24    pub fn new(value: T, retrieved_at: TimeBase) -> Self {
25        Self {
26            value,
27            retrieved_at,
28        }
29    }
30
31    pub fn map<M>(&self, f: impl FnOnce(&T) -> Option<&M>) -> Option<Retrieved<&M>> {
32        Some(Retrieved {
33            value: f(&self.value)?,
34            retrieved_at: self.retrieved_at,
35        })
36    }
37}
38
39impl Retrieved<&Tls13ClientSessionValue> {
40    pub fn obfuscated_ticket_age(&self) -> u32 {
41        let age_secs = self
42            .retrieved_at
43            .as_secs()
44            .saturating_sub(self.value.common.epoch);
45        let age_millis = age_secs as u32 * 1000;
46        age_millis.wrapping_add(self.value.age_add)
47    }
48}
49
50impl<T: std::ops::Deref<Target = ClientSessionCommon>> Retrieved<T> {
51    pub fn has_expired(&self) -> bool {
52        let common = &*self.value;
53        common.lifetime_secs != 0
54            && common
55                .epoch
56                .saturating_add(u64::from(common.lifetime_secs))
57                < self.retrieved_at.as_secs()
58    }
59}
60
61impl<T> std::ops::Deref for Retrieved<T> {
62    type Target = T;
63
64    fn deref(&self) -> &Self::Target {
65        &self.value
66    }
67}
68
69#[derive(Debug)]
70pub struct Tls13ClientSessionValue {
71    suite: &'static Tls13CipherSuite,
72    age_add: u32,
73    max_early_data_size: u32,
74    pub(crate) common: ClientSessionCommon,
75    #[cfg(feature = "quic")]
76    quic_params: PayloadU16,
77}
78
79impl Tls13ClientSessionValue {
80    pub(crate) fn new(
81        suite: &'static Tls13CipherSuite,
82        ticket: Vec<u8>,
83        secret: Vec<u8>,
84        server_cert_chain: Vec<key::Certificate>,
85        time_now: TimeBase,
86        lifetime_secs: u32,
87        age_add: u32,
88        max_early_data_size: u32,
89    ) -> Self {
90        Self {
91            suite,
92            age_add,
93            max_early_data_size,
94            common: ClientSessionCommon::new(
95                ticket,
96                secret,
97                time_now,
98                lifetime_secs,
99                server_cert_chain,
100            ),
101            #[cfg(feature = "quic")]
102            quic_params: PayloadU16(Vec::new()),
103        }
104    }
105
106    pub fn max_early_data_size(&self) -> u32 {
107        self.max_early_data_size
108    }
109
110    pub fn suite(&self) -> &'static Tls13CipherSuite {
111        self.suite
112    }
113
114    #[doc(hidden)]
115    /// Test only: rewind epoch by `delta` seconds.
116    pub fn rewind_epoch(&mut self, delta: u32) {
117        self.common.epoch -= delta as u64;
118    }
119
120    #[cfg(feature = "quic")]
121    pub fn set_quic_params(&mut self, quic_params: &[u8]) {
122        self.quic_params = PayloadU16(quic_params.to_vec());
123    }
124
125    #[cfg(feature = "quic")]
126    pub fn quic_params(&self) -> Vec<u8> {
127        self.quic_params.0.clone()
128    }
129}
130
131impl std::ops::Deref for Tls13ClientSessionValue {
132    type Target = ClientSessionCommon;
133
134    fn deref(&self) -> &Self::Target {
135        &self.common
136    }
137}
138
139#[derive(Debug, Clone)]
140pub struct Tls12ClientSessionValue {
141    #[cfg(feature = "tls12")]
142    suite: &'static Tls12CipherSuite,
143    #[cfg(feature = "tls12")]
144    pub(crate) session_id: SessionId,
145    #[cfg(feature = "tls12")]
146    extended_ms: bool,
147    #[doc(hidden)]
148    #[cfg(feature = "tls12")]
149    pub(crate) common: ClientSessionCommon,
150}
151
152#[cfg(feature = "tls12")]
153impl Tls12ClientSessionValue {
154    pub(crate) fn new(
155        suite: &'static Tls12CipherSuite,
156        session_id: SessionId,
157        ticket: Vec<u8>,
158        master_secret: Vec<u8>,
159        server_cert_chain: Vec<key::Certificate>,
160        time_now: TimeBase,
161        lifetime_secs: u32,
162        extended_ms: bool,
163    ) -> Self {
164        Self {
165            suite,
166            session_id,
167            extended_ms,
168            common: ClientSessionCommon::new(
169                ticket,
170                master_secret,
171                time_now,
172                lifetime_secs,
173                server_cert_chain,
174            ),
175        }
176    }
177
178    pub(crate) fn take_ticket(&mut self) -> Vec<u8> {
179        mem::take(&mut self.common.ticket.0)
180    }
181
182    pub(crate) fn extended_ms(&self) -> bool {
183        self.extended_ms
184    }
185
186    pub(crate) fn suite(&self) -> &'static Tls12CipherSuite {
187        self.suite
188    }
189
190    #[doc(hidden)]
191    /// Test only: rewind epoch by `delta` seconds.
192    pub fn rewind_epoch(&mut self, delta: u32) {
193        self.common.epoch -= delta as u64;
194    }
195}
196
197#[cfg(feature = "tls12")]
198impl std::ops::Deref for Tls12ClientSessionValue {
199    type Target = ClientSessionCommon;
200
201    fn deref(&self) -> &Self::Target {
202        &self.common
203    }
204}
205
206#[derive(Debug, Clone)]
207pub struct ClientSessionCommon {
208    ticket: PayloadU16,
209    secret: PayloadU8,
210    epoch: u64,
211    lifetime_secs: u32,
212    server_cert_chain: CertificatePayload,
213}
214
215impl ClientSessionCommon {
216    fn new(
217        ticket: Vec<u8>,
218        secret: Vec<u8>,
219        time_now: TimeBase,
220        lifetime_secs: u32,
221        server_cert_chain: Vec<key::Certificate>,
222    ) -> Self {
223        Self {
224            ticket: PayloadU16(ticket),
225            secret: PayloadU8(secret),
226            epoch: time_now.as_secs(),
227            lifetime_secs: cmp::min(lifetime_secs, MAX_TICKET_LIFETIME),
228            server_cert_chain,
229        }
230    }
231
232    pub(crate) fn server_cert_chain(&self) -> &[key::Certificate] {
233        self.server_cert_chain.as_ref()
234    }
235
236    pub(crate) fn secret(&self) -> &[u8] {
237        self.secret.0.as_ref()
238    }
239
240    pub(crate) fn ticket(&self) -> &[u8] {
241        self.ticket.0.as_ref()
242    }
243}
244
245static MAX_TICKET_LIFETIME: u32 = 7 * 24 * 60 * 60;
246
247/// This is the maximum allowed skew between server and client clocks, over
248/// the maximum ticket lifetime period.  This encompasses TCP retransmission
249/// times in case packet loss occurs when the client sends the ClientHello
250/// or receives the NewSessionTicket, _and_ actual clock skew over this period.
251static MAX_FRESHNESS_SKEW_MS: u32 = 60 * 1000;
252
253// --- Server types ---
254pub type ServerSessionKey = SessionId;
255
256#[derive(Debug)]
257pub struct ServerSessionValue {
258    pub sni: Option<DnsName>,
259    pub version: ProtocolVersion,
260    pub cipher_suite: CipherSuite,
261    pub master_secret: PayloadU8,
262    pub extended_ms: bool,
263    pub client_cert_chain: Option<CertificatePayload>,
264    pub alpn: Option<PayloadU8>,
265    pub application_data: PayloadU16,
266    pub creation_time_sec: u64,
267    pub age_obfuscation_offset: u32,
268    freshness: Option<bool>,
269}
270
271impl Codec for ServerSessionValue {
272    fn encode(&self, bytes: &mut Vec<u8>) {
273        if let Some(ref sni) = self.sni {
274            1u8.encode(bytes);
275            let sni_bytes: &str = sni.as_ref();
276            PayloadU8::new(Vec::from(sni_bytes)).encode(bytes);
277        } else {
278            0u8.encode(bytes);
279        }
280        self.version.encode(bytes);
281        self.cipher_suite.encode(bytes);
282        self.master_secret.encode(bytes);
283        (u8::from(self.extended_ms)).encode(bytes);
284        if let Some(ref chain) = self.client_cert_chain {
285            1u8.encode(bytes);
286            chain.encode(bytes);
287        } else {
288            0u8.encode(bytes);
289        }
290        if let Some(ref alpn) = self.alpn {
291            1u8.encode(bytes);
292            alpn.encode(bytes);
293        } else {
294            0u8.encode(bytes);
295        }
296        self.application_data.encode(bytes);
297        self.creation_time_sec.encode(bytes);
298        self.age_obfuscation_offset
299            .encode(bytes);
300    }
301
302    fn read(r: &mut Reader) -> Result<Self, InvalidMessage> {
303        let has_sni = u8::read(r)?;
304        let sni = if has_sni == 1 {
305            let dns_name = PayloadU8::read(r)?;
306            let dns_name = match DnsName::try_from_ascii(&dns_name.0) {
307                Ok(dns_name) => dns_name,
308                Err(_) => return Err(InvalidMessage::InvalidServerName),
309            };
310
311            Some(dns_name)
312        } else {
313            None
314        };
315
316        let v = ProtocolVersion::read(r)?;
317        let cs = CipherSuite::read(r)?;
318        let ms = PayloadU8::read(r)?;
319        let ems = u8::read(r)?;
320        let has_ccert = u8::read(r)? == 1;
321        let ccert = if has_ccert {
322            Some(CertificatePayload::read(r)?)
323        } else {
324            None
325        };
326        let has_alpn = u8::read(r)? == 1;
327        let alpn = if has_alpn {
328            Some(PayloadU8::read(r)?)
329        } else {
330            None
331        };
332        let application_data = PayloadU16::read(r)?;
333        let creation_time_sec = u64::read(r)?;
334        let age_obfuscation_offset = u32::read(r)?;
335
336        Ok(Self {
337            sni,
338            version: v,
339            cipher_suite: cs,
340            master_secret: ms,
341            extended_ms: ems == 1u8,
342            client_cert_chain: ccert,
343            alpn,
344            application_data,
345            creation_time_sec,
346            age_obfuscation_offset,
347            freshness: None,
348        })
349    }
350}
351
352impl ServerSessionValue {
353    pub fn new(
354        sni: Option<&DnsName>,
355        v: ProtocolVersion,
356        cs: CipherSuite,
357        ms: Vec<u8>,
358        client_cert_chain: Option<CertificatePayload>,
359        alpn: Option<Vec<u8>>,
360        application_data: Vec<u8>,
361        creation_time: TimeBase,
362        age_obfuscation_offset: u32,
363    ) -> Self {
364        Self {
365            sni: sni.cloned(),
366            version: v,
367            cipher_suite: cs,
368            master_secret: PayloadU8::new(ms),
369            extended_ms: false,
370            client_cert_chain,
371            alpn: alpn.map(PayloadU8::new),
372            application_data: PayloadU16::new(application_data),
373            creation_time_sec: creation_time.as_secs(),
374            age_obfuscation_offset,
375            freshness: None,
376        }
377    }
378
379    pub fn set_extended_ms_used(&mut self) {
380        self.extended_ms = true;
381    }
382
383    pub fn set_freshness(mut self, obfuscated_client_age_ms: u32, time_now: TimeBase) -> Self {
384        let client_age_ms = obfuscated_client_age_ms.wrapping_sub(self.age_obfuscation_offset);
385        let server_age_ms = (time_now
386            .as_secs()
387            .saturating_sub(self.creation_time_sec) as u32)
388            .saturating_mul(1000);
389
390        let age_difference = if client_age_ms < server_age_ms {
391            server_age_ms - client_age_ms
392        } else {
393            client_age_ms - server_age_ms
394        };
395
396        self.freshness = Some(age_difference <= MAX_FRESHNESS_SKEW_MS);
397        self
398    }
399
400    pub fn is_fresh(&self) -> bool {
401        self.freshness.unwrap_or_default()
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use crate::enums::*;
409
410    #[test]
411    fn serversessionvalue_is_debug() {
412        let ssv = ServerSessionValue::new(
413            None,
414            ProtocolVersion::TLSv1_3,
415            CipherSuite::TLS13_AES_128_GCM_SHA256,
416            vec![1, 2, 3],
417            None,
418            None,
419            vec![4, 5, 6],
420            TimeBase::now().unwrap(),
421            0x12345678,
422        );
423        println!("{:?}", ssv);
424    }
425
426    #[test]
427    fn serversessionvalue_no_sni() {
428        let bytes = [
429            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
430            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
431        ];
432        let mut rd = Reader::init(&bytes);
433        let ssv = ServerSessionValue::read(&mut rd).unwrap();
434        assert_eq!(ssv.get_encoding(), bytes);
435    }
436
437    #[test]
438    fn serversessionvalue_with_cert() {
439        let bytes = [
440            0x00, 0x03, 0x03, 0xc0, 0x23, 0x03, 0x01, 0x02, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00,
441            0x12, 0x23, 0x34, 0x45, 0x56, 0x67, 0x78, 0x89, 0xfe, 0xed, 0xf0, 0x0d,
442        ];
443        let mut rd = Reader::init(&bytes);
444        let ssv = ServerSessionValue::read(&mut rd).unwrap();
445        assert_eq!(ssv.get_encoding(), bytes);
446    }
447}