rustls/
hash_hs.rs
1use crate::msgs::codec::Codec;
2use crate::msgs::handshake::HandshakeMessagePayload;
3use crate::msgs::message::{Message, MessagePayload};
4use ring::digest;
5use std::mem;
6
7pub(crate) struct HandshakeHashBuffer {
13 buffer: Vec<u8>,
14 client_auth_enabled: bool,
15}
16
17impl HandshakeHashBuffer {
18 pub(crate) fn new() -> Self {
19 Self {
20 buffer: Vec::new(),
21 client_auth_enabled: false,
22 }
23 }
24
25 pub(crate) fn set_client_auth_enabled(&mut self) {
28 self.client_auth_enabled = true;
29 }
30
31 pub(crate) fn add_message(&mut self, m: &Message) {
33 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
34 self.buffer
35 .extend_from_slice(&encoded.0);
36 }
37 }
38
39 #[cfg(test)]
41 fn update_raw(&mut self, buf: &[u8]) {
42 self.buffer.extend_from_slice(buf);
43 }
44
45 pub(crate) fn get_hash_given(
47 &self,
48 hash: &'static digest::Algorithm,
49 extra: &[u8],
50 ) -> digest::Digest {
51 let mut ctx = digest::Context::new(hash);
52 ctx.update(&self.buffer);
53 ctx.update(extra);
54 ctx.finish()
55 }
56
57 pub(crate) fn start_hash(self, alg: &'static digest::Algorithm) -> HandshakeHash {
59 let mut ctx = digest::Context::new(alg);
60 ctx.update(&self.buffer);
61 HandshakeHash {
62 ctx,
63 client_auth: match self.client_auth_enabled {
64 true => Some(self.buffer),
65 false => None,
66 },
67 }
68 }
69}
70
71pub(crate) struct HandshakeHash {
79 ctx: digest::Context,
81
82 client_auth: Option<Vec<u8>>,
84}
85
86impl HandshakeHash {
87 pub(crate) fn abandon_client_auth(&mut self) {
90 self.client_auth = None;
91 }
92
93 pub(crate) fn add_message(&mut self, m: &Message) -> &mut Self {
95 if let MessagePayload::Handshake { encoded, .. } = &m.payload {
96 self.update_raw(&encoded.0);
97 }
98 self
99 }
100
101 fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
103 self.ctx.update(buf);
104
105 if let Some(buffer) = &mut self.client_auth {
106 buffer.extend_from_slice(buf);
107 }
108
109 self
110 }
111
112 pub(crate) fn get_hash_given(&self, extra: &[u8]) -> digest::Digest {
115 let mut ctx = self.ctx.clone();
116 ctx.update(extra);
117 ctx.finish()
118 }
119
120 pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
121 let old_hash = self.ctx.finish();
122 let old_handshake_hash_msg =
123 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
124
125 HandshakeHashBuffer {
126 client_auth_enabled: self.client_auth.is_some(),
127 buffer: old_handshake_hash_msg.get_encoding(),
128 }
129 }
130
131 pub(crate) fn rollup_for_hrr(&mut self) {
135 let ctx = &mut self.ctx;
136
137 let old_ctx = mem::replace(ctx, digest::Context::new(ctx.algorithm()));
138 let old_hash = old_ctx.finish();
139 let old_handshake_hash_msg =
140 HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
141
142 self.update_raw(&old_handshake_hash_msg.get_encoding());
143 }
144
145 pub(crate) fn get_current_hash(&self) -> digest::Digest {
147 self.ctx.clone().finish()
148 }
149
150 #[cfg(feature = "tls12")]
154 pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
155 self.client_auth.take()
156 }
157
158 pub(crate) fn algorithm(&self) -> &'static digest::Algorithm {
160 self.ctx.algorithm()
161 }
162}
163
164#[cfg(test)]
165mod test {
166 use super::HandshakeHashBuffer;
167 use ring::digest;
168
169 #[test]
170 fn hashes_correctly() {
171 let mut hhb = HandshakeHashBuffer::new();
172 hhb.update_raw(b"hello");
173 assert_eq!(hhb.buffer.len(), 5);
174 let mut hh = hhb.start_hash(&digest::SHA256);
175 assert!(hh.client_auth.is_none());
176 hh.update_raw(b"world");
177 let h = hh.get_current_hash();
178 let h = h.as_ref();
179 assert_eq!(h[0], 0x93);
180 assert_eq!(h[1], 0x6a);
181 assert_eq!(h[2], 0x18);
182 assert_eq!(h[3], 0x5c);
183 }
184
185 #[cfg(feature = "tls12")]
186 #[test]
187 fn buffers_correctly() {
188 let mut hhb = HandshakeHashBuffer::new();
189 hhb.set_client_auth_enabled();
190 hhb.update_raw(b"hello");
191 assert_eq!(hhb.buffer.len(), 5);
192 let mut hh = hhb.start_hash(&digest::SHA256);
193 assert_eq!(
194 hh.client_auth
195 .as_ref()
196 .map(|buf| buf.len()),
197 Some(5)
198 );
199 hh.update_raw(b"world");
200 assert_eq!(
201 hh.client_auth
202 .as_ref()
203 .map(|buf| buf.len()),
204 Some(10)
205 );
206 let h = hh.get_current_hash();
207 let h = h.as_ref();
208 assert_eq!(h[0], 0x93);
209 assert_eq!(h[1], 0x6a);
210 assert_eq!(h[2], 0x18);
211 assert_eq!(h[3], 0x5c);
212 let buf = hh.take_handshake_buf();
213 assert_eq!(Some(b"helloworld".to_vec()), buf);
214 }
215
216 #[test]
217 fn abandon() {
218 let mut hhb = HandshakeHashBuffer::new();
219 hhb.set_client_auth_enabled();
220 hhb.update_raw(b"hello");
221 assert_eq!(hhb.buffer.len(), 5);
222 let mut hh = hhb.start_hash(&digest::SHA256);
223 assert_eq!(
224 hh.client_auth
225 .as_ref()
226 .map(|buf| buf.len()),
227 Some(5)
228 );
229 hh.abandon_client_auth();
230 assert_eq!(hh.client_auth, None);
231 hh.update_raw(b"world");
232 assert_eq!(hh.client_auth, None);
233 let h = hh.get_current_hash();
234 let h = h.as_ref();
235 assert_eq!(h[0], 0x93);
236 assert_eq!(h[1], 0x6a);
237 assert_eq!(h[2], 0x18);
238 assert_eq!(h[3], 0x5c);
239 }
240}