p3_challenger/
hash_challenger.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::CryptographicHasher;
5
6use crate::{CanObserve, CanSample};
7
8/// A generic challenger that uses a cryptographic hash function to generate challenges.
9#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12    T: Clone,
13    H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15    /// Buffer to store observed values before hashing.
16    input_buffer: Vec<T>,
17    /// Buffer to store hashed output values, which are consumed when sampling.
18    output_buffer: Vec<T>,
19    /// The cryptographic hash function used for generating challenges.
20    hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25    T: Clone,
26    H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28    pub fn new(initial_state: Vec<T>, hasher: H) -> Self {
29        Self {
30            input_buffer: initial_state,
31            output_buffer: vec![],
32            hasher,
33        }
34    }
35
36    fn flush(&mut self) {
37        let inputs = self.input_buffer.drain(..);
38        let output = self.hasher.hash_iter(inputs);
39
40        self.output_buffer = output.to_vec();
41
42        // Chaining values.
43        self.input_buffer.extend(output.to_vec());
44    }
45}
46
47impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
48where
49    T: Clone,
50    H: CryptographicHasher<T, [T; OUT_LEN]>,
51{
52    fn observe(&mut self, value: T) {
53        // Any buffered output is now invalid.
54        self.output_buffer.clear();
55
56        self.input_buffer.push(value);
57    }
58}
59
60impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
61    for HashChallenger<T, H, OUT_LEN>
62where
63    T: Clone,
64    H: CryptographicHasher<T, [T; OUT_LEN]>,
65{
66    fn observe(&mut self, values: [T; N]) {
67        for value in values {
68            self.observe(value);
69        }
70    }
71}
72
73impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
74where
75    T: Clone,
76    H: CryptographicHasher<T, [T; OUT_LEN]>,
77{
78    fn sample(&mut self) -> T {
79        if self.output_buffer.is_empty() {
80            self.flush();
81        }
82        self.output_buffer
83            .pop()
84            .expect("Output buffer should be non-empty")
85    }
86}
87
88#[cfg(test)]
89mod tests {
90    use p3_field::FieldAlgebra;
91    use p3_goldilocks::Goldilocks;
92
93    use super::*;
94
95    const OUT_LEN: usize = 2;
96    type F = Goldilocks;
97
98    #[derive(Clone)]
99    struct TestHasher {}
100
101    impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
102        /// A very simple hash iterator. From an input of type `IntoIterator<Item = Goldilocks>`,
103        /// it outputs the sum of its elements and its length (as a field element).
104        fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
105        where
106            I: IntoIterator<Item = F>,
107        {
108            let (sum, len) = input
109                .into_iter()
110                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
111                    (acc_sum + f, acc_len + 1)
112                });
113            [sum, F::from_canonical_usize(len)]
114        }
115
116        /// A very simple slice hash iterator. From an input of type `IntoIterator<Item = &'a [Goldilocks]>`,
117        /// it outputs the sum of its elements and its length (as a field element).
118        fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
119        where
120            I: IntoIterator<Item = &'a [F]>,
121            F: 'a,
122        {
123            let (sum, len) = input
124                .into_iter()
125                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
126                    (
127                        acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
128                        acc_len + n.len(),
129                    )
130                });
131            [sum, F::from_canonical_usize(len)]
132        }
133    }
134
135    #[test]
136    fn test_hash_challenger() {
137        let initial_state = (1..11_u8).map(F::from_canonical_u8).collect::<Vec<_>>();
138        let test_hasher = TestHasher {};
139        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
140
141        assert_eq!(hash_challenger.input_buffer, initial_state);
142        assert_eq!(hash_challenger.output_buffer, vec![]);
143
144        hash_challenger.flush();
145
146        let expected_sum = F::from_canonical_u8(55);
147        let expected_len = F::from_canonical_u8(10);
148        assert_eq!(
149            hash_challenger.input_buffer,
150            vec![expected_sum, expected_len]
151        );
152        assert_eq!(
153            hash_challenger.output_buffer,
154            vec![expected_sum, expected_len]
155        );
156
157        let new_element = F::from_canonical_u8(11);
158        hash_challenger.observe(new_element);
159        assert_eq!(
160            hash_challenger.input_buffer,
161            vec![expected_sum, expected_len, new_element]
162        );
163        assert_eq!(hash_challenger.output_buffer, vec![]);
164
165        let new_expected_len = 3;
166        let new_expected_sum = 76;
167
168        let new_element = hash_challenger.sample();
169        assert_eq!(new_element, F::from_canonical_u8(new_expected_len));
170        assert_eq!(
171            hash_challenger.output_buffer,
172            [F::from_canonical_u8(new_expected_sum)]
173        )
174    }
175
176    #[test]
177    fn test_hash_challenger_flush() {
178        let initial_state = (1..11_u8).map(F::from_canonical_u8).collect::<Vec<_>>();
179        let test_hasher = TestHasher {};
180        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
181
182        // Sample twice to ensure flush happens
183        let first_sample = hash_challenger.sample();
184
185        let second_sample = hash_challenger.sample();
186
187        // Verify that the first sample is the length of 1..11, (i.e. 10).
188        assert_eq!(first_sample, F::from_canonical_u8(10));
189        //  Verify that the second sample is the sum of numbers from 1 to 10 (i.e. 55)
190        assert_eq!(second_sample, F::from_canonical_u8(55));
191
192        // Verify that the output buffer is now empty
193        assert!(hash_challenger.output_buffer.is_empty());
194    }
195
196    #[test]
197    fn test_observe_single_value() {
198        let test_hasher = TestHasher {};
199        // Initial state non-empty
200        let mut hash_challenger = HashChallenger::new(vec![F::from_canonical_u8(123)], test_hasher);
201
202        // Observe a single value
203        let value = F::from_canonical_u8(42);
204        hash_challenger.observe(value);
205
206        // Check that the input buffer contains the initial and observed values
207        assert_eq!(
208            hash_challenger.input_buffer,
209            vec![F::from_canonical_u8(123), F::from_canonical_u8(42)]
210        );
211        // Check that the output buffer is empty (clears after observation)
212        assert!(hash_challenger.output_buffer.is_empty());
213    }
214
215    #[test]
216    fn test_observe_array() {
217        let test_hasher = TestHasher {};
218        // Initial state non-empty
219        let mut hash_challenger = HashChallenger::new(vec![F::from_canonical_u8(123)], test_hasher);
220
221        // Observe an array of values
222        let values = [
223            F::from_canonical_u8(1),
224            F::from_canonical_u8(2),
225            F::from_canonical_u8(3),
226        ];
227        hash_challenger.observe(values);
228
229        // Check that the input buffer contains the values
230        assert_eq!(
231            hash_challenger.input_buffer,
232            vec![
233                F::from_canonical_u8(123),
234                F::from_canonical_u8(1),
235                F::from_canonical_u8(2),
236                F::from_canonical_u8(3)
237            ]
238        );
239        // Check that the output buffer is empty (clears after observation)
240        assert!(hash_challenger.output_buffer.is_empty());
241    }
242
243    #[test]
244    fn test_sample_output_buffer() {
245        let test_hasher = TestHasher {};
246        let initial_state = vec![F::from_canonical_u8(5), F::from_canonical_u8(10)];
247        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
248
249        let sample = hash_challenger.sample();
250        // Verify that the sample is the length of the initial state
251        assert_eq!(sample, F::from_canonical_u8(2));
252        // Check that the output buffer contains the sum of the initial state
253        assert_eq!(
254            hash_challenger.output_buffer,
255            vec![F::from_canonical_u8(15)]
256        );
257    }
258
259    #[test]
260    fn test_flush_empty_buffer() {
261        let test_hasher = TestHasher {};
262        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
263
264        // Flush empty buffer
265        hash_challenger.flush();
266
267        // Check that the input and output buffers contain the sum and length of the empty buffer
268        assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
269        assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
270    }
271
272    #[test]
273    fn test_flush_with_data() {
274        let test_hasher = TestHasher {};
275        // Initial state non-empty
276        let initial_state = vec![F::from_canonical_u8(1), F::from_canonical_u8(2)];
277        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
278
279        hash_challenger.flush();
280
281        // Check that the input buffer contains the sum and length of the initial state
282        assert_eq!(
283            hash_challenger.input_buffer,
284            vec![F::from_canonical_u8(3), F::from_canonical_u8(2)]
285        );
286        // Check that the output buffer contains the sum and length of the initial state
287        assert_eq!(
288            hash_challenger.output_buffer,
289            vec![F::from_canonical_u8(3), F::from_canonical_u8(2)]
290        );
291    }
292
293    #[test]
294    fn test_sample_after_observe() {
295        let test_hasher = TestHasher {};
296        let initial_state = vec![F::from_canonical_u8(1), F::from_canonical_u8(2)];
297        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
298
299        // Observe will clear the output buffer
300        hash_challenger.observe(F::from_canonical_u8(3));
301
302        // Verify that the output buffer is empty
303        assert!(hash_challenger.output_buffer.is_empty());
304
305        // Verify the new value is in the input buffer
306        assert_eq!(
307            hash_challenger.input_buffer,
308            vec![
309                F::from_canonical_u8(1),
310                F::from_canonical_u8(2),
311                F::from_canonical_u8(3)
312            ]
313        );
314
315        let sample = hash_challenger.sample();
316
317        // Length of initial state + observed value
318        assert_eq!(sample, F::from_canonical_u8(3));
319    }
320
321    #[test]
322    fn test_sample_with_non_empty_output_buffer() {
323        let test_hasher = TestHasher {};
324        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
325
326        hash_challenger.output_buffer = vec![F::from_canonical_u8(42), F::from_canonical_u8(24)];
327
328        let sample = hash_challenger.sample();
329
330        // Sample will pop the last element from the output buffer
331        assert_eq!(sample, F::from_canonical_u8(24));
332
333        // Check that the output buffer is now one element shorter
334        assert_eq!(
335            hash_challenger.output_buffer,
336            vec![F::from_canonical_u8(42)]
337        );
338    }
339
340    #[test]
341    fn test_output_buffer_cleared_on_observe() {
342        let test_hasher = TestHasher {};
343        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
344
345        // Populate artificially the output buffer
346        hash_challenger.output_buffer.push(F::from_canonical_u8(42));
347
348        // Ensure the output buffer is populated
349        assert!(!hash_challenger.output_buffer.is_empty());
350
351        // Observe a new value
352        hash_challenger.observe(F::from_canonical_u8(3));
353
354        // Verify that the output buffer is cleared after observing
355        assert!(hash_challenger.output_buffer.is_empty());
356    }
357}