openvm_native_recursion/challenger/
duplex.rs
1use openvm_native_compiler::{
2 ir::{RVar, DIGEST_SIZE, PERMUTATION_WIDTH},
3 prelude::*,
4};
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, FieldAlgebra};
7
8use crate::{
9 challenger::{
10 CanCheckWitness, CanObserveDigest, CanObserveVariable, CanSampleBitsVariable,
11 CanSampleVariable, ChallengerVariable, FeltChallenger,
12 },
13 digest::DigestVariable,
14};
15
16#[derive(Clone)]
18pub struct DuplexChallengerVariable<C: Config> {
19 pub sponge_state: Array<C, Felt<C::F>>,
20
21 pub input_ptr: Ptr<C::N>,
22 pub output_ptr: Ptr<C::N>,
23 pub io_empty_ptr: Ptr<C::N>,
24 pub io_full_ptr: Ptr<C::N>,
25}
26
27impl<C: Config> DuplexChallengerVariable<C> {
28 pub fn new(builder: &mut Builder<C>) -> Self {
30 let sponge_state = builder.dyn_array(PERMUTATION_WIDTH);
31
32 builder
33 .range(0, sponge_state.len())
34 .for_each(|i_vec, builder| {
35 builder.set(&sponge_state, i_vec[0], C::F::ZERO);
36 });
37 let io_empty_ptr = sponge_state.ptr();
38 let io_full_ptr: Ptr<_> =
39 builder.eval(io_empty_ptr + C::N::from_canonical_usize(DIGEST_SIZE));
40 let input_ptr = builder.eval(io_empty_ptr);
41 let output_ptr = builder.eval(io_empty_ptr);
42
43 DuplexChallengerVariable::<C> {
44 sponge_state,
45 input_ptr,
46 output_ptr,
47 io_empty_ptr,
48 io_full_ptr,
49 }
50 }
51
52 pub fn duplexing(&self, builder: &mut Builder<C>) {
53 builder.assign(&self.input_ptr, self.io_empty_ptr);
54
55 builder.poseidon2_permute_mut(&self.sponge_state);
56
57 builder.assign(&self.output_ptr, self.io_full_ptr);
58 }
59
60 fn observe(&self, builder: &mut Builder<C>, value: Felt<C::F>) {
61 builder.assign(&self.output_ptr, self.io_empty_ptr);
62
63 builder.iter_ptr_set(&self.sponge_state, self.input_ptr.address.into(), value);
64 builder.assign(&self.input_ptr, self.input_ptr + C::N::ONE);
65
66 builder
67 .if_eq(self.input_ptr.address, self.io_full_ptr.address)
68 .then(|builder| {
69 self.duplexing(builder);
70 })
71 }
72
73 fn observe_commitment(&self, builder: &mut Builder<C>, commitment: &Array<C, Felt<C::F>>) {
74 for i in 0..DIGEST_SIZE {
75 let element = builder.get(commitment, i);
76 self.observe(builder, element);
77 }
78 }
79
80 fn sample(&self, builder: &mut Builder<C>) -> Felt<C::F> {
81 builder
82 .if_ne(self.input_ptr.address, self.io_empty_ptr.address)
83 .then_or_else(
84 |builder| {
85 self.duplexing(builder);
86 },
87 |builder| {
88 builder
89 .if_eq(self.output_ptr.address, self.io_empty_ptr.address)
90 .then(|builder| {
91 self.duplexing(builder);
92 });
93 },
94 );
95 builder.assign(&self.output_ptr, self.output_ptr - C::N::ONE);
96 builder.iter_ptr_get(&self.sponge_state, self.output_ptr.address.into())
97 }
98
99 fn sample_ext(&self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
100 let a = self.sample(builder);
101 let b = self.sample(builder);
102 let c = self.sample(builder);
103 let d = self.sample(builder);
104 builder.ext_from_base_slice(&[a, b, c, d])
105 }
106
107 fn sample_bits(&self, builder: &mut Builder<C>, nb_bits: RVar<C::N>) -> Array<C, Var<C::N>>
108 where
109 C::N: Field,
110 {
111 let rand_f = self.sample(builder);
112 let bits = builder.num2bits_f(rand_f, C::N::bits() as u32);
113
114 builder
115 .range(nb_bits, bits.len())
116 .for_each(|i_vec, builder| {
117 builder.set(&bits, i_vec[0], C::N::ZERO);
118 });
119 bits
120 }
121
122 pub fn check_witness(&self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
123 self.observe(builder, witness);
124 let element_bits = self.sample_bits(builder, RVar::from(nb_bits));
125 let element_bits_truncated = element_bits.slice(builder, 0, nb_bits);
126 iter_zip!(builder, element_bits_truncated).for_each(|ptr_vec, builder| {
127 let element = builder.iter_ptr_get(&element_bits_truncated, ptr_vec[0]);
128 builder.assert_var_eq(element, C::N::ZERO);
129 });
130 }
131}
132
133impl<C: Config> CanObserveVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
134 fn observe(&mut self, builder: &mut Builder<C>, value: Felt<C::F>) {
135 DuplexChallengerVariable::observe(self, builder, value);
136 }
137
138 fn observe_slice(&mut self, builder: &mut Builder<C>, values: Array<C, Felt<C::F>>) {
139 iter_zip!(builder, values).for_each(|ptr_vec, builder| {
140 let element = builder.iter_ptr_get(&values, ptr_vec[0]);
141 self.observe(builder, element);
142 });
143 }
144}
145
146impl<C: Config> CanSampleVariable<C, Felt<C::F>> for DuplexChallengerVariable<C> {
147 fn sample(&mut self, builder: &mut Builder<C>) -> Felt<C::F> {
148 DuplexChallengerVariable::sample(self, builder)
149 }
150}
151
152impl<C: Config> CanSampleBitsVariable<C> for DuplexChallengerVariable<C> {
153 fn sample_bits(
154 &mut self,
155 builder: &mut Builder<C>,
156 nb_bits: RVar<C::N>,
157 ) -> Array<C, Var<C::N>> {
158 DuplexChallengerVariable::sample_bits(self, builder, nb_bits)
159 }
160}
161
162impl<C: Config> CanObserveDigest<C> for DuplexChallengerVariable<C> {
163 fn observe_digest(&mut self, builder: &mut Builder<C>, commitment: DigestVariable<C>) {
164 if let DigestVariable::Felt(commitment) = commitment {
165 self.observe_commitment(builder, &commitment);
166 } else {
167 panic!("Expected a felt digest");
168 }
169 }
170}
171
172impl<C: Config> FeltChallenger<C> for DuplexChallengerVariable<C> {
173 fn sample_ext(&mut self, builder: &mut Builder<C>) -> Ext<C::F, C::EF> {
174 DuplexChallengerVariable::sample_ext(self, builder)
175 }
176}
177
178impl<C: Config> CanCheckWitness<C> for DuplexChallengerVariable<C> {
179 fn check_witness(&mut self, builder: &mut Builder<C>, nb_bits: usize, witness: Felt<C::F>) {
180 DuplexChallengerVariable::check_witness(self, builder, nb_bits, witness);
181 }
182}
183
184impl<C: Config> ChallengerVariable<C> for DuplexChallengerVariable<C> {
185 fn new(builder: &mut Builder<C>) -> Self {
186 DuplexChallengerVariable::new(builder)
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use openvm_native_circuit::execute_program;
193 use openvm_native_compiler::{
194 asm::{AsmBuilder, AsmConfig},
195 ir::Felt,
196 };
197 use openvm_stark_backend::{
198 config::{StarkGenericConfig, Val},
199 p3_challenger::{CanObserve, CanSample},
200 p3_field::FieldAlgebra,
201 };
202 use openvm_stark_sdk::{
203 config::baby_bear_poseidon2::{default_engine, BabyBearPoseidon2Config},
204 engine::StarkEngine,
205 p3_baby_bear::BabyBear,
206 };
207 use rand::Rng;
208
209 use super::DuplexChallengerVariable;
210
211 fn test_compiler_challenger_with_num_challenges(num_challenges: usize) {
212 let mut rng = rand::thread_rng();
213 let observations = (0..num_challenges)
214 .map(|_| BabyBear::from_canonical_u32(rng.gen_range(0..(1 << 30))))
215 .collect::<Vec<_>>();
216
217 type SC = BabyBearPoseidon2Config;
218 type F = Val<SC>;
219 type EF = <SC as StarkGenericConfig>::Challenge;
220
221 let engine = default_engine();
222 let mut challenger = engine.new_challenger();
223 for observation in &observations {
224 challenger.observe(*observation);
225 }
226 let result: F = challenger.sample();
227 println!("expected result: {}", result);
228
229 let mut builder = AsmBuilder::<F, EF>::default();
230
231 let challenger = DuplexChallengerVariable::<AsmConfig<F, EF>>::new(&mut builder);
232 for observation in &observations {
233 let observation: Felt<_> = builder.eval(*observation);
234 challenger.observe(&mut builder, observation);
235 }
236 let element = challenger.sample(&mut builder);
237
238 let expected_result: Felt<_> = builder.eval(result);
239 builder.assert_felt_eq(expected_result, element);
240
241 builder.halt();
242
243 let program = builder.compile_isa();
244 execute_program(program, vec![]);
245 }
246
247 #[test]
248 fn test_compiler_challenger() {
249 test_compiler_challenger_with_num_challenges(1);
250 test_compiler_challenger_with_num_challenges(4);
251 test_compiler_challenger_with_num_challenges(8);
252 test_compiler_challenger_with_num_challenges(10);
253 test_compiler_challenger_with_num_challenges(16);
254 test_compiler_challenger_with_num_challenges(20);
255 test_compiler_challenger_with_num_challenges(50);
256 }
257}