1use alloc::vec;
2use alloc::vec::Vec;
3
4use p3_field::{ExtensionField, Field, PrimeField64};
5use p3_symmetric::{CryptographicPermutation, Hash};
6
7use crate::{CanObserve, CanSample, CanSampleBits, FieldChallenger};
8
9#[derive(Clone, Debug)]
10pub struct DuplexChallenger<F, P, const WIDTH: usize, const RATE: usize>
11where
12 F: Clone,
13 P: CryptographicPermutation<[F; WIDTH]>,
14{
15 pub sponge_state: [F; WIDTH],
16 pub input_buffer: Vec<F>,
17 pub output_buffer: Vec<F>,
18 pub permutation: P,
19}
20
21impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
22where
23 F: Copy,
24 P: CryptographicPermutation<[F; WIDTH]>,
25{
26 pub fn new(permutation: P) -> Self
27 where
28 F: Default,
29 {
30 Self {
31 sponge_state: [F::default(); WIDTH],
32 input_buffer: vec![],
33 output_buffer: vec![],
34 permutation,
35 }
36 }
37
38 fn duplexing(&mut self) {
39 assert!(self.input_buffer.len() <= RATE);
40
41 for (i, val) in self.input_buffer.drain(..).enumerate() {
43 self.sponge_state[i] = val;
44 }
45
46 self.permutation.permute_mut(&mut self.sponge_state);
48
49 self.output_buffer.clear();
50 self.output_buffer.extend(&self.sponge_state[..RATE]);
51 }
52}
53
54impl<F, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
55 for DuplexChallenger<F, P, WIDTH, RATE>
56where
57 F: PrimeField64,
58 P: CryptographicPermutation<[F; WIDTH]>,
59{
60}
61
62impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
63 for DuplexChallenger<F, P, WIDTH, RATE>
64where
65 F: Copy,
66 P: CryptographicPermutation<[F; WIDTH]>,
67{
68 fn observe(&mut self, value: F) {
69 self.output_buffer.clear();
71
72 self.input_buffer.push(value);
73
74 if self.input_buffer.len() == RATE {
75 self.duplexing();
76 }
77 }
78}
79
80impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
81 for DuplexChallenger<F, P, WIDTH, RATE>
82where
83 F: Copy,
84 P: CryptographicPermutation<[F; WIDTH]>,
85{
86 fn observe(&mut self, values: [F; N]) {
87 for value in values {
88 self.observe(value);
89 }
90 }
91}
92
93impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, F, N>>
94 for DuplexChallenger<F, P, WIDTH, RATE>
95where
96 F: Copy,
97 P: CryptographicPermutation<[F; WIDTH]>,
98{
99 fn observe(&mut self, values: Hash<F, F, N>) {
100 for value in values {
101 self.observe(value);
102 }
103 }
104}
105
106impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
108 for DuplexChallenger<F, P, WIDTH, RATE>
109where
110 F: Copy,
111 P: CryptographicPermutation<[F; WIDTH]>,
112{
113 fn observe(&mut self, valuess: Vec<Vec<F>>) {
114 for values in valuess {
115 for value in values {
116 self.observe(value);
117 }
118 }
119 }
120}
121
122impl<F, EF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
123 for DuplexChallenger<F, P, WIDTH, RATE>
124where
125 F: Field,
126 EF: ExtensionField<F>,
127 P: CryptographicPermutation<[F; WIDTH]>,
128{
129 fn sample(&mut self) -> EF {
130 EF::from_base_fn(|_| {
131 if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
134 self.duplexing();
135 }
136
137 self.output_buffer
138 .pop()
139 .expect("Output buffer should be non-empty")
140 })
141 }
142}
143
144impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
145 for DuplexChallenger<F, P, WIDTH, RATE>
146where
147 F: PrimeField64,
148 P: CryptographicPermutation<[F; WIDTH]>,
149{
150 fn sample_bits(&mut self, bits: usize) -> usize {
151 assert!(bits < (usize::BITS as usize));
152 assert!((1 << bits) < F::ORDER_U64);
153 let rand_f: F = self.sample();
154 let rand_usize = rand_f.as_canonical_u64() as usize;
155 rand_usize & ((1 << bits) - 1)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use core::iter;
162
163 use p3_baby_bear::BabyBear;
164 use p3_field::FieldAlgebra;
165 use p3_goldilocks::Goldilocks;
166 use p3_symmetric::Permutation;
167
168 use super::*;
169 use crate::grinding_challenger::GrindingChallenger;
170
171 const WIDTH: usize = 24;
172 const RATE: usize = 16;
173
174 type G = Goldilocks;
175 type BB = BabyBear;
176
177 #[derive(Clone)]
178 struct TestPermutation {}
179
180 impl<F: Clone> Permutation<[F; WIDTH]> for TestPermutation {
181 fn permute_mut(&self, input: &mut [F; WIDTH]) {
182 input.reverse()
183 }
184 }
185
186 impl<F: Clone> CryptographicPermutation<[F; WIDTH]> for TestPermutation {}
187
188 #[test]
189 fn test_duplex_challenger() {
190 type Chal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
191 let permutation = TestPermutation {};
192 let mut duplex_challenger = DuplexChallenger::new(permutation);
193
194 (0..12).for_each(|element| duplex_challenger.observe(G::from_canonical_u8(element as u8)));
196
197 let state_after_duplexing: Vec<_> = iter::repeat(G::ZERO)
198 .take(12)
199 .chain((0..12).map(G::from_canonical_u8).rev())
200 .collect();
201
202 let expected_samples: Vec<G> = state_after_duplexing[..16].iter().copied().rev().collect();
203 let samples = <Chal as CanSample<G>>::sample_vec(&mut duplex_challenger, 16);
204 assert_eq!(samples, expected_samples);
205 }
206
207 #[test]
208 #[should_panic]
209 fn test_duplex_challenger_sample_bits_security() {
210 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
211 let permutation = TestPermutation {};
212 let mut duplex_challenger = GoldilocksChal::new(permutation);
213
214 for _ in 0..100 {
215 assert!(duplex_challenger.sample_bits(129) < 4);
216 }
217 }
218
219 #[test]
220 #[should_panic]
221 fn test_duplex_challenger_sample_bits_security_small_field() {
222 type BabyBearChal = DuplexChallenger<BB, TestPermutation, WIDTH, RATE>;
223 let permutation = TestPermutation {};
224 let mut duplex_challenger = BabyBearChal::new(permutation);
225
226 for _ in 0..100 {
227 assert!(duplex_challenger.sample_bits(40) < 1 << 31);
228 }
229 }
230
231 #[test]
232 #[should_panic]
233 fn test_duplex_challenger_grind_security() {
234 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
235 let permutation = TestPermutation {};
236 let mut duplex_challenger = GoldilocksChal::new(permutation);
237
238 let too_many_bits = usize::BITS as usize;
243
244 let witness = duplex_challenger.grind(too_many_bits);
245 assert!(duplex_challenger.check_witness(too_many_bits, witness));
246 }
247}