1use alloc::vec;
2use alloc::vec::Vec;
3use core::error::Error;
4use core::fmt::{Display, Formatter};
5
6use p3_field::{BasedVectorSpace, Field, PrimeField, PrimeField64};
7use p3_monty_31::{MontyField31, MontyParameters};
8use p3_symmetric::{CryptographicPermutation, Hash};
9
10use crate::{CanObserve, CanSample, CanSampleBits, CanSampleUniformBits, FieldChallenger};
11
12#[derive(Clone, Debug)]
29pub struct DuplexChallenger<F, P, const WIDTH: usize, const RATE: usize>
30where
31 F: Clone,
32 P: CryptographicPermutation<[F; WIDTH]>,
33{
34 pub sponge_state: [F; WIDTH],
41
42 pub input_buffer: Vec<F>,
48
49 pub output_buffer: Vec<F>,
55
56 pub permutation: P,
62}
63
64impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
65where
66 F: Copy,
67 P: CryptographicPermutation<[F; WIDTH]>,
68{
69 pub fn new(permutation: P) -> Self
70 where
71 F: Default,
72 {
73 Self {
74 sponge_state: [F::default(); WIDTH],
75 input_buffer: vec![],
76 output_buffer: vec![],
77 permutation,
78 }
79 }
80
81 pub(crate) fn duplexing(&mut self) {
82 assert!(self.input_buffer.len() <= RATE);
83
84 for (i, val) in self.input_buffer.drain(..).enumerate() {
86 self.sponge_state[i] = val;
87 }
88
89 self.permutation.permute_mut(&mut self.sponge_state);
91
92 self.output_buffer.clear();
93 self.output_buffer.extend(&self.sponge_state[..RATE]);
94 }
95}
96
97impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
98where
99 F: Copy + Default + PrimeField,
100 P: CryptographicPermutation<[F; WIDTH]>,
101{
102 pub fn absorb_rate_padded_with_tag(&mut self, values: &[F], length_tag: u8) {
109 const {
110 assert!(
111 RATE < WIDTH,
112 "RATE must be less than WIDTH for capacity length slot"
113 );
114 }
115 assert!(values.len() <= RATE);
116 self.input_buffer.clear();
117 self.output_buffer.clear();
118 for (i, &v) in values.iter().enumerate() {
119 self.sponge_state[i] = v;
120 }
121 self.sponge_state[values.len()..RATE].fill(F::ZERO);
122 self.sponge_state[RATE] += F::from_u8(length_tag);
123 self.permutation.permute_mut(&mut self.sponge_state);
124 self.output_buffer
125 .extend_from_slice(&self.sponge_state[..RATE]);
126 }
127}
128
129impl<F, P, const WIDTH: usize, const RATE: usize> FieldChallenger<F>
130 for DuplexChallenger<F, P, WIDTH, RATE>
131where
132 F: PrimeField64,
133 P: CryptographicPermutation<[F; WIDTH]>,
134{
135}
136
137impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<F>
138 for DuplexChallenger<F, P, WIDTH, RATE>
139where
140 F: Copy,
141 P: CryptographicPermutation<[F; WIDTH]>,
142{
143 fn observe(&mut self, value: F) {
144 self.output_buffer.clear();
146
147 self.input_buffer.push(value);
148
149 if self.input_buffer.len() == RATE {
150 self.duplexing();
151 }
152 }
153}
154
155impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<[F; N]>
156 for DuplexChallenger<F, P, WIDTH, RATE>
157where
158 F: Copy,
159 P: CryptographicPermutation<[F; WIDTH]>,
160{
161 fn observe(&mut self, values: [F; N]) {
162 for value in values {
163 self.observe(value);
164 }
165 }
166}
167
168impl<F, P, const N: usize, const WIDTH: usize, const RATE: usize> CanObserve<Hash<F, F, N>>
169 for DuplexChallenger<F, P, WIDTH, RATE>
170where
171 F: Copy,
172 P: CryptographicPermutation<[F; WIDTH]>,
173{
174 fn observe(&mut self, values: Hash<F, F, N>) {
175 for value in values {
176 self.observe(value);
177 }
178 }
179}
180
181impl<F, P, const WIDTH: usize, const RATE: usize> CanObserve<Vec<Vec<F>>>
183 for DuplexChallenger<F, P, WIDTH, RATE>
184where
185 F: Copy,
186 P: CryptographicPermutation<[F; WIDTH]>,
187{
188 fn observe(&mut self, valuess: Vec<Vec<F>>) {
189 for values in valuess {
190 for value in values {
191 self.observe(value);
192 }
193 }
194 }
195}
196
197impl<F, EF, P, const WIDTH: usize, const RATE: usize> CanSample<EF>
198 for DuplexChallenger<F, P, WIDTH, RATE>
199where
200 F: Field,
201 EF: BasedVectorSpace<F>,
202 P: CryptographicPermutation<[F; WIDTH]>,
203{
204 fn sample(&mut self) -> EF {
205 EF::from_basis_coefficients_fn(|_| {
206 if !self.input_buffer.is_empty() || self.output_buffer.is_empty() {
209 self.duplexing();
210 }
211
212 self.output_buffer
213 .pop()
214 .expect("Output buffer should be non-empty")
215 })
216 }
217}
218
219impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleBits<usize>
220 for DuplexChallenger<F, P, WIDTH, RATE>
221where
222 F: PrimeField64,
223 P: CryptographicPermutation<[F; WIDTH]>,
224{
225 fn sample_bits(&mut self, bits: usize) -> usize {
234 assert!(bits < (usize::BITS as usize));
235 assert!((1 << bits) < F::ORDER_U64);
236 let rand_f: F = self.sample();
237 let rand_usize = rand_f.as_canonical_u64() as usize;
238 rand_usize & ((1 << bits) - 1)
239 }
240}
241
242pub trait UniformSamplingField {
244 const MAX_SINGLE_SAMPLE_BITS: usize;
247 const SAMPLING_BITS_M: [u64; 64];
258}
259
260impl<MP> UniformSamplingField for MontyField31<MP>
264where
265 MP: UniformSamplingField + MontyParameters,
266{
267 const MAX_SINGLE_SAMPLE_BITS: usize = MP::MAX_SINGLE_SAMPLE_BITS;
268 const SAMPLING_BITS_M: [u64; 64] = MP::SAMPLING_BITS_M;
269}
270
271pub(super) struct ResampleOnRejection;
275pub(super) struct ErrorOnRejection;
277
278#[derive(Debug)]
281pub struct ResamplingError {
282 value: u64,
284 m: u64,
286}
287
288impl Display for ResamplingError {
289 fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
290 write!(
291 f,
292 "Encountered value {0}, which requires resampling for uniform bits as it not smaller than {1}. But resampling is not enabled.",
293 self.value, self.m
294 )
295 }
296}
297
298impl Error for ResamplingError {}
299
300pub(super) trait BitSamplingStrategy<F, P, const W: usize, const R: usize>
302where
303 F: PrimeField64,
304 P: CryptographicPermutation<[F; W]>,
305{
306 const ERROR_ON_REJECTION: bool;
308
309 #[inline]
310 fn sample_value(
311 challenger: &mut DuplexChallenger<F, P, W, R>,
312 m: u64,
313 ) -> Result<F, ResamplingError> {
314 let mut result: F = challenger.sample();
315 if Self::ERROR_ON_REJECTION {
316 if result.as_canonical_u64() >= m {
317 return Err(ResamplingError {
318 value: result.as_canonical_u64(),
319 m,
320 });
321 }
322 } else {
323 while result.as_canonical_u64() >= m {
324 result = challenger.sample();
325 }
326 }
327 Ok(result)
328 }
329}
330
331impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ResampleOnRejection
333where
334 F: PrimeField64,
335 P: CryptographicPermutation<[F; W]>,
336{
337 const ERROR_ON_REJECTION: bool = false;
338}
339
340impl<F, P, const W: usize, const R: usize> BitSamplingStrategy<F, P, W, R> for ErrorOnRejection
342where
343 F: PrimeField64,
344 P: CryptographicPermutation<[F; W]>,
345{
346 const ERROR_ON_REJECTION: bool = true;
347}
348
349impl<F, P, const WIDTH: usize, const RATE: usize> DuplexChallenger<F, P, WIDTH, RATE>
350where
351 F: UniformSamplingField + PrimeField64,
352 P: CryptographicPermutation<[F; WIDTH]>,
353{
354 #[inline]
356 fn sample_uniform_bits_with_strategy<S>(
357 &mut self,
358 bits: usize,
359 ) -> Result<usize, ResamplingError>
360 where
361 S: BitSamplingStrategy<F, P, WIDTH, RATE>,
362 {
363 if bits == 0 {
364 return Ok(0);
365 };
366 assert!(bits < usize::BITS as usize, "bit count must be valid");
367 assert!(
368 (1u64 << bits) < F::ORDER_U64,
369 "bit count exceeds field order"
370 );
371 let m = F::SAMPLING_BITS_M[bits];
372 if bits <= F::MAX_SINGLE_SAMPLE_BITS {
373 let rand_f = S::sample_value(self, m);
375 Ok(rand_f?.as_canonical_u64() as usize & ((1 << bits) - 1))
376 } else {
377 let half_bits1 = bits / 2;
380 let half_bits2 = bits - half_bits1;
381 let rand1 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits1]);
383 let chunk1 = rand1?.as_canonical_u64() as usize & ((1 << half_bits1) - 1);
384 let rand2 = S::sample_value(self, F::SAMPLING_BITS_M[half_bits2]);
386 let chunk2 = rand2?.as_canonical_u64() as usize & ((1 << half_bits2) - 1);
387
388 Ok(chunk1 | (chunk2 << half_bits1))
390 }
391 }
392}
393
394impl<F, P, const WIDTH: usize, const RATE: usize> CanSampleUniformBits<F>
395 for DuplexChallenger<F, P, WIDTH, RATE>
396where
397 F: UniformSamplingField + PrimeField64,
398 P: CryptographicPermutation<[F; WIDTH]>,
399{
400 fn sample_uniform_bits<const RESAMPLE: bool>(
401 &mut self,
402 bits: usize,
403 ) -> Result<usize, ResamplingError> {
404 if RESAMPLE {
405 self.sample_uniform_bits_with_strategy::<ResampleOnRejection>(bits)
406 } else {
407 self.sample_uniform_bits_with_strategy::<ErrorOnRejection>(bits)
408 }
409 }
410}
411
412#[cfg(test)]
413mod tests {
414 use core::iter;
415
416 use p3_baby_bear::BabyBear;
417 use p3_field::PrimeCharacteristicRing;
418 use p3_field::extension::BinomialExtensionField;
419 use p3_goldilocks::Goldilocks;
420 use p3_symmetric::Permutation;
421
422 use super::*;
423 use crate::grinding_challenger::GrindingChallenger;
424
425 const WIDTH: usize = 24;
426 const RATE: usize = 16;
427
428 type G = Goldilocks;
429 type EF2G = BinomialExtensionField<G, 2>;
430
431 type BB = BabyBear;
432
433 #[derive(Clone)]
434 struct TestPermutation {}
435
436 impl<F: Clone> Permutation<[F; WIDTH]> for TestPermutation {
437 fn permute_mut(&self, input: &mut [F; WIDTH]) {
438 input.reverse();
439 }
440 }
441
442 impl<F: Clone> CryptographicPermutation<[F; WIDTH]> for TestPermutation {}
443
444 #[test]
445 fn test_duplex_challenger() {
446 type Chal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
447 let permutation = TestPermutation {};
448 let mut duplex_challenger = DuplexChallenger::new(permutation);
449
450 (0..12).for_each(|element| duplex_challenger.observe(G::from_u8(element as u8)));
452
453 let state_after_duplexing: Vec<_> = iter::repeat_n(G::ZERO, 12)
454 .chain((0..12).map(G::from_u8).rev())
455 .collect();
456
457 let expected_samples: Vec<G> = state_after_duplexing[..16].iter().copied().rev().collect();
458 let samples = <Chal as CanSample<G>>::sample_vec(&mut duplex_challenger, 16);
459 assert_eq!(samples, expected_samples);
460 }
461
462 #[test]
463 #[should_panic]
464 fn test_duplex_challenger_sample_bits_security() {
465 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
466 let permutation = TestPermutation {};
467 let mut duplex_challenger = GoldilocksChal::new(permutation);
468
469 for _ in 0..100 {
470 assert!(duplex_challenger.sample_bits(129) < 4);
471 }
472 }
473
474 #[test]
475 #[should_panic]
476 fn test_duplex_challenger_sample_bits_security_small_field() {
477 type BabyBearChal = DuplexChallenger<BB, TestPermutation, WIDTH, RATE>;
478 let permutation = TestPermutation {};
479 let mut duplex_challenger = BabyBearChal::new(permutation);
480
481 for _ in 0..100 {
482 assert!(duplex_challenger.sample_bits(40) < 1 << 31);
483 }
484 }
485
486 #[test]
487 #[should_panic]
488 fn test_duplex_challenger_grind_security() {
489 type GoldilocksChal = DuplexChallenger<G, TestPermutation, WIDTH, RATE>;
490 let permutation = TestPermutation {};
491 let mut duplex_challenger = GoldilocksChal::new(permutation);
492
493 let too_many_bits = usize::BITS as usize;
498
499 let witness = duplex_challenger.grind(too_many_bits);
500 assert!(duplex_challenger.check_witness(too_many_bits, witness));
501 }
502
503 #[test]
504 fn test_observe_single_value() {
505 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
506 chal.observe(G::from_u8(42));
507 assert_eq!(chal.input_buffer, vec![G::from_u8(42)]);
508 assert!(chal.output_buffer.is_empty());
509 }
510
511 #[test]
512 fn test_observe_array_of_values() {
513 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
514 chal.observe([G::from_u8(1), G::from_u8(2), G::from_u8(3)]);
515 assert_eq!(
516 chal.input_buffer,
517 vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
518 );
519 assert!(chal.output_buffer.is_empty());
520 }
521
522 #[test]
523 fn test_observe_hash_array() {
524 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
525 let hash = Hash::<G, G, 4>::from([G::from_u8(10); 4]);
526 chal.observe(hash);
527 assert_eq!(chal.input_buffer, vec![G::from_u8(10); 4]);
528 }
529
530 #[test]
531 fn test_observe_nested_vecs() {
532 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
533 chal.observe(vec![
534 vec![G::from_u8(1), G::from_u8(2)],
535 vec![G::from_u8(3)],
536 ]);
537 assert_eq!(
538 chal.input_buffer,
539 vec![G::from_u8(1), G::from_u8(2), G::from_u8(3)]
540 );
541 }
542
543 #[test]
544 fn test_sample_triggers_duplex() {
545 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
546 chal.observe(G::from_u8(5));
547 assert!(chal.output_buffer.is_empty());
548 let _sample: G = chal.sample();
549 assert!(!chal.output_buffer.is_empty());
550 }
551
552 #[test]
553 fn test_sample_multiple_extension_field() {
554 use p3_field::extension::BinomialExtensionField;
555 type EF = BinomialExtensionField<G, 2>;
556 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
557
558 chal.observe(G::from_u8(1));
559 chal.observe(G::from_u8(2));
560 let _: EF = chal.sample();
561 let _: EF = chal.sample();
562 }
563
564 #[test]
565 fn test_sample_bits_within_bounds() {
566 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
567 for i in 0..RATE {
568 chal.observe(G::from_u8(i as u8));
569 }
570
571 let bits = 3;
576 let value = chal.sample_bits(bits);
577 let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
578 assert_eq!(value, expected);
579 }
580
581 #[test]
582 fn test_sample_bits_trigger_duplex_when_empty() {
583 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
584 assert_eq!(chal.input_buffer.len(), 0);
586 assert_eq!(chal.output_buffer.len(), 0);
587
588 let bits = 2;
590 let sample = chal.sample_bits(bits);
591 let expected = G::ZERO.as_canonical_u64() as usize & ((1 << bits) - 1);
592 assert_eq!(sample, expected);
593 }
594
595 #[test]
596 fn test_output_buffer_pops_correctly() {
597 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
598
599 for i in 0..RATE {
601 chal.observe(G::from_u8(i as u8));
602 }
603
604 let expected = [
606 G::from_u8(0),
607 G::from_u8(0),
608 G::from_u8(0),
609 G::from_u8(0),
610 G::from_u8(0),
611 G::from_u8(0),
612 G::from_u8(0),
613 G::from_u8(0),
614 G::from_u8(15),
615 G::from_u8(14),
616 G::from_u8(13),
617 G::from_u8(12),
618 G::from_u8(11),
619 G::from_u8(10),
620 G::from_u8(9),
621 G::from_u8(8),
622 ]
623 .to_vec();
624
625 assert_eq!(chal.output_buffer, expected);
626
627 let first: G = chal.sample();
628 let second: G = chal.sample();
629
630 assert_eq!(first, G::from_u8(8));
632 assert_eq!(second, G::from_u8(9));
633 }
634
635 #[test]
636 fn test_duplexing_only_when_needed() {
637 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
638 chal.output_buffer = vec![G::from_u8(10), G::from_u8(20)];
639
640 let sample: G = chal.sample();
642 assert_eq!(sample, G::from_u8(20));
643 assert_eq!(chal.output_buffer, vec![G::from_u8(10)]);
644 }
645
646 #[test]
647 fn test_flush_when_input_full() {
648 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
649
650 for i in 0..RATE {
652 chal.observe(G::from_u8(i as u8));
653 }
654
655 let expected_output = [
657 G::from_u8(0),
658 G::from_u8(0),
659 G::from_u8(0),
660 G::from_u8(0),
661 G::from_u8(0),
662 G::from_u8(0),
663 G::from_u8(0),
664 G::from_u8(0),
665 G::from_u8(15),
666 G::from_u8(14),
667 G::from_u8(13),
668 G::from_u8(12),
669 G::from_u8(11),
670 G::from_u8(10),
671 G::from_u8(9),
672 G::from_u8(8),
673 ]
674 .to_vec();
675
676 assert!(chal.input_buffer.is_empty());
678
679 assert_eq!(chal.output_buffer, expected_output);
681 }
682
683 #[test]
684 fn test_observe_base_as_algebra_element_consistency_with_direct_observe() {
685 let mut chal1 =
687 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
688 let mut chal2 =
689 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
690
691 let base_val = G::from_u8(99);
692
693 chal1.observe_base_as_algebra_element::<EF2G>(base_val);
695
696 let ext_val = EF2G::from(base_val);
698 chal2.observe_algebra_element(ext_val);
699
700 assert_eq!(chal1.input_buffer, chal2.input_buffer);
702 assert_eq!(chal1.output_buffer, chal2.output_buffer);
703 assert_eq!(chal1.sponge_state, chal2.sponge_state);
704 }
705
706 #[test]
707 fn test_observe_base_as_algebra_element_stream_consistency() {
708 let mut chal1 =
710 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
711 let mut chal2 =
712 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
713
714 let base_values: Vec<_> = (0u8..25).map(G::from_u8).collect();
716
717 for &val in &base_values {
719 chal1.observe_base_as_algebra_element::<EF2G>(val);
720 }
721
722 for &val in &base_values {
724 let ext_val = EF2G::from(val);
725 chal2.observe_algebra_element(ext_val);
726 }
727
728 assert_eq!(chal1.input_buffer, chal2.input_buffer);
730 assert_eq!(chal1.output_buffer, chal2.output_buffer);
731 assert_eq!(chal1.sponge_state, chal2.sponge_state);
732
733 let sample1: EF2G = chal1.sample_algebra_element();
735 let sample2: EF2G = chal2.sample_algebra_element();
736 assert_eq!(sample1, sample2);
737
738 assert_eq!(chal1.input_buffer, chal2.input_buffer);
740 assert_eq!(chal1.output_buffer, chal2.output_buffer);
741 assert_eq!(chal1.sponge_state, chal2.sponge_state);
742 }
743
744 #[test]
745 fn test_observe_algebra_elements_equivalence() {
746 let mut chal1 =
750 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
751 let mut chal2 =
752 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
753
754 let ext_values: Vec<EF2G> = (0u8..10).map(|i| EF2G::from(G::from_u8(i))).collect();
756
757 chal1.observe_algebra_slice(&ext_values);
759
760 for ext_val in &ext_values {
762 chal2.observe_algebra_element(*ext_val);
763 }
764
765 assert_eq!(chal1.input_buffer, chal2.input_buffer);
767 assert_eq!(chal1.output_buffer, chal2.output_buffer);
768 assert_eq!(chal1.sponge_state, chal2.sponge_state);
769
770 let sample1: EF2G = chal1.sample_algebra_element();
772 let sample2: EF2G = chal2.sample_algebra_element();
773 assert_eq!(sample1, sample2);
774 }
775
776 #[test]
777 fn test_observe_algebra_elements_empty_slice() {
778 let mut chal1 =
780 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
781 let mut chal2 =
782 DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
783
784 chal1.observe(G::from_u8(42));
786 chal2.observe(G::from_u8(42));
787
788 let empty: Vec<EF2G> = vec![];
790 chal1.observe_algebra_slice(&empty);
791
792 assert_eq!(chal1.input_buffer, chal2.input_buffer);
794 assert_eq!(chal1.output_buffer, chal2.output_buffer);
795 assert_eq!(chal1.sponge_state, chal2.sponge_state);
796 }
797
798 #[test]
799 fn test_observe_algebra_elements_triggers_duplexing() {
800 let mut chal = DuplexChallenger::<G, TestPermutation, WIDTH, RATE>::new(TestPermutation {});
802
803 let ext_values: Vec<EF2G> = (0u8..8).map(|i| EF2G::from(G::from_u8(i))).collect();
807
808 assert!(chal.input_buffer.is_empty());
809 assert!(chal.output_buffer.is_empty());
810
811 chal.observe_algebra_slice(&ext_values);
812
813 assert!(chal.input_buffer.is_empty());
815 assert!(!chal.output_buffer.is_empty());
816 }
817}