rustls/client/
handy.rs
1use crate::client;
2use crate::enums::SignatureScheme;
3use crate::error::Error;
4use crate::key;
5use crate::limited_cache;
6use crate::msgs::persist;
7use crate::sign;
8use crate::NamedGroup;
9use crate::ServerName;
10
11use std::collections::VecDeque;
12use std::sync::{Arc, Mutex};
13
14pub(super) struct NoClientSessionStorage;
16
17impl client::ClientSessionStore for NoClientSessionStorage {
18 fn set_kx_hint(&self, _: &ServerName, _: NamedGroup) {}
19
20 fn kx_hint(&self, _: &ServerName) -> Option<NamedGroup> {
21 None
22 }
23
24 fn set_tls12_session(&self, _: &ServerName, _: persist::Tls12ClientSessionValue) {}
25
26 fn tls12_session(&self, _: &ServerName) -> Option<persist::Tls12ClientSessionValue> {
27 None
28 }
29
30 fn remove_tls12_session(&self, _: &ServerName) {}
31
32 fn insert_tls13_ticket(&self, _: &ServerName, _: persist::Tls13ClientSessionValue) {}
33
34 fn take_tls13_ticket(&self, _: &ServerName) -> Option<persist::Tls13ClientSessionValue> {
35 None
36 }
37}
38
39const MAX_TLS13_TICKETS_PER_SERVER: usize = 8;
40
41struct ServerData {
42 kx_hint: Option<NamedGroup>,
43
44 #[cfg(feature = "tls12")]
46 tls12: Option<persist::Tls12ClientSessionValue>,
47
48 tls13: VecDeque<persist::Tls13ClientSessionValue>,
50}
51
52impl Default for ServerData {
53 fn default() -> Self {
54 Self {
55 kx_hint: None,
56 #[cfg(feature = "tls12")]
57 tls12: None,
58 tls13: VecDeque::with_capacity(MAX_TLS13_TICKETS_PER_SERVER),
59 }
60 }
61}
62
63pub struct ClientSessionMemoryCache {
68 servers: Mutex<limited_cache::LimitedCache<ServerName, ServerData>>,
69}
70
71impl ClientSessionMemoryCache {
72 pub fn new(size: usize) -> Self {
75 let max_servers =
76 size.saturating_add(MAX_TLS13_TICKETS_PER_SERVER - 1) / MAX_TLS13_TICKETS_PER_SERVER;
77 Self {
78 servers: Mutex::new(limited_cache::LimitedCache::new(max_servers)),
79 }
80 }
81}
82
83impl client::ClientSessionStore for ClientSessionMemoryCache {
84 fn set_kx_hint(&self, server_name: &ServerName, group: NamedGroup) {
85 self.servers
86 .lock()
87 .unwrap()
88 .get_or_insert_default_and_edit(server_name.clone(), |data| data.kx_hint = Some(group));
89 }
90
91 fn kx_hint(&self, server_name: &ServerName) -> Option<NamedGroup> {
92 self.servers
93 .lock()
94 .unwrap()
95 .get(server_name)
96 .and_then(|sd| sd.kx_hint)
97 }
98
99 fn set_tls12_session(
100 &self,
101 _server_name: &ServerName,
102 _value: persist::Tls12ClientSessionValue,
103 ) {
104 #[cfg(feature = "tls12")]
105 self.servers
106 .lock()
107 .unwrap()
108 .get_or_insert_default_and_edit(_server_name.clone(), |data| data.tls12 = Some(_value));
109 }
110
111 fn tls12_session(&self, _server_name: &ServerName) -> Option<persist::Tls12ClientSessionValue> {
112 #[cfg(not(feature = "tls12"))]
113 return None;
114
115 #[cfg(feature = "tls12")]
116 self.servers
117 .lock()
118 .unwrap()
119 .get(_server_name)
120 .and_then(|sd| sd.tls12.as_ref().cloned())
121 }
122
123 fn remove_tls12_session(&self, _server_name: &ServerName) {
124 #[cfg(feature = "tls12")]
125 self.servers
126 .lock()
127 .unwrap()
128 .get_mut(_server_name)
129 .and_then(|data| data.tls12.take());
130 }
131
132 fn insert_tls13_ticket(
133 &self,
134 server_name: &ServerName,
135 value: persist::Tls13ClientSessionValue,
136 ) {
137 self.servers
138 .lock()
139 .unwrap()
140 .get_or_insert_default_and_edit(server_name.clone(), |data| {
141 if data.tls13.len() == data.tls13.capacity() {
142 data.tls13.pop_front();
143 }
144 data.tls13.push_back(value);
145 });
146 }
147
148 fn take_tls13_ticket(
149 &self,
150 server_name: &ServerName,
151 ) -> Option<persist::Tls13ClientSessionValue> {
152 self.servers
153 .lock()
154 .unwrap()
155 .get_mut(server_name)
156 .and_then(|data| data.tls13.pop_back())
157 }
158}
159
160pub(super) struct FailResolveClientCert {}
161
162impl client::ResolvesClientCert for FailResolveClientCert {
163 fn resolve(
164 &self,
165 _acceptable_issuers: &[&[u8]],
166 _sigschemes: &[SignatureScheme],
167 ) -> Option<Arc<sign::CertifiedKey>> {
168 None
169 }
170
171 fn has_certs(&self) -> bool {
172 false
173 }
174}
175
176pub(super) struct AlwaysResolvesClientCert(Arc<sign::CertifiedKey>);
177
178impl AlwaysResolvesClientCert {
179 pub(super) fn new(
180 chain: Vec<key::Certificate>,
181 priv_key: &key::PrivateKey,
182 ) -> Result<Self, Error> {
183 let key = sign::any_supported_type(priv_key)
184 .map_err(|_| Error::General("invalid private key".into()))?;
185 Ok(Self(Arc::new(sign::CertifiedKey::new(chain, key))))
186 }
187}
188
189impl client::ResolvesClientCert for AlwaysResolvesClientCert {
190 fn resolve(
191 &self,
192 _acceptable_issuers: &[&[u8]],
193 _sigschemes: &[SignatureScheme],
194 ) -> Option<Arc<sign::CertifiedKey>> {
195 Some(Arc::clone(&self.0))
196 }
197
198 fn has_certs(&self) -> bool {
199 true
200 }
201}
202
203#[cfg(test)]
204mod test {
205 use super::NoClientSessionStorage;
206 use crate::client::ClientSessionStore;
207 use crate::msgs::enums::NamedGroup;
208 #[cfg(feature = "tls12")]
209 use crate::msgs::handshake::SessionId;
210 use crate::msgs::persist::Tls13ClientSessionValue;
211 use crate::suites::SupportedCipherSuite;
212
213 #[test]
214 fn test_noclientsessionstorage_does_nothing() {
215 let c = NoClientSessionStorage {};
216 let name = "example.com".try_into().unwrap();
217 let now = crate::ticketer::TimeBase::now().unwrap();
218
219 c.set_kx_hint(&name, NamedGroup::X25519);
220 assert_eq!(None, c.kx_hint(&name));
221
222 #[cfg(feature = "tls12")]
223 {
224 use crate::msgs::persist::Tls12ClientSessionValue;
225 let tls12_suite = match crate::cipher_suite::TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 {
226 SupportedCipherSuite::Tls12(inner) => inner,
227 _ => unreachable!(),
228 };
229
230 c.set_tls12_session(
231 &name,
232 Tls12ClientSessionValue::new(
233 tls12_suite,
234 SessionId::empty(),
235 Vec::new(),
236 Vec::new(),
237 Vec::new(),
238 now,
239 0,
240 true,
241 ),
242 );
243 assert!(c.tls12_session(&name).is_none());
244 c.remove_tls12_session(&name);
245 }
246
247 let tls13_suite = match crate::cipher_suite::TLS13_AES_256_GCM_SHA384 {
248 SupportedCipherSuite::Tls13(inner) => inner,
249 #[cfg(feature = "tls12")]
250 _ => unreachable!(),
251 };
252 c.insert_tls13_ticket(
253 &name,
254 Tls13ClientSessionValue::new(
255 tls13_suite,
256 Vec::new(),
257 Vec::new(),
258 Vec::new(),
259 now,
260 0,
261 0,
262 0,
263 ),
264 );
265 assert!(c.take_tls13_ticket(&name).is_none());
266 }
267}