1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::CryptographicHasher;
5
6use crate::{CanObserve, CanSample};
7
8#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12 T: Clone,
13 H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15 input_buffer: Vec<T>,
17 output_buffer: Vec<T>,
19 hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25 T: Clone,
26 H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28 pub fn new(initial_state: Vec<T>, hasher: H) -> Self {
29 Self {
30 input_buffer: initial_state,
31 output_buffer: vec![],
32 hasher,
33 }
34 }
35
36 fn flush(&mut self) {
37 let inputs = self.input_buffer.drain(..);
38 let output = self.hasher.hash_iter(inputs);
39
40 self.output_buffer = output.to_vec();
41
42 self.input_buffer.extend(output.to_vec());
44 }
45}
46
47impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
48where
49 T: Clone,
50 H: CryptographicHasher<T, [T; OUT_LEN]>,
51{
52 fn observe(&mut self, value: T) {
53 self.output_buffer.clear();
55
56 self.input_buffer.push(value);
57 }
58}
59
60impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
61 for HashChallenger<T, H, OUT_LEN>
62where
63 T: Clone,
64 H: CryptographicHasher<T, [T; OUT_LEN]>,
65{
66 fn observe(&mut self, values: [T; N]) {
67 for value in values {
68 self.observe(value);
69 }
70 }
71}
72
73impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
74where
75 T: Clone,
76 H: CryptographicHasher<T, [T; OUT_LEN]>,
77{
78 fn sample(&mut self) -> T {
79 if self.output_buffer.is_empty() {
80 self.flush();
81 }
82 self.output_buffer
83 .pop()
84 .expect("Output buffer should be non-empty")
85 }
86}
87
88#[cfg(test)]
89mod tests {
90 use p3_field::FieldAlgebra;
91 use p3_goldilocks::Goldilocks;
92
93 use super::*;
94
95 const OUT_LEN: usize = 2;
96 type F = Goldilocks;
97
98 #[derive(Clone)]
99 struct TestHasher {}
100
101 impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
102 fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
105 where
106 I: IntoIterator<Item = F>,
107 {
108 let (sum, len) = input
109 .into_iter()
110 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
111 (acc_sum + f, acc_len + 1)
112 });
113 [sum, F::from_canonical_usize(len)]
114 }
115
116 fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
119 where
120 I: IntoIterator<Item = &'a [F]>,
121 F: 'a,
122 {
123 let (sum, len) = input
124 .into_iter()
125 .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
126 (
127 acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
128 acc_len + n.len(),
129 )
130 });
131 [sum, F::from_canonical_usize(len)]
132 }
133 }
134
135 #[test]
136 fn test_hash_challenger() {
137 let initial_state = (1..11_u8).map(F::from_canonical_u8).collect::<Vec<_>>();
138 let test_hasher = TestHasher {};
139 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
140
141 assert_eq!(hash_challenger.input_buffer, initial_state);
142 assert_eq!(hash_challenger.output_buffer, vec![]);
143
144 hash_challenger.flush();
145
146 let expected_sum = F::from_canonical_u8(55);
147 let expected_len = F::from_canonical_u8(10);
148 assert_eq!(
149 hash_challenger.input_buffer,
150 vec![expected_sum, expected_len]
151 );
152 assert_eq!(
153 hash_challenger.output_buffer,
154 vec![expected_sum, expected_len]
155 );
156
157 let new_element = F::from_canonical_u8(11);
158 hash_challenger.observe(new_element);
159 assert_eq!(
160 hash_challenger.input_buffer,
161 vec![expected_sum, expected_len, new_element]
162 );
163 assert_eq!(hash_challenger.output_buffer, vec![]);
164
165 let new_expected_len = 3;
166 let new_expected_sum = 76;
167
168 let new_element = hash_challenger.sample();
169 assert_eq!(new_element, F::from_canonical_u8(new_expected_len));
170 assert_eq!(
171 hash_challenger.output_buffer,
172 [F::from_canonical_u8(new_expected_sum)]
173 )
174 }
175
176 #[test]
177 fn test_hash_challenger_flush() {
178 let initial_state = (1..11_u8).map(F::from_canonical_u8).collect::<Vec<_>>();
179 let test_hasher = TestHasher {};
180 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
181
182 let first_sample = hash_challenger.sample();
184
185 let second_sample = hash_challenger.sample();
186
187 assert_eq!(first_sample, F::from_canonical_u8(10));
189 assert_eq!(second_sample, F::from_canonical_u8(55));
191
192 assert!(hash_challenger.output_buffer.is_empty());
194 }
195
196 #[test]
197 fn test_observe_single_value() {
198 let test_hasher = TestHasher {};
199 let mut hash_challenger = HashChallenger::new(vec![F::from_canonical_u8(123)], test_hasher);
201
202 let value = F::from_canonical_u8(42);
204 hash_challenger.observe(value);
205
206 assert_eq!(
208 hash_challenger.input_buffer,
209 vec![F::from_canonical_u8(123), F::from_canonical_u8(42)]
210 );
211 assert!(hash_challenger.output_buffer.is_empty());
213 }
214
215 #[test]
216 fn test_observe_array() {
217 let test_hasher = TestHasher {};
218 let mut hash_challenger = HashChallenger::new(vec![F::from_canonical_u8(123)], test_hasher);
220
221 let values = [
223 F::from_canonical_u8(1),
224 F::from_canonical_u8(2),
225 F::from_canonical_u8(3),
226 ];
227 hash_challenger.observe(values);
228
229 assert_eq!(
231 hash_challenger.input_buffer,
232 vec![
233 F::from_canonical_u8(123),
234 F::from_canonical_u8(1),
235 F::from_canonical_u8(2),
236 F::from_canonical_u8(3)
237 ]
238 );
239 assert!(hash_challenger.output_buffer.is_empty());
241 }
242
243 #[test]
244 fn test_sample_output_buffer() {
245 let test_hasher = TestHasher {};
246 let initial_state = vec![F::from_canonical_u8(5), F::from_canonical_u8(10)];
247 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
248
249 let sample = hash_challenger.sample();
250 assert_eq!(sample, F::from_canonical_u8(2));
252 assert_eq!(
254 hash_challenger.output_buffer,
255 vec![F::from_canonical_u8(15)]
256 );
257 }
258
259 #[test]
260 fn test_flush_empty_buffer() {
261 let test_hasher = TestHasher {};
262 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
263
264 hash_challenger.flush();
266
267 assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
269 assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
270 }
271
272 #[test]
273 fn test_flush_with_data() {
274 let test_hasher = TestHasher {};
275 let initial_state = vec![F::from_canonical_u8(1), F::from_canonical_u8(2)];
277 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
278
279 hash_challenger.flush();
280
281 assert_eq!(
283 hash_challenger.input_buffer,
284 vec![F::from_canonical_u8(3), F::from_canonical_u8(2)]
285 );
286 assert_eq!(
288 hash_challenger.output_buffer,
289 vec![F::from_canonical_u8(3), F::from_canonical_u8(2)]
290 );
291 }
292
293 #[test]
294 fn test_sample_after_observe() {
295 let test_hasher = TestHasher {};
296 let initial_state = vec![F::from_canonical_u8(1), F::from_canonical_u8(2)];
297 let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
298
299 hash_challenger.observe(F::from_canonical_u8(3));
301
302 assert!(hash_challenger.output_buffer.is_empty());
304
305 assert_eq!(
307 hash_challenger.input_buffer,
308 vec![
309 F::from_canonical_u8(1),
310 F::from_canonical_u8(2),
311 F::from_canonical_u8(3)
312 ]
313 );
314
315 let sample = hash_challenger.sample();
316
317 assert_eq!(sample, F::from_canonical_u8(3));
319 }
320
321 #[test]
322 fn test_sample_with_non_empty_output_buffer() {
323 let test_hasher = TestHasher {};
324 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
325
326 hash_challenger.output_buffer = vec![F::from_canonical_u8(42), F::from_canonical_u8(24)];
327
328 let sample = hash_challenger.sample();
329
330 assert_eq!(sample, F::from_canonical_u8(24));
332
333 assert_eq!(
335 hash_challenger.output_buffer,
336 vec![F::from_canonical_u8(42)]
337 );
338 }
339
340 #[test]
341 fn test_output_buffer_cleared_on_observe() {
342 let test_hasher = TestHasher {};
343 let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
344
345 hash_challenger.output_buffer.push(F::from_canonical_u8(42));
347
348 assert!(!hash_challenger.output_buffer.is_empty());
350
351 hash_challenger.observe(F::from_canonical_u8(3));
353
354 assert!(hash_challenger.output_buffer.is_empty());
356 }
357}