rustls/client/
handy.rs

1use crate::client;
2use crate::enums::SignatureScheme;
3use crate::error::Error;
4use crate::key;
5use crate::limited_cache;
6use crate::msgs::persist;
7use crate::sign;
8use crate::NamedGroup;
9use crate::ServerName;
10
11use std::collections::VecDeque;
12use std::sync::{Arc, Mutex};
13
14/// An implementer of `ClientSessionStore` which does nothing.
15pub(super) struct NoClientSessionStorage;
16
17impl client::ClientSessionStore for NoClientSessionStorage {
18    fn set_kx_hint(&self, _: &ServerName, _: NamedGroup) {}
19
20    fn kx_hint(&self, _: &ServerName) -> Option<NamedGroup> {
21        None
22    }
23
24    fn set_tls12_session(&self, _: &ServerName, _: persist::Tls12ClientSessionValue) {}
25
26    fn tls12_session(&self, _: &ServerName) -> Option<persist::Tls12ClientSessionValue> {
27        None
28    }
29
30    fn remove_tls12_session(&self, _: &ServerName) {}
31
32    fn insert_tls13_ticket(&self, _: &ServerName, _: persist::Tls13ClientSessionValue) {}
33
34    fn take_tls13_ticket(&self, _: &ServerName) -> Option<persist::Tls13ClientSessionValue> {
35        None
36    }
37}
38
39const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
40
41struct ServerData {
42    kx_hint: Option<NamedGroup>,
43
44    // Zero or one TLS1.2 sessions.
45    #[cfg(feature = "tls12")]
46    tls12: Option<persist::Tls12ClientSessionValue>,
47
48    // Up to MAX_TLS13_TICKETS_PER_SERVER TLS1.3 tickets, oldest first.
49    tls13: VecDeque<persist::Tls13ClientSessionValue>,
50}
51
52impl Default for ServerData {
53    fn default() -> Self {
54        Self {
55            kx_hint: None,
56            #[cfg(feature = "tls12")]
57            tls12: None,
58            tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
59        }
60    }
61}
62
63/// An implementer of `ClientSessionStore` that stores everything
64/// in memory.
65///
66/// It enforces a limit on the number of entries to bound memory usage.
67pub struct ClientSessionMemoryCache {
68    servers: Mutex<limited_cache::LimitedCache<ServerName, ServerData>>,
69}
70
71impl ClientSessionMemoryCache {
72    /// Make a new ClientSessionMemoryCache.  `size` is the
73    /// maximum number of stored sessions.
74    pub fn new(size: usize) -> Self {
75        let max_servers =
76            size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER;
77        Self {
78            servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
79        }
80    }
81}
82
83impl client::ClientSessionStore for ClientSessionMemoryCache {
84    fn set_kx_hint(&self, server_name: &ServerName, group: NamedGroup) {
85        self.servers
86            .lock()
87            .unwrap()
88            .get_or_insert_default_and_edit(server_name.clone(), |data| data.kx_hint = Some(group));
89    }
90
91    fn kx_hint(&self, server_name: &ServerName) -> Option<NamedGroup> {
92        self.servers
93            .lock()
94            .unwrap()
95            .get(server_name)
96            .and_then(|sd| sd.kx_hint)
97    }
98
99    fn set_tls12_session(
100        &self,
101        _server_name: &ServerName,
102        _value: persist::Tls12ClientSessionValue,
103    ) {
104        #[cfg(feature = "tls12")]
105        self.servers
106            .lock()
107            .unwrap()
108            .get_or_insert_default_and_edit(_server_name.clone(), |data| data.tls12 = Some(_value));
109    }
110
111    fn tls12_session(&self, _server_name: &ServerName) -> Option<persist::Tls12ClientSessionValue> {
112        #[cfg(not(feature = "tls12"))]
113        return None;
114
115        #[cfg(feature = "tls12")]
116        self.servers
117            .lock()
118            .unwrap()
119            .get(_server_name)
120            .and_then(|sd| sd.tls12.as_ref().cloned())
121    }
122
123    fn remove_tls12_session(&self, _server_name: &ServerName) {
124        #[cfg(feature = "tls12")]
125        self.servers
126            .lock()
127            .unwrap()
128            .get_mut(_server_name)
129            .and_then(|data| data.tls12.take());
130    }
131
132    fn insert_tls13_ticket(
133        &self,
134        server_name: &ServerName,
135        value: persist::Tls13ClientSessionValue,
136    ) {
137        self.servers
138            .lock()
139            .unwrap()
140            .get_or_insert_default_and_edit(server_name.clone(), |data| {
141                if data.tls13.len() == data.tls13.capacity() {
142                    data.tls13.pop_front();
143                }
144                data.tls13.push_back(value);
145            });
146    }
147
148    fn take_tls13_ticket(
149        &self,
150        server_name: &ServerName,
151    ) -> Option<persist::Tls13ClientSessionValue> {
152        self.servers
153            .lock()
154            .unwrap()
155            .get_mut(server_name)
156            .and_then(|data| data.tls13.pop_back())
157    }
158}
159
160pub(super) struct FailResolveClientCert {}
161
162impl client::ResolvesClientCert for FailResolveClientCert {
163    fn resolve(
164        &self,
165        _acceptable_issuers: &[&[u8]],
166        _sigschemes: &[SignatureScheme],
167    ) -> Option<Arc<sign::CertifiedKey>> {
168        None
169    }
170
171    fn has_certs(&self) -> bool {
172        false
173    }
174}
175
176pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
177
178impl AlwaysResolvesClientCert {
179    pub(super) fn new(
180        chain: Vec<key::Certificate>,
181        priv_key: &key::PrivateKey,
182    ) -> Result<Self, Error> {
183        let key = sign::any_supported_type(priv_key)
184            .map_err(|_| Error::General("invalid private key".into()))?;
185        Ok(Self(Arc::new(sign::CertifiedKey::new(chain, key))))
186    }
187}
188
189impl client::ResolvesClientCert for AlwaysResolvesClientCert {
190    fn resolve(
191        &self,
192        _acceptable_issuers: &[&[u8]],
193        _sigschemes: &[SignatureScheme],
194    ) -> Option<Arc<sign::CertifiedKey>> {
195        Some(Arc::clone(&self.0))
196    }
197
198    fn has_certs(&self) -> bool {
199        true
200    }
201}
202
203#[cfg(test)]
204mod test {
205    use super::NoClientSessionStorage;
206    use crate::client::ClientSessionStore;
207    use crate::msgs::enums::NamedGroup;
208    #[cfg(feature = "tls12")]
209    use crate::msgs::handshake::SessionId;
210    use crate::msgs::persist::Tls13ClientSessionValue;
211    use crate::suites::SupportedCipherSuite;
212
213    #[test]
214    fn test_noclientsessionstorage_does_nothing() {
215        let c = NoClientSessionStorage {};
216        let name = "example.com".try_into().unwrap();
217        let now = crate::ticketer::TimeBase::now().unwrap();
218
219        c.set_kx_hint(&name, NamedGroup::X25519);
220        assert_eq!(None, c.kx_hint(&name));
221
222        #[cfg(feature = "tls12")]
223        {
224            use crate::msgs::persist::Tls12ClientSessionValue;
225            let tls12_suite = match crate::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 {
226                SupportedCipherSuite::Tls12(inner) => inner,
227                _ => unreachable!(),
228            };
229
230            c.set_tls12_session(
231                &name,
232                Tls12ClientSessionValue::new(
233                    tls12_suite,
234                    SessionId::empty(),
235                    Vec::new(),
236                    Vec::new(),
237                    Vec::new(),
238                    now,
239                    0,
240                    true,
241                ),
242            );
243            assert!(c.tls12_session(&name).is_none());
244            c.remove_tls12_session(&name);
245        }
246
247        let tls13_suite = match crate::cipher_suite::TLS13_AES_256_GCM_SHA384 {
248            SupportedCipherSuite::Tls13(inner) => inner,
249            #[cfg(feature = "tls12")]
250            _ => unreachable!(),
251        };
252        c.insert_tls13_ticket(
253            &name,
254            Tls13ClientSessionValue::new(
255                tls13_suite,
256                Vec::new(),
257                Vec::new(),
258                Vec::new(),
259                now,
260                0,
261                0,
262                0,
263            ),
264        );
265        assert!(c.take_tls13_ticket(&name).is_none());
266    }
267}