rustls/
hash_hs.rs

1use crate::msgs::codec::Codec;
2use crate::msgs::handshake::HandshakeMessagePayload;
3use crate::msgs::message::{Message, MessagePayload};
4use ring::digest;
5use std::mem;
6
7/// Early stage buffering of handshake payloads.
8///
9/// Before we know the hash algorithm to use to verify the handshake, we just buffer the messages.
10/// During the handshake, we may restart the transcript due to a HelloRetryRequest, reverting
11/// from the `HandshakeHash` to a `HandshakeHashBuffer` again.
12pub(crate) struct HandshakeHashBuffer {
13    buffer: Vec<u8>,
14    client_auth_enabled: bool,
15}
16
17impl HandshakeHashBuffer {
18    pub(crate) fn new() -> Self {
19        Self {
20            buffer: Vec::new(),
21            client_auth_enabled: false,
22        }
23    }
24
25    /// We might be doing client auth, so need to keep a full
26    /// log of the handshake.
27    pub(crate) fn set_client_auth_enabled(&mut self) {
28        self.client_auth_enabled = true;
29    }
30
31    /// Hash/buffer a handshake message.
32    pub(crate) fn add_message(&mut self, m: &Message) {
33        if let MessagePayload::Handshake { encoded, .. } = &m.payload {
34            self.buffer
35                .extend_from_slice(&encoded.0);
36        }
37    }
38
39    /// Hash or buffer a byte slice.
40    #[cfg(test)]
41    fn update_raw(&mut self, buf: &[u8]) {
42        self.buffer.extend_from_slice(buf);
43    }
44
45    /// Get the hash value if we were to hash `extra` too.
46    pub(crate) fn get_hash_given(
47        &self,
48        hash: &'static digest::Algorithm,
49        extra: &[u8],
50    ) -> digest::Digest {
51        let mut ctx = digest::Context::new(hash);
52        ctx.update(&self.buffer);
53        ctx.update(extra);
54        ctx.finish()
55    }
56
57    /// We now know what hash function the verify_data will use.
58    pub(crate) fn start_hash(self, alg: &'static digest::Algorithm) -> HandshakeHash {
59        let mut ctx = digest::Context::new(alg);
60        ctx.update(&self.buffer);
61        HandshakeHash {
62            ctx,
63            client_auth: match self.client_auth_enabled {
64                true => Some(self.buffer),
65                false => None,
66            },
67        }
68    }
69}
70
71/// This deals with keeping a running hash of the handshake
72/// payloads.  This is computed by buffering initially.  Once
73/// we know what hash function we need to use we switch to
74/// incremental hashing.
75///
76/// For client auth, we also need to buffer all the messages.
77/// This is disabled in cases where client auth is not possible.
78pub(crate) struct HandshakeHash {
79    /// None before we know what hash function we're using
80    ctx: digest::Context,
81
82    /// buffer for client-auth.
83    client_auth: Option<Vec<u8>>,
84}
85
86impl HandshakeHash {
87    /// We decided not to do client auth after all, so discard
88    /// the transcript.
89    pub(crate) fn abandon_client_auth(&mut self) {
90        self.client_auth = None;
91    }
92
93    /// Hash/buffer a handshake message.
94    pub(crate) fn add_message(&mut self, m: &Message) -> &mut Self {
95        if let MessagePayload::Handshake { encoded, .. } = &m.payload {
96            self.update_raw(&encoded.0);
97        }
98        self
99    }
100
101    /// Hash or buffer a byte slice.
102    fn update_raw(&mut self, buf: &[u8]) -> &mut Self {
103        self.ctx.update(buf);
104
105        if let Some(buffer) = &mut self.client_auth {
106            buffer.extend_from_slice(buf);
107        }
108
109        self
110    }
111
112    /// Get the hash value if we were to hash `extra` too,
113    /// using hash function `hash`.
114    pub(crate) fn get_hash_given(&self, extra: &[u8]) -> digest::Digest {
115        let mut ctx = self.ctx.clone();
116        ctx.update(extra);
117        ctx.finish()
118    }
119
120    pub(crate) fn into_hrr_buffer(self) -> HandshakeHashBuffer {
121        let old_hash = self.ctx.finish();
122        let old_handshake_hash_msg =
123            HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
124
125        HandshakeHashBuffer {
126            client_auth_enabled: self.client_auth.is_some(),
127            buffer: old_handshake_hash_msg.get_encoding(),
128        }
129    }
130
131    /// Take the current hash value, and encapsulate it in a
132    /// 'handshake_hash' handshake message.  Start this hash
133    /// again, with that message at the front.
134    pub(crate) fn rollup_for_hrr(&mut self) {
135        let ctx = &mut self.ctx;
136
137        let old_ctx = mem::replace(ctx, digest::Context::new(ctx.algorithm()));
138        let old_hash = old_ctx.finish();
139        let old_handshake_hash_msg =
140            HandshakeMessagePayload::build_handshake_hash(old_hash.as_ref());
141
142        self.update_raw(&old_handshake_hash_msg.get_encoding());
143    }
144
145    /// Get the current hash value.
146    pub(crate) fn get_current_hash(&self) -> digest::Digest {
147        self.ctx.clone().finish()
148    }
149
150    /// Takes this object's buffer containing all handshake messages
151    /// so far.  This method only works once; it resets the buffer
152    /// to empty.
153    #[cfg(feature = "tls12")]
154    pub(crate) fn take_handshake_buf(&mut self) -> Option<Vec<u8>> {
155        self.client_auth.take()
156    }
157
158    /// The digest algorithm
159    pub(crate) fn algorithm(&self) -> &'static digest::Algorithm {
160        self.ctx.algorithm()
161    }
162}
163
164#[cfg(test)]
165mod test {
166    use super::HandshakeHashBuffer;
167    use ring::digest;
168
169    #[test]
170    fn hashes_correctly() {
171        let mut hhb = HandshakeHashBuffer::new();
172        hhb.update_raw(b"hello");
173        assert_eq!(hhb.buffer.len(), 5);
174        let mut hh = hhb.start_hash(&digest::SHA256);
175        assert!(hh.client_auth.is_none());
176        hh.update_raw(b"world");
177        let h = hh.get_current_hash();
178        let h = h.as_ref();
179        assert_eq!(h[0], 0x93);
180        assert_eq!(h[1], 0x6a);
181        assert_eq!(h[2], 0x18);
182        assert_eq!(h[3], 0x5c);
183    }
184
185    #[cfg(feature = "tls12")]
186    #[test]
187    fn buffers_correctly() {
188        let mut hhb = HandshakeHashBuffer::new();
189        hhb.set_client_auth_enabled();
190        hhb.update_raw(b"hello");
191        assert_eq!(hhb.buffer.len(), 5);
192        let mut hh = hhb.start_hash(&digest::SHA256);
193        assert_eq!(
194            hh.client_auth
195                .as_ref()
196                .map(|buf| buf.len()),
197            Some(5)
198        );
199        hh.update_raw(b"world");
200        assert_eq!(
201            hh.client_auth
202                .as_ref()
203                .map(|buf| buf.len()),
204            Some(10)
205        );
206        let h = hh.get_current_hash();
207        let h = h.as_ref();
208        assert_eq!(h[0], 0x93);
209        assert_eq!(h[1], 0x6a);
210        assert_eq!(h[2], 0x18);
211        assert_eq!(h[3], 0x5c);
212        let buf = hh.take_handshake_buf();
213        assert_eq!(Some(b"helloworld".to_vec()), buf);
214    }
215
216    #[test]
217    fn abandon() {
218        let mut hhb = HandshakeHashBuffer::new();
219        hhb.set_client_auth_enabled();
220        hhb.update_raw(b"hello");
221        assert_eq!(hhb.buffer.len(), 5);
222        let mut hh = hhb.start_hash(&digest::SHA256);
223        assert_eq!(
224            hh.client_auth
225                .as_ref()
226                .map(|buf| buf.len()),
227            Some(5)
228        );
229        hh.abandon_client_auth();
230        assert_eq!(hh.client_auth, None);
231        hh.update_raw(b"world");
232        assert_eq!(hh.client_auth, None);
233        let h = hh.get_current_hash();
234        let h = h.as_ref();
235        assert_eq!(h[0], 0x93);
236        assert_eq!(h[1], 0x6a);
237        assert_eq!(h[2], 0x18);
238        assert_eq!(h[3], 0x5c);
239    }
240}