openvm_native_recursion/challenger/
duplex.rs

1use openvm_native_compiler::{
2    ir::{RVar, DIGEST_SIZE, PERMUTATION_WIDTH},
3    prelude::*,
4};
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, FieldAlgebra};
7
8use crate::{
9    challenger::{
10        CanCheckWitness, CanObserveDigest, CanObserveVariable, CanSampleBitsVariable,
11        CanSampleVariable, ChallengerVariable, FeltChallenger,
12    },
13    digest::DigestVariable,
14};
15
16/// Reference: [`openvm_stark_backend::p3_challenger::DuplexChallenger`]
17#[derive(Clone)]
18pub struct DuplexChallengerVariable<C: Config> {
19    pub sponge_state: Array<C, Felt<C::F>>,
20
21    pub input_ptr: Ptr<C::N>,
22    pub output_ptr: Ptr<C::N>,
23    pub io_empty_ptr: Ptr<C::N>,
24    pub io_full_ptr: Ptr<C::N>,
25}
26
27impl<C: Config> DuplexChallengerVariable<C> {
28    /// Creates a new duplex challenger with the default state.
29    pub fn new(builder: &mut Builder<C>) -> Self {
30        let sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
31
32        builder
33            .range(0, sponge_state.len())
34            .for_each(|i_vec, builder| {
35                builder.set(&sponge_state, i_vec[0], C::F::ZERO);
36            });
37        let io_empty_ptr = sponge_state.ptr();
38        let io_full_ptr: Ptr<_> =
39            builder.eval(io_empty_ptr + C::N::from_canonical_usize(DIGEST_SIZE));
40        let input_ptr = builder.eval(io_empty_ptr);
41        let output_ptr = builder.eval(io_empty_ptr);
42
43        DuplexChallengerVariable::<C> {
44            sponge_state,
45            input_ptr,
46            output_ptr,
47            io_empty_ptr,
48            io_full_ptr,
49        }
50    }
51
52    pub fn duplexing(&self, builder: &mut Builder<C>) {
53        builder.assign(&self.input_ptr, self.io_empty_ptr);
54
55        builder.poseidon2_permute_mut(&self.sponge_state);
56
57        builder.assign(&self.output_ptr, self.io_full_ptr);
58    }
59
60    fn observe(&self, builder: &mut Builder<C>, value: Felt<C::F>) {
61        builder.assign(&self.output_ptr, self.io_empty_ptr);
62
63        builder.iter_ptr_set(&self.sponge_state, self.input_ptr.address.into(), value);
64        builder.assign(&self.input_ptr, self.input_ptr + C::N::ONE);
65
66        builder
67            .if_eq(self.input_ptr.address, self.io_full_ptr.address)
68            .then(|builder| {
69                self.duplexing(builder);
70            })
71    }
72
73    fn observe_commitment(&self, builder: &mut Builder<C>, commitment: &Array<C, Felt<C::F>>) {
74        for i in 0..DIGEST_SIZE {
75            let element = builder.get(commitment, i);
76            self.observe(builder, element);
77        }
78    }
79
80    fn sample(&self, builder: &mut Builder<C>) -> Felt<C::F> {
81        builder
82            .if_ne(self.input_ptr.address, self.io_empty_ptr.address)
83            .then_or_else(
84                |builder| {
85                    self.duplexing(builder);
86                },
87                |builder| {
88                    builder
89                        .if_eq(self.output_ptr.address, self.io_empty_ptr.address)
90                        .then(|builder| {
91                            self.duplexing(builder);
92                        });
93                },
94            );
95        builder.assign(&self.output_ptr, self.output_ptr - C::N::ONE);
96        builder.iter_ptr_get(&self.sponge_state, self.output_ptr.address.into())
97    }
98
99    fn sample_ext(&self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
100        let a = self.sample(builder);
101        let b = self.sample(builder);
102        let c = self.sample(builder);
103        let d = self.sample(builder);
104        builder.ext_from_base_slice(&[a, b, c, d])
105    }
106
107    fn sample_bits(&self, builder: &mut Builder<C>, nb_bits: RVar<C::N>) -> Array<C, Var<C::N>>
108    where
109        C::N: Field,
110    {
111        let rand_f = self.sample(builder);
112        let bits = builder.num2bits_f(rand_f, C::N::bits() as u32);
113
114        builder
115            .range(nb_bits, bits.len())
116            .for_each(|i_vec, builder| {
117                builder.set(&bits, i_vec[0], C::N::ZERO);
118            });
119        bits
120    }
121
122    pub fn check_witness(&self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
123        self.observe(builder, witness);
124        let element_bits = self.sample_bits(builder, RVar::from(nb_bits));
125        let element_bits_truncated = element_bits.slice(builder, 0, nb_bits);
126        iter_zip!(builder, element_bits_truncated).for_each(|ptr_vec, builder| {
127            let element = builder.iter_ptr_get(&element_bits_truncated, ptr_vec[0]);
128            builder.assert_var_eq(element, C::N::ZERO);
129        });
130    }
131}
132
133impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
134    fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
135        DuplexChallengerVariable::observe(self, builder, value);
136    }
137
138    fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
139        iter_zip!(builder, values).for_each(|ptr_vec, builder| {
140            let element = builder.iter_ptr_get(&values, ptr_vec[0]);
141            self.observe(builder, element);
142        });
143    }
144}
145
146impl<C: Config> CanSampleVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
147    fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
148        DuplexChallengerVariable::sample(self, builder)
149    }
150}
151
152impl<C: Config> CanSampleBitsVariable<C> for DuplexChallengerVariable<C> {
153    fn sample_bits(
154        &mut self,
155        builder: &mut Builder<C>,
156        nb_bits: RVar<C::N>,
157    ) -> Array<C, Var<C::N>> {
158        DuplexChallengerVariable::sample_bits(self, builder, nb_bits)
159    }
160}
161
162impl<C: Config> CanObserveDigest<C> for DuplexChallengerVariable<C> {
163    fn observe_digest(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
164        if let DigestVariable::Felt(commitment) = commitment {
165            self.observe_commitment(builder, &commitment);
166        } else {
167            panic!("Expected a felt digest");
168        }
169    }
170}
171
172impl<C: Config> FeltChallenger<C> for DuplexChallengerVariable<C> {
173    fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
174        DuplexChallengerVariable::sample_ext(self, builder)
175    }
176}
177
178impl<C: Config> CanCheckWitness<C> for DuplexChallengerVariable<C> {
179    fn check_witness(&mut self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
180        DuplexChallengerVariable::check_witness(self, builder, nb_bits, witness);
181    }
182}
183
184impl<C: Config> ChallengerVariable<C> for DuplexChallengerVariable<C> {
185    fn new(builder: &mut Builder<C>) -> Self {
186        DuplexChallengerVariable::new(builder)
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use openvm_native_circuit::execute_program;
193    use openvm_native_compiler::{
194        asm::{AsmBuilder, AsmConfig},
195        ir::Felt,
196    };
197    use openvm_stark_backend::{
198        config::{StarkGenericConfig, Val},
199        p3_challenger::{CanObserve, CanSample},
200        p3_field::FieldAlgebra,
201    };
202    use openvm_stark_sdk::{
203        config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
204        engine::StarkEngine,
205        p3_baby_bear::BabyBear,
206    };
207    use rand::Rng;
208
209    use super::DuplexChallengerVariable;
210
211    fn test_compiler_challenger_with_num_challenges(num_challenges: usize) {
212        let mut rng = rand::thread_rng();
213        let observations = (0..num_challenges)
214            .map(|_| BabyBear::from_canonical_u32(rng.gen_range(0..(1 << 30))))
215            .collect::<Vec<_>>();
216
217        type SC = BabyBearPoseidon2Config;
218        type F = Val<SC>;
219        type EF = <SC as StarkGenericConfig>::Challenge;
220
221        let engine = default_engine();
222        let mut challenger = engine.new_challenger();
223        for observation in &observations {
224            challenger.observe(*observation);
225        }
226        let result: F = challenger.sample();
227        println!("expected result: {}", result);
228
229        let mut builder = AsmBuilder::<F, EF>::default();
230
231        let challenger = DuplexChallengerVariable::<AsmConfig<F, EF>>::new(&mut builder);
232        for observation in &observations {
233            let observation: Felt<_> = builder.eval(*observation);
234            challenger.observe(&mut builder, observation);
235        }
236        let element = challenger.sample(&mut builder);
237
238        let expected_result: Felt<_> = builder.eval(result);
239        builder.assert_felt_eq(expected_result, element);
240
241        builder.halt();
242
243        let program = builder.compile_isa();
244        execute_program(program, vec![]);
245    }
246
247    #[test]
248    fn test_compiler_challenger() {
249        test_compiler_challenger_with_num_challenges(1);
250        test_compiler_challenger_with_num_challenges(4);
251        test_compiler_challenger_with_num_challenges(8);
252        test_compiler_challenger_with_num_challenges(10);
253        test_compiler_challenger_with_num_challenges(16);
254        test_compiler_challenger_with_num_challenges(20);
255        test_compiler_challenger_with_num_challenges(50);
256    }
257}