openvm_native_recursion/challenger/
duplex.rs

1use openvm_native_compiler::{
2    ir::{RVar, DIGEST_SIZE, PERMUTATION_WIDTH},
3    prelude::*,
4};
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, PrimeCharacteristicRing};
7
8use crate::{
9    challenger::{
10        CanCheckWitness, CanObserveDigest, CanObserveVariable, CanSampleBitsVariable,
11        CanSampleVariable, ChallengerVariable, FeltChallenger,
12    },
13    digest::DigestVariable,
14};
15
16/// Reference: [`openvm_stark_backend::p3_challenger::DuplexChallenger`]
17#[derive(Clone)]
18pub struct DuplexChallengerVariable<C: Config> {
19    pub sponge_state: Array<C, Felt<C::F>>,
20
21    pub input_ptr: Ptr<C::N>,
22    pub output_ptr: Ptr<C::N>,
23    pub io_empty_ptr: Ptr<C::N>,
24    pub io_full_ptr: Ptr<C::N>,
25}
26
27impl<C: Config> DuplexChallengerVariable<C> {
28    /// Creates a new duplex challenger with the default state.
29    pub fn new(builder: &mut Builder<C>) -> Self {
30        let sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
31
32        builder
33            .range(0, sponge_state.len())
34            .for_each(|i_vec, builder| {
35                builder.set(&sponge_state, i_vec[0], C::F::ZERO);
36            });
37        let io_empty_ptr = sponge_state.ptr();
38        let io_full_ptr: Ptr<_> = builder.eval(io_empty_ptr + C::N::from_usize(DIGEST_SIZE));
39        let input_ptr = builder.eval(io_empty_ptr);
40        let output_ptr = builder.eval(io_empty_ptr);
41
42        DuplexChallengerVariable::<C> {
43            sponge_state,
44            input_ptr,
45            output_ptr,
46            io_empty_ptr,
47            io_full_ptr,
48        }
49    }
50
51    pub fn duplexing(&self, builder: &mut Builder<C>) {
52        builder.assign(&self.input_ptr, self.io_empty_ptr);
53
54        builder.poseidon2_permute_mut(&self.sponge_state);
55
56        builder.assign(&self.output_ptr, self.io_full_ptr);
57    }
58
59    fn observe(&self, builder: &mut Builder<C>, value: Felt<C::F>) {
60        builder.assign(&self.output_ptr, self.io_empty_ptr);
61
62        builder.iter_ptr_set(&self.sponge_state, self.input_ptr.address.into(), value);
63        builder.assign(&self.input_ptr, self.input_ptr + C::N::ONE);
64
65        builder
66            .if_eq(self.input_ptr.address, self.io_full_ptr.address)
67            .then(|builder| {
68                self.duplexing(builder);
69            })
70    }
71
72    fn observe_commitment(&self, builder: &mut Builder<C>, commitment: &Array<C, Felt<C::F>>) {
73        for i in 0..DIGEST_SIZE {
74            let element = builder.get(commitment, i);
75            self.observe(builder, element);
76        }
77    }
78
79    fn sample(&self, builder: &mut Builder<C>) -> Felt<C::F> {
80        builder
81            .if_ne(self.input_ptr.address, self.io_empty_ptr.address)
82            .then_or_else(
83                |builder| {
84                    self.duplexing(builder);
85                },
86                |builder| {
87                    builder
88                        .if_eq(self.output_ptr.address, self.io_empty_ptr.address)
89                        .then(|builder| {
90                            self.duplexing(builder);
91                        });
92                },
93            );
94        builder.assign(&self.output_ptr, self.output_ptr - C::N::ONE);
95        builder.iter_ptr_get(&self.sponge_state, self.output_ptr.address.into())
96    }
97
98    fn sample_ext(&self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
99        let a = self.sample(builder);
100        let b = self.sample(builder);
101        let c = self.sample(builder);
102        let d = self.sample(builder);
103        builder.ext_from_base_slice(&[a, b, c, d])
104    }
105
106    fn sample_bits(&self, builder: &mut Builder<C>, nb_bits: RVar<C::N>) -> Array<C, Var<C::N>>
107    where
108        C::N: Field,
109    {
110        let rand_f = self.sample(builder);
111        let bits = builder.num2bits_f(rand_f, C::N::bits() as u32);
112
113        builder
114            .range(nb_bits, bits.len())
115            .for_each(|i_vec, builder| {
116                builder.set(&bits, i_vec[0], C::N::ZERO);
117            });
118        bits
119    }
120
121    pub fn check_witness(&self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
122        if nb_bits == 0 {
123            return;
124        }
125        self.observe(builder, witness);
126        let element_bits = self.sample_bits(builder, RVar::from(nb_bits));
127        let element_bits_truncated = element_bits.slice(builder, 0, nb_bits);
128        iter_zip!(builder, element_bits_truncated).for_each(|ptr_vec, builder| {
129            let element = builder.iter_ptr_get(&element_bits_truncated, ptr_vec[0]);
130            builder.assert_var_eq(element, C::N::ZERO);
131        });
132    }
133}
134
135impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
136    fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
137        DuplexChallengerVariable::observe(self, builder, value);
138    }
139
140    fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
141        iter_zip!(builder, values).for_each(|ptr_vec, builder| {
142            let element = builder.iter_ptr_get(&values, ptr_vec[0]);
143            self.observe(builder, element);
144        });
145    }
146}
147
148impl<C: Config> CanSampleVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
149    fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
150        DuplexChallengerVariable::sample(self, builder)
151    }
152}
153
154impl<C: Config> CanSampleBitsVariable<C> for DuplexChallengerVariable<C> {
155    fn sample_bits(
156        &mut self,
157        builder: &mut Builder<C>,
158        nb_bits: RVar<C::N>,
159    ) -> Array<C, Var<C::N>> {
160        DuplexChallengerVariable::sample_bits(self, builder, nb_bits)
161    }
162}
163
164impl<C: Config> CanObserveDigest<C> for DuplexChallengerVariable<C> {
165    fn observe_digest(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
166        if let DigestVariable::Felt(commitment) = commitment {
167            self.observe_commitment(builder, &commitment);
168        } else {
169            panic!("Expected a felt digest");
170        }
171    }
172}
173
174impl<C: Config> FeltChallenger<C> for DuplexChallengerVariable<C> {
175    fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
176        DuplexChallengerVariable::sample_ext(self, builder)
177    }
178}
179
180impl<C: Config> CanCheckWitness<C> for DuplexChallengerVariable<C> {
181    fn check_witness(&mut self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
182        DuplexChallengerVariable::check_witness(self, builder, nb_bits, witness);
183    }
184}
185
186impl<C: Config> ChallengerVariable<C> for DuplexChallengerVariable<C> {
187    fn new(builder: &mut Builder<C>) -> Self {
188        DuplexChallengerVariable::new(builder)
189    }
190}
191
192#[cfg(test)]
193mod tests {
194    use openvm_native_circuit::execute_program;
195    use openvm_native_compiler::{
196        asm::{AsmBuilder, AsmConfig},
197        ir::Felt,
198    };
199    use openvm_stark_backend::{
200        config::{StarkGenericConfig, Val},
201        p3_challenger::{CanObserve, CanSample},
202        p3_field::PrimeCharacteristicRing,
203    };
204    use openvm_stark_sdk::{
205        config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
206        engine::StarkEngine,
207        p3_baby_bear::BabyBear,
208        utils::create_seeded_rng,
209    };
210    use rand::Rng;
211
212    use super::DuplexChallengerVariable;
213
214    fn test_compiler_challenger_with_num_challenges(num_challenges: usize) {
215        let mut rng = create_seeded_rng();
216        let observations = (0..num_challenges)
217            .map(|_| BabyBear::from_u32(rng.random_range(0..(1 << 30))))
218            .collect::<Vec<_>>();
219
220        type SC = BabyBearPoseidon2Config;
221        type F = Val<SC>;
222        type EF = <SC as StarkGenericConfig>::Challenge;
223
224        let engine = default_engine();
225        let mut challenger = engine.new_challenger();
226        for observation in &observations {
227            challenger.observe(*observation);
228        }
229        let result: F = challenger.sample();
230        println!("expected result: {result}");
231
232        let mut builder = AsmBuilder::<F, EF>::default();
233
234        let challenger = DuplexChallengerVariable::<AsmConfig<F, EF>>::new(&mut builder);
235        for observation in &observations {
236            let observation: Felt<_> = builder.eval(*observation);
237            challenger.observe(&mut builder, observation);
238        }
239        let element = challenger.sample(&mut builder);
240
241        let expected_result: Felt<_> = builder.eval(result);
242        builder.assert_felt_eq(expected_result, element);
243
244        builder.halt();
245
246        let program = builder.compile_isa();
247        execute_program(program, vec![]);
248    }
249
250    #[test]
251    fn test_compiler_challenger() {
252        test_compiler_challenger_with_num_challenges(1);
253        test_compiler_challenger_with_num_challenges(4);
254        test_compiler_challenger_with_num_challenges(8);
255        test_compiler_challenger_with_num_challenges(10);
256        test_compiler_challenger_with_num_challenges(16);
257        test_compiler_challenger_with_num_challenges(20);
258        test_compiler_challenger_with_num_challenges(50);
259    }
260}