p3_challenger/
multi_field_challenger.rs

1use alloc::string::String;
2use alloc::vec;
3use alloc::vec::Vec;
4
5use p3_field::{reduce_32, split_32, ExtensionField, Field, PrimeField, PrimeField32};
6use p3_symmetric::{CryptographicPermutation, Hash};
7
8use crate::{CanObserve, CanSample, CanSampleBits, FieldChallenger};
9
10/// A challenger that operates natively on PF but produces challenges of F: PrimeField32.
11///
12/// Used for optimizing the cost of recursive proof verification of STARKs in SNARKs.
13///
14/// SAFETY: There are some bias complications with using this challenger. In particular,
15/// samples are actually random in [0, 2^64) and then reduced to be in F.
16#[derive(Clone, Debug)]
17pub struct MultiField32Challenger<F, PF, P, const WIDTH: usize, const RATE: usize>
18where
19    F: PrimeField32,
20    PF: Field,
21    P: CryptographicPermutation<[PF; WIDTH]>,
22{
23    sponge_state: [PF; WIDTH],
24    input_buffer: Vec<F>,
25    output_buffer: Vec<F>,
26    permutation: P,
27    num_f_elms: usize,
28}
29
30impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
31where
32    F: PrimeField32,
33    PF: Field,
34    P: CryptographicPermutation<[PF; WIDTH]>,
35{
36    pub fn new(permutation: P) -> Result<Self, String> {
37        if F::order() >= PF::order() {
38            return Err(String::from("F::order() must be less than PF::order()"));
39        }
40        let num_f_elms = PF::bits() / 64;
41        Ok(Self {
42            sponge_state: [PF::default(); WIDTH],
43            input_buffer: vec![],
44            output_buffer: vec![],
45            permutation,
46            num_f_elms,
47        })
48    }
49}
50
51impl<F, PF, P, const WIDTH: usize, const RATE: usize> MultiField32Challenger<F, PF, P, WIDTH, RATE>
52where
53    F: PrimeField32,
54    PF: PrimeField,
55    P: CryptographicPermutation<[PF; WIDTH]>,
56{
57    fn duplexing(&mut self) {
58        assert!(self.input_buffer.len() <= self.num_f_elms * RATE);
59
60        for (i, f_chunk) in self.input_buffer.chunks(self.num_f_elms).enumerate() {
61            self.sponge_state[i] = reduce_32(f_chunk);
62        }
63        self.input_buffer.clear();
64
65        // Apply the permutation.
66        self.permutation.permute_mut(&mut self.sponge_state);
67
68        self.output_buffer.clear();
69        for &pf_val in self.sponge_state.iter() {
70            let f_vals = split_32(pf_val, self.num_f_elms);
71            for f_val in f_vals {
72                self.output_buffer.push(f_val);
73            }
74        }
75    }
76}
77
78impl<F, PF, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
79    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
80where
81    F: PrimeField32,
82    PF: PrimeField,
83    P: CryptographicPermutation<[PF; WIDTH]>,
84{
85}
86
87impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
88    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
89where
90    F: PrimeField32,
91    PF: PrimeField,
92    P: CryptographicPermutation<[PF; WIDTH]>,
93{
94    fn observe(&mut self, value: F) {
95        // Any buffered output is now invalid.
96        self.output_buffer.clear();
97
98        self.input_buffer.push(value);
99
100        if self.input_buffer.len() == self.num_f_elms * RATE {
101            self.duplexing();
102        }
103    }
104}
105
106impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
107    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
108where
109    F: PrimeField32,
110    PF: PrimeField,
111    P: CryptographicPermutation<[PF; WIDTH]>,
112{
113    fn observe(&mut self, values: [F; N]) {
114        for value in values {
115            self.observe(value);
116        }
117    }
118}
119
120impl<F, PF, const N: usize, P, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, PF, N>>
121    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
122where
123    F: PrimeField32,
124    PF: PrimeField,
125    P: CryptographicPermutation<[PF; WIDTH]>,
126{
127    fn observe(&mut self, values: Hash<F, PF, N>) {
128        for pf_val in values {
129            let f_vals: Vec<F> = split_32(pf_val, self.num_f_elms);
130            for f_val in f_vals {
131                self.observe(f_val);
132            }
133        }
134    }
135}
136
137// for TrivialPcs
138impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
139    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
140where
141    F: PrimeField32,
142    PF: PrimeField,
143    P: CryptographicPermutation<[PF; WIDTH]>,
144{
145    fn observe(&mut self, valuess: Vec<Vec<F>>) {
146        for values in valuess {
147            for value in values {
148                self.observe(value);
149            }
150        }
151    }
152}
153
154impl<F, EF, PF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
155    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
156where
157    F: PrimeField32,
158    EF: ExtensionField<F>,
159    PF: PrimeField,
160    P: CryptographicPermutation<[PF; WIDTH]>,
161{
162    fn sample(&mut self) -> EF {
163        EF::from_base_fn(|_| {
164            // If we have buffered inputs, we must perform a duplexing so that the challenge will
165            // reflect them. Or if we've run out of outputs, we must perform a duplexing to get more.
166            if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
167                self.duplexing();
168            }
169
170            self.output_buffer
171                .pop()
172                .expect("Output buffer should be non-empty")
173        })
174    }
175}
176
177impl<F, PF, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
178    for MultiField32Challenger<F, PF, P, WIDTH, RATE>
179where
180    F: PrimeField32,
181    PF: PrimeField,
182    P: CryptographicPermutation<[PF; WIDTH]>,
183{
184    fn sample_bits(&mut self, bits: usize) -> usize {
185        assert!(bits < (usize::BITS as usize));
186        assert!((1 << bits) < F::ORDER_U32);
187        let rand_f: F = self.sample();
188        let rand_usize = rand_f.as_canonical_u32() as usize;
189        rand_usize & ((1 << bits) - 1)
190    }
191}