p3_challenger/
hash_challenger.rs

1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_symmetric::CryptographicHasher;
5
6use crate::{CanObserve, CanSample};
7
8/// A generic challenger that uses a cryptographic hash function to generate challenges.
9#[derive(Clone, Debug)]
10pub struct HashChallenger<T, H, const OUT_LEN: usize>
11where
12    T: Clone,
13    H: CryptographicHasher<T, [T; OUT_LEN]>,
14{
15    /// Buffer to store observed values before hashing.
16    input_buffer: Vec<T>,
17    /// Buffer to store hashed output values, which are consumed when sampling.
18    output_buffer: Vec<T>,
19    /// The cryptographic hash function used for generating challenges.
20    hasher: H,
21}
22
23impl<T, H, const OUT_LEN: usize> HashChallenger<T, H, OUT_LEN>
24where
25    T: Clone,
26    H: CryptographicHasher<T, [T; OUT_LEN]>,
27{
28    pub const fn new(initial_state: Vec<T>, hasher: H) -> Self {
29        Self {
30            input_buffer: initial_state,
31            output_buffer: vec![],
32            hasher,
33        }
34    }
35
36    fn flush(&mut self) {
37        let inputs = self.input_buffer.drain(..);
38        let output = self.hasher.hash_iter(inputs);
39
40        // Chaining values.
41        self.input_buffer.extend_from_slice(&output);
42        self.output_buffer = output.into();
43    }
44}
45
46impl<T, H, const OUT_LEN: usize> CanObserve<T> for HashChallenger<T, H, OUT_LEN>
47where
48    T: Clone,
49    H: CryptographicHasher<T, [T; OUT_LEN]>,
50{
51    fn observe(&mut self, value: T) {
52        // Any buffered output is now invalid.
53        self.output_buffer.clear();
54
55        self.input_buffer.push(value);
56    }
57}
58
59impl<T, H, const N: usize, const OUT_LEN: usize> CanObserve<[T; N]>
60    for HashChallenger<T, H, OUT_LEN>
61where
62    T: Clone,
63    H: CryptographicHasher<T, [T; OUT_LEN]>,
64{
65    fn observe(&mut self, values: [T; N]) {
66        if N == 0 {
67            return;
68        }
69
70        self.output_buffer.clear();
71        self.input_buffer.extend(values);
72    }
73}
74
75impl<T, H, const OUT_LEN: usize> CanSample<T> for HashChallenger<T, H, OUT_LEN>
76where
77    T: Clone,
78    H: CryptographicHasher<T, [T; OUT_LEN]>,
79{
80    fn sample(&mut self) -> T {
81        if self.output_buffer.is_empty() {
82            self.flush();
83        }
84        self.output_buffer
85            .pop()
86            .expect("Output buffer should be non-empty")
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use p3_field::PrimeCharacteristicRing;
93    use p3_goldilocks::Goldilocks;
94
95    use super::*;
96
97    const OUT_LEN: usize = 2;
98    type F = Goldilocks;
99
100    #[derive(Clone)]
101    struct TestHasher {}
102
103    impl CryptographicHasher<F, [F; OUT_LEN]> for TestHasher {
104        /// A very simple hash iterator. From an input of type `IntoIterator<Item = Goldilocks>`,
105        /// it outputs the sum of its elements and its length (as a field element).
106        fn hash_iter<I>(&self, input: I) -> [F; OUT_LEN]
107        where
108            I: IntoIterator<Item = F>,
109        {
110            let (sum, len) = input
111                .into_iter()
112                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), f| {
113                    (acc_sum + f, acc_len + 1)
114                });
115            [sum, F::from_usize(len)]
116        }
117
118        /// A very simple slice hash iterator. From an input of type `IntoIterator<Item = &'a [Goldilocks]>`,
119        /// it outputs the sum of its elements and its length (as a field element).
120        fn hash_iter_slices<'a, I>(&self, input: I) -> [F; OUT_LEN]
121        where
122            I: IntoIterator<Item = &'a [F]>,
123            F: 'a,
124        {
125            let (sum, len) = input
126                .into_iter()
127                .fold((F::ZERO, 0_usize), |(acc_sum, acc_len), n| {
128                    (
129                        acc_sum + n.iter().fold(F::ZERO, |acc, f| acc + *f),
130                        acc_len + n.len(),
131                    )
132                });
133            [sum, F::from_usize(len)]
134        }
135    }
136
137    #[test]
138    fn test_hash_challenger() {
139        let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
140        let test_hasher = TestHasher {};
141        let mut hash_challenger = HashChallenger::new(initial_state.clone(), test_hasher);
142
143        assert_eq!(hash_challenger.input_buffer, initial_state);
144        assert_eq!(hash_challenger.output_buffer, vec![]);
145
146        hash_challenger.flush();
147
148        let expected_sum = F::from_u8(55);
149        let expected_len = F::from_u8(10);
150        assert_eq!(
151            hash_challenger.input_buffer,
152            vec![expected_sum, expected_len]
153        );
154        assert_eq!(
155            hash_challenger.output_buffer,
156            vec![expected_sum, expected_len]
157        );
158
159        let new_element = F::from_u8(11);
160        hash_challenger.observe(new_element);
161        assert_eq!(
162            hash_challenger.input_buffer,
163            vec![expected_sum, expected_len, new_element]
164        );
165        assert_eq!(hash_challenger.output_buffer, vec![]);
166
167        let new_expected_len = 3;
168        let new_expected_sum = 76;
169
170        let new_element = hash_challenger.sample();
171        assert_eq!(new_element, F::from_u8(new_expected_len));
172        assert_eq!(
173            hash_challenger.output_buffer,
174            [F::from_u8(new_expected_sum)]
175        );
176    }
177
178    #[test]
179    fn test_hash_challenger_flush() {
180        let initial_state = (1..11_u8).map(F::from_u8).collect::<Vec<_>>();
181        let test_hasher = TestHasher {};
182        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
183
184        // Sample twice to ensure flush happens
185        let first_sample = hash_challenger.sample();
186
187        let second_sample = hash_challenger.sample();
188
189        // Verify that the first sample is the length of 1..11, (i.e. 10).
190        assert_eq!(first_sample, F::from_u8(10));
191        //  Verify that the second sample is the sum of numbers from 1 to 10 (i.e. 55)
192        assert_eq!(second_sample, F::from_u8(55));
193
194        // Verify that the output buffer is now empty
195        assert!(hash_challenger.output_buffer.is_empty());
196    }
197
198    #[test]
199    fn test_observe_single_value() {
200        let test_hasher = TestHasher {};
201        // Initial state non-empty
202        let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
203
204        // Observe a single value
205        let value = F::from_u8(42);
206        hash_challenger.observe(value);
207
208        // Check that the input buffer contains the initial and observed values
209        assert_eq!(
210            hash_challenger.input_buffer,
211            vec![F::from_u8(123), F::from_u8(42)]
212        );
213        // Check that the output buffer is empty (clears after observation)
214        assert!(hash_challenger.output_buffer.is_empty());
215    }
216
217    #[test]
218    fn test_observe_array() {
219        let test_hasher = TestHasher {};
220        // Initial state non-empty
221        let mut hash_challenger = HashChallenger::new(vec![F::from_u8(123)], test_hasher);
222
223        // Observe an array of values
224        let values = [F::from_u8(1), F::from_u8(2), F::from_u8(3)];
225        hash_challenger.observe(values);
226
227        // Check that the input buffer contains the values
228        assert_eq!(
229            hash_challenger.input_buffer,
230            vec![F::from_u8(123), F::from_u8(1), F::from_u8(2), F::from_u8(3)]
231        );
232        // Check that the output buffer is empty (clears after observation)
233        assert!(hash_challenger.output_buffer.is_empty());
234    }
235
236    #[test]
237    fn test_sample_output_buffer() {
238        let test_hasher = TestHasher {};
239        let initial_state = vec![F::from_u8(5), F::from_u8(10)];
240        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
241
242        let sample = hash_challenger.sample();
243        // Verify that the sample is the length of the initial state
244        assert_eq!(sample, F::from_u8(2));
245        // Check that the output buffer contains the sum of the initial state
246        assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(15)]);
247    }
248
249    #[test]
250    fn test_flush_empty_buffer() {
251        let test_hasher = TestHasher {};
252        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
253
254        // Flush empty buffer
255        hash_challenger.flush();
256
257        // Check that the input and output buffers contain the sum and length of the empty buffer
258        assert_eq!(hash_challenger.input_buffer, vec![F::ZERO, F::ZERO]);
259        assert_eq!(hash_challenger.output_buffer, vec![F::ZERO, F::ZERO]);
260    }
261
262    #[test]
263    fn test_flush_with_data() {
264        let test_hasher = TestHasher {};
265        // Initial state non-empty
266        let initial_state = vec![F::from_u8(1), F::from_u8(2)];
267        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
268
269        hash_challenger.flush();
270
271        // Check that the input buffer contains the sum and length of the initial state
272        assert_eq!(
273            hash_challenger.input_buffer,
274            vec![F::from_u8(3), F::from_u8(2)]
275        );
276        // Check that the output buffer contains the sum and length of the initial state
277        assert_eq!(
278            hash_challenger.output_buffer,
279            vec![F::from_u8(3), F::from_u8(2)]
280        );
281    }
282
283    #[test]
284    fn test_sample_after_observe() {
285        let test_hasher = TestHasher {};
286        let initial_state = vec![F::from_u8(1), F::from_u8(2)];
287        let mut hash_challenger = HashChallenger::new(initial_state, test_hasher);
288
289        // Observe will clear the output buffer
290        hash_challenger.observe(F::from_u8(3));
291
292        // Verify that the output buffer is empty
293        assert!(hash_challenger.output_buffer.is_empty());
294
295        // Verify the new value is in the input buffer
296        assert_eq!(
297            hash_challenger.input_buffer,
298            vec![F::from_u8(1), F::from_u8(2), F::from_u8(3)]
299        );
300
301        let sample = hash_challenger.sample();
302
303        // Length of initial state + observed value
304        assert_eq!(sample, F::from_u8(3));
305    }
306
307    #[test]
308    fn test_sample_with_non_empty_output_buffer() {
309        let test_hasher = TestHasher {};
310        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
311
312        hash_challenger.output_buffer = vec![F::from_u8(42), F::from_u8(24)];
313
314        let sample = hash_challenger.sample();
315
316        // Sample will pop the last element from the output buffer
317        assert_eq!(sample, F::from_u8(24));
318
319        // Check that the output buffer is now one element shorter
320        assert_eq!(hash_challenger.output_buffer, vec![F::from_u8(42)]);
321    }
322
323    #[test]
324    fn test_output_buffer_cleared_on_observe() {
325        let test_hasher = TestHasher {};
326        let mut hash_challenger = HashChallenger::new(vec![], test_hasher);
327
328        // Populate artificially the output buffer
329        hash_challenger.output_buffer.push(F::from_u8(42));
330
331        // Ensure the output buffer is populated
332        assert!(!hash_challenger.output_buffer.is_empty());
333
334        // Observe a new value
335        hash_challenger.observe(F::from_u8(3));
336
337        // Verify that the output buffer is cleared after observing
338        assert!(hash_challenger.output_buffer.is_empty());
339    }
340}