openvm_native_compiler/ir/
builder.rs

1use std::{iter::Zip, vec::IntoIter};
2
3use backtrace::Backtrace;
4use itertools::izip;
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
7use serde::{Deserialize, Serialize};
8
9use super::{
10    Array, Config, DslIr, Ext, Felt, FromConstant, MemIndex, MemVariable, RVar, SymbolicExt,
11    SymbolicFelt, SymbolicVar, Usize, Var, Variable, WitnessRef,
12};
13use crate::ir::{collections::ArrayLike, Ptr};
14
15/// TracedVec is a Vec wrapper that records a trace whenever an element is pushed. When extending
16/// from another TracedVec, the traces are copied over.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TracedVec<T> {
19    pub vec: Vec<T>,
20    pub traces: Vec<Option<Backtrace>>,
21}
22
23impl<T> Default for TracedVec<T> {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<T> From<Vec<T>> for TracedVec<T> {
30    fn from(vec: Vec<T>) -> Self {
31        let len = vec.len();
32        Self {
33            vec,
34            traces: vec![None; len],
35        }
36    }
37}
38
39impl<T> TracedVec<T> {
40    pub const fn new() -> Self {
41        Self {
42            vec: Vec::new(),
43            traces: Vec::new(),
44        }
45    }
46
47    pub fn push(&mut self, value: T) {
48        self.vec.push(value);
49        self.traces.push(None);
50    }
51
52    /// Pushes a value to the vector and records a backtrace if RUST_BACKTRACE is enabled
53    pub fn trace_push(&mut self, value: T) {
54        self.vec.push(value);
55        if std::env::var_os("RUST_BACKTRACE").is_none() {
56            self.traces.push(None);
57        } else {
58            self.traces.push(Some(Backtrace::new_unresolved()));
59        }
60    }
61
62    pub fn extend<I: IntoIterator<Item = (T, Option<Backtrace>)>>(&mut self, iter: I) {
63        let iter = iter.into_iter();
64        let len = iter.size_hint().0;
65        self.vec.reserve(len);
66        self.traces.reserve(len);
67        for (value, trace) in iter {
68            self.vec.push(value);
69            self.traces.push(trace);
70        }
71    }
72
73    pub fn is_empty(&self) -> bool {
74        self.vec.is_empty()
75    }
76}
77
78impl<T> IntoIterator for TracedVec<T> {
79    type Item = (T, Option<Backtrace>);
80    type IntoIter = Zip<IntoIter<T>, IntoIter<Option<Backtrace>>>;
81
82    fn into_iter(self) -> Self::IntoIter {
83        self.vec.into_iter().zip(self.traces)
84    }
85}
86
87#[derive(Debug, Copy, Clone, Default)]
88pub struct BuilderFlags {
89    pub debug: bool,
90    /// If true, branching/looping/heap memory is disabled.
91    pub static_only: bool,
92}
93
94/// A builder for the DSL.
95///
96/// Can compile to both assembly and a set of constraints.
97#[derive(Debug, Clone, Default)]
98pub struct Builder<C: Config> {
99    pub(crate) var_count: u32,
100    pub(crate) felt_count: u32,
101    pub(crate) ext_count: u32,
102    pub operations: TracedVec<DslIr<C>>,
103    pub(crate) nb_public_values: Option<Var<C::N>>,
104    pub(crate) witness_var_count: u32,
105    pub(crate) witness_felt_count: u32,
106    pub(crate) witness_ext_count: u32,
107    pub(crate) witness_space: Vec<Vec<WitnessRef>>,
108    pub flags: BuilderFlags,
109    pub is_sub_builder: bool,
110}
111
112impl<C: Config> Builder<C> {
113    /// Creates a new builder with a given number of counts for each type.
114    pub fn create_sub_builder(&self) -> Self {
115        Self {
116            var_count: self.var_count,
117            felt_count: self.felt_count,
118            ext_count: self.ext_count,
119            // Witness counts are only used when the target is a circuit.  And sub-builders are
120            // not used when the target is a circuit, so it is fine to set the witness counts to 0.
121            witness_var_count: 0,
122            witness_felt_count: 0,
123            witness_ext_count: 0,
124            witness_space: Default::default(),
125            operations: Default::default(),
126            nb_public_values: self.nb_public_values,
127            flags: self.flags,
128            is_sub_builder: true,
129        }
130    }
131
132    /// Pushes an operation to the builder.
133    pub fn push(&mut self, op: DslIr<C>) {
134        self.operations.push(op);
135    }
136
137    /// Pushes an operation to the builder and records a trace if RUST_BACKTRACE=1.
138    pub fn trace_push(&mut self, op: DslIr<C>) {
139        self.operations.trace_push(op);
140    }
141
142    /// Creates an uninitialized variable.
143    pub fn uninit<V: Variable<C>>(&mut self) -> V {
144        V::uninit(self)
145    }
146
147    /// Evaluates an expression and returns a variable.
148    pub fn eval<V: Variable<C>, E: Into<V::Expression>>(&mut self, expr: E) -> V {
149        V::eval(self, expr)
150    }
151
152    /// Evaluates an expression and returns a right value.
153    pub fn eval_expr(&mut self, expr: impl Into<SymbolicVar<C::N>>) -> RVar<C::N> {
154        let expr = expr.into();
155        match expr {
156            SymbolicVar::Const(c, _) => RVar::Const(c),
157            SymbolicVar::Val(val, _) => RVar::Val(val),
158            _ => {
159                let ret: Var<_> = self.eval(expr);
160                RVar::Val(ret)
161            }
162        }
163    }
164
165    /// Increments Usize by one.
166    pub fn inc(&mut self, u: &Usize<C::N>) {
167        self.assign(u, u.clone() + RVar::one());
168    }
169
170    /// Evaluates a constant expression and returns a variable.
171    pub fn constant<V: FromConstant<C>>(&mut self, value: V::Constant) -> V {
172        V::constant(value, self)
173    }
174
175    /// Assigns an expression to a variable.
176    pub fn assign<V: Variable<C>, E: Into<V::Expression>>(&mut self, dst: &V, expr: E) {
177        dst.assign(expr.into(), self);
178    }
179
180    /// Casts a Felt to a Var.
181    pub fn cast_felt_to_var(&mut self, felt: Felt<C::F>) -> Var<C::N> {
182        let var: Var<_> = self.uninit();
183        self.operations.push(DslIr::CastFV(var, felt));
184        var
185    }
186    /// Casts a Var to a Felt.
187    pub fn unsafe_cast_var_to_felt(&mut self, var: Var<C::N>) -> Felt<C::F> {
188        assert!(!self.flags.static_only, "dynamic mode only");
189        let felt: Felt<_> = self.uninit();
190        self.operations.push(DslIr::UnsafeCastVF(felt, var));
191        felt
192    }
193
194    /// Asserts that a Usize is non-zero
195    pub fn assert_nonzero(&mut self, u: &Usize<C::N>) {
196        self.operations.push(DslIr::AssertNonZero(u.clone()));
197    }
198
199    /// Asserts that two expressions are equal.
200    pub fn assert_eq<V: Variable<C>>(
201        &mut self,
202        lhs: impl Into<V::Expression>,
203        rhs: impl Into<V::Expression>,
204    ) {
205        V::assert_eq(lhs, rhs, self);
206    }
207
208    /// Assert that two vars are equal.
209    pub fn assert_var_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
210        &mut self,
211        lhs: LhsExpr,
212        rhs: RhsExpr,
213    ) {
214        self.assert_eq::<Var<C::N>>(lhs, rhs);
215    }
216
217    /// Assert that two felts are equal.
218    pub fn assert_felt_eq<LhsExpr: Into<SymbolicFelt<C::F>>, RhsExpr: Into<SymbolicFelt<C::F>>>(
219        &mut self,
220        lhs: LhsExpr,
221        rhs: RhsExpr,
222    ) {
223        self.assert_eq::<Felt<C::F>>(lhs, rhs);
224    }
225
226    /// Assert that two exts are equal.
227    pub fn assert_ext_eq<
228        LhsExpr: Into<SymbolicExt<C::F, C::EF>>,
229        RhsExpr: Into<SymbolicExt<C::F, C::EF>>,
230    >(
231        &mut self,
232        lhs: LhsExpr,
233        rhs: RhsExpr,
234    ) {
235        self.assert_eq::<Ext<C::F, C::EF>>(lhs, rhs);
236    }
237
238    pub fn assert_usize_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
239        &mut self,
240        lhs: LhsExpr,
241        rhs: RhsExpr,
242    ) {
243        self.assert_eq::<Usize<C::N>>(lhs, rhs);
244    }
245
246    /// Assert that two arrays are equal.
247    pub fn assert_var_array_eq(&mut self, lhs: &Array<C, Var<C::N>>, rhs: &Array<C, Var<C::N>>) {
248        self.assert_var_eq(lhs.len(), rhs.len());
249        self.range(0, lhs.len()).for_each(|idx_vec, builder| {
250            let l = builder.get(lhs, idx_vec[0]);
251            let r = builder.get(rhs, idx_vec[0]);
252            builder.assert_var_eq(l, r);
253        });
254    }
255
256    /// Evaluate a block of operations if two expressions are equal.
257    pub fn if_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
258        &mut self,
259        lhs: LhsExpr,
260        rhs: RhsExpr,
261    ) -> IfBuilder<C> {
262        IfBuilder {
263            lhs: lhs.into(),
264            rhs: rhs.into(),
265            is_eq: true,
266            builder: self,
267        }
268    }
269
270    /// Evaluate a block of operations if two expressions are not equal.
271    pub fn if_ne<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
272        &mut self,
273        lhs: LhsExpr,
274        rhs: RhsExpr,
275    ) -> IfBuilder<C> {
276        IfBuilder {
277            lhs: lhs.into(),
278            rhs: rhs.into(),
279            is_eq: false,
280            builder: self,
281        }
282    }
283
284    /// Asserts that lhs is less than rhs in time O(rhs).
285    pub fn assert_less_than_slow_small_rhs<
286        LhsExpr: Into<SymbolicVar<C::N>>,
287        RhsExpr: Into<SymbolicVar<C::N>>,
288    >(
289        &mut self,
290        lhs: LhsExpr,
291        rhs: RhsExpr,
292    ) {
293        let lhs: Usize<_> = self.eval(lhs.into());
294        let rhs: Usize<_> = self.eval(rhs.into());
295        let product: Usize<_> = self.eval(lhs.clone());
296        self.range(1, rhs).for_each(|i_vec, builder| {
297            let i = i_vec[0];
298            let diff: Usize<_> = builder.eval(lhs.clone() - i);
299            builder.assign(&product, product.clone() * diff);
300        });
301        self.assert_usize_eq(product, RVar::from(0));
302    }
303
304    /// Asserts that lhs is less than rhs in time O(log(lhs) + log(rhs)).
305    ///
306    /// Only works for Felt == BabyBear and in the VM.
307    ///
308    /// Uses bit decomposition hint, which has large constant factor overhead, so prefer
309    /// [Self::assert_less_than_slow_small_rhs] when rhs is small.
310    pub fn assert_less_than_slow_bit_decomp(&mut self, lhs: Var<C::N>, rhs: Var<C::N>) {
311        let lhs = self.unsafe_cast_var_to_felt(lhs);
312        let rhs = self.unsafe_cast_var_to_felt(rhs);
313
314        let lhs_bits = self.num2bits_f(lhs, C::N::bits() as u32);
315        let rhs_bits = self.num2bits_f(rhs, C::N::bits() as u32);
316
317        let is_lt: Var<_> = self.eval(C::N::ZERO);
318
319        iter_zip!(self, lhs_bits, rhs_bits).for_each(|ptr_vec, builder| {
320            let lhs_bit = builder.iter_ptr_get(&lhs_bits, ptr_vec[0]);
321            let rhs_bit = builder.iter_ptr_get(&rhs_bits, ptr_vec[1]);
322
323            builder.if_ne(lhs_bit, rhs_bit).then(|builder| {
324                builder.assign(&is_lt, rhs_bit);
325            });
326        });
327        self.assert_var_eq(is_lt, C::N::ONE);
328    }
329
330    /// asserts that x has at most num_bits bits
331    pub fn range_check_var(&mut self, x: Var<C::N>, num_bits: usize) {
332        assert!(!self.flags.static_only, "range_check_var is dynamic");
333        assert!(num_bits <= 30);
334        self.trace_push(DslIr::RangeCheckV(x, num_bits));
335    }
336
337    /// Evaluate a block of operations over a range from start to end.
338    pub fn range(
339        &mut self,
340        start: impl Into<RVar<C::N>>,
341        end: impl Into<RVar<C::N>>,
342    ) -> IteratorBuilder<C> {
343        self.range_with_step(start, end, C::N::ONE)
344    }
345    /// Evaluate a block of operations over a range from start to end with a custom step.
346    pub fn range_with_step(
347        &mut self,
348        start: impl Into<RVar<C::N>>,
349        end: impl Into<RVar<C::N>>,
350        step: C::N,
351    ) -> IteratorBuilder<C> {
352        let start = start.into();
353        let end0 = end.into();
354        IteratorBuilder {
355            starts: vec![start],
356            end0,
357            step_sizes: vec![step],
358            builder: self,
359        }
360    }
361
362    pub fn zip<'a>(
363        &'a mut self,
364        arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
365    ) -> IteratorBuilder<'a, C> {
366        assert!(!arrays.is_empty());
367        if arrays.iter().all(|array| array.is_fixed()) {
368            IteratorBuilder {
369                starts: vec![RVar::zero(); arrays.len()],
370                end0: arrays[0].len().into(),
371                step_sizes: vec![C::N::ONE; arrays.len()],
372                builder: self,
373            }
374        } else if arrays.iter().all(|array| !array.is_fixed()) {
375            IteratorBuilder {
376                starts: arrays
377                    .iter()
378                    .map(|array| array.ptr().address.into())
379                    .collect(),
380                end0: {
381                    let len: RVar<C::N> = arrays[0].len().into();
382                    let size = arrays[0].element_size_of();
383                    let end: Var<C::N> =
384                        self.eval(arrays[0].ptr().address + len * RVar::from(size));
385                    end.into()
386                },
387                step_sizes: arrays
388                    .iter()
389                    .map(|array| C::N::from_canonical_usize(array.element_size_of()))
390                    .collect(),
391                builder: self,
392            }
393        } else {
394            panic!("Cannot use zipped pointer iterator with mixed arrays");
395        }
396    }
397
398    pub fn print_debug(&mut self, val: usize) {
399        let constant = self.eval(C::N::from_canonical_usize(val));
400        self.print_v(constant);
401    }
402
403    /// Print a variable.
404    pub fn print_v(&mut self, dst: Var<C::N>) {
405        self.operations.push(DslIr::PrintV(dst));
406    }
407
408    /// Print a felt.
409    pub fn print_f(&mut self, dst: Felt<C::F>) {
410        self.operations.push(DslIr::PrintF(dst));
411    }
412
413    /// Print an ext.
414    pub fn print_e(&mut self, dst: Ext<C::F, C::EF>) {
415        self.operations.push(DslIr::PrintE(dst));
416    }
417
418    pub fn hint_var(&mut self) -> Var<C::N> {
419        let ptr = self.alloc(RVar::one(), 1);
420        // Prepare data for hinting.
421        self.operations.push(DslIr::HintFelt());
422        let index = MemIndex {
423            index: RVar::zero(),
424            offset: 0,
425            size: 1,
426        };
427        self.operations.push(DslIr::StoreHintWord(ptr, index));
428        let v: Var<C::N> = self.uninit();
429        self.load(v, ptr, index);
430        v
431    }
432
433    pub fn hint_felt(&mut self) -> Felt<C::F> {
434        let ptr = self.alloc(RVar::one(), 1);
435        // Prepare data for hinting.
436        self.operations.push(DslIr::HintFelt());
437        let index = MemIndex {
438            index: RVar::zero(),
439            offset: 0,
440            size: 1,
441        };
442        self.operations.push(DslIr::StoreHintWord(ptr, index));
443        let f: Felt<C::F> = self.uninit();
444        self.load(f, ptr, index);
445        f
446    }
447
448    pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
449        let flattened = self.hint_felts_fixed(C::EF::D);
450
451        // Simply recast memory as Array<Ext>.
452        let array: Array<C, Ext<_, _>> = match flattened {
453            Array::Fixed(_) => unreachable!(),
454            Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::from(1)),
455        };
456        self.get(&array, 0)
457    }
458
459    /// Hint a vector of variables.
460    ///
461    /// Writes the next element of the witness stream into memory and returns it.
462    pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
463        self.hint_words()
464    }
465
466    /// Hint a vector of felts.
467    pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
468        self.hint_words()
469    }
470
471    pub fn hint_felts_fixed(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, Felt<C::F>> {
472        self.hint_words_fixed(len)
473    }
474
475    /// Hints an array of V and assumes V::size_of() == 1.
476    fn hint_words<V: MemVariable<C>>(&mut self) -> Array<C, V> {
477        assert_eq!(V::size_of(), 1);
478
479        // Allocate space for the length variable. We assume that mem[ptr..] is empty.
480        let ptr = self.alloc(RVar::one(), 1);
481
482        // Prepare length + data for hinting.
483        self.operations.push(DslIr::HintInputVec());
484
485        // Write and retrieve length hint.
486        let index = MemIndex {
487            index: RVar::zero(),
488            offset: 0,
489            size: 1,
490        };
491        // MemIndex.index share the same pointer, but it doesn't matter.
492        self.operations.push(DslIr::StoreHintWord(ptr, index));
493
494        let vlen: Var<C::N> = self.uninit();
495        self.load(vlen, ptr, index);
496        let arr = self.dyn_array(vlen);
497
498        // Write the content hints directly into the array memory.
499        iter_zip!(self, arr).for_each(|ptr_vec, builder| {
500            let index = MemIndex {
501                index: 0.into(),
502                offset: 0,
503                size: 1,
504            };
505            builder.operations.push(DslIr::StoreHintWord(
506                Ptr {
507                    address: ptr_vec[0].variable(),
508                },
509                index,
510            ));
511        });
512        arr
513    }
514
515    /// Hints an array of V and assumes V::size_of() == 1.
516    fn hint_words_fixed<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
517        assert_eq!(V::size_of(), 1);
518
519        let arr = self.dyn_array(len.into());
520        // Write the content hints directly into the array memory.
521        iter_zip!(self, arr).for_each(|ptr_vec, builder| {
522            let index = MemIndex {
523                index: 0.into(),
524                offset: 0,
525                size: 1,
526            };
527            builder.operations.push(DslIr::HintFelt());
528            builder.operations.push(DslIr::StoreHintWord(
529                Ptr {
530                    address: ptr_vec[0].variable(),
531                },
532                index,
533            ));
534        });
535        arr
536    }
537
538    /// Hint a vector of exts.
539    ///
540    /// Emits two hint opcodes: the first for the number of exts, the second for the list of exts
541    /// themselves.
542    pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
543        let len = self.hint_var();
544        let flattened = self.hint_felts();
545
546        let size = <Ext<C::F, C::EF> as MemVariable<C>>::size_of();
547        self.assert_usize_eq(flattened.len(), len * C::N::from_canonical_usize(size));
548
549        // Simply recast memory as Array<Ext>.
550        match flattened {
551            Array::Fixed(_) => unreachable!(),
552            Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::Var(len)),
553        }
554    }
555
556    /// Move data from input stream into hint space. Return an ID which can be used to load the
557    /// data at runtime.
558    pub fn hint_load(&mut self) -> Var<C::N> {
559        self.trace_push(DslIr::HintLoad());
560        let ptr = self.alloc(RVar::one(), 1);
561        let index = MemIndex {
562            index: RVar::zero(),
563            offset: 0,
564            size: 1,
565        };
566        // MemIndex.index share the same pointer, but it doesn't matter.
567        self.operations.push(DslIr::StoreHintWord(ptr, index));
568        let id: Var<C::N> = self.uninit();
569        self.load(id, ptr, index);
570        id
571    }
572
573    pub fn witness_var(&mut self) -> Var<C::N> {
574        assert!(
575            !self.is_sub_builder,
576            "Cannot create a witness var with a sub builder"
577        );
578        let witness = self.uninit();
579        self.operations
580            .push(DslIr::WitnessVar(witness, self.witness_var_count));
581        self.witness_var_count += 1;
582        witness
583    }
584
585    pub fn witness_felt(&mut self) -> Felt<C::F> {
586        assert!(
587            !self.is_sub_builder,
588            "Cannot create a witness felt with a sub builder"
589        );
590        let witness = self.uninit();
591        self.operations
592            .push(DslIr::WitnessFelt(witness, self.witness_felt_count));
593        self.witness_felt_count += 1;
594        witness
595    }
596
597    pub fn witness_ext(&mut self) -> Ext<C::F, C::EF> {
598        assert!(
599            !self.is_sub_builder,
600            "Cannot create a witness ext with a sub builder"
601        );
602        let witness = self.uninit();
603        self.operations
604            .push(DslIr::WitnessExt(witness, self.witness_ext_count));
605        self.witness_ext_count += 1;
606        witness
607    }
608
609    pub fn witness_load(&mut self, witness_refs: Vec<WitnessRef>) -> Usize<C::N> {
610        assert!(
611            !self.is_sub_builder,
612            "Cannot load witness refs with a sub builder"
613        );
614        let ret = self.witness_space.len();
615        self.witness_space.push(witness_refs);
616        ret.into()
617    }
618
619    pub fn get_witness_refs(&self, id: Usize<C::N>) -> &[WitnessRef] {
620        self.witness_space.get(id.value()).unwrap()
621    }
622
623    /// Throws an error.
624    pub fn error(&mut self) {
625        self.operations.trace_push(DslIr::Error());
626    }
627
628    fn get_nb_public_values(&mut self) -> Var<C::N> {
629        assert!(
630            !self.is_sub_builder,
631            "Cannot commit to public values with a sub builder"
632        );
633        if self.nb_public_values.is_none() {
634            self.nb_public_values = Some(self.eval(C::N::ZERO));
635        }
636        *self.nb_public_values.as_ref().unwrap()
637    }
638
639    fn commit_public_value_and_increment(&mut self, val: Felt<C::F>, nb_public_values: Var<C::N>) {
640        assert!(
641            !self.flags.static_only,
642            "Static mode should use static_commit_public_value"
643        );
644        self.operations.push(DslIr::Publish(val, nb_public_values));
645        self.assign(&nb_public_values, nb_public_values + C::N::ONE);
646    }
647
648    /// Commits a Var as public value. This value will be constrained when verified. This method
649    /// should only be used in static mode.
650    pub fn static_commit_public_value(&mut self, index: usize, val: Var<C::N>) {
651        assert!(
652            self.flags.static_only,
653            "Dynamic mode should use commit_public_value instead."
654        );
655        self.operations.push(DslIr::CircuitPublish(val, index));
656    }
657
658    /// Register and commits a felt as public value.  This value will be constrained when verified.
659    pub fn commit_public_value(&mut self, val: Felt<C::F>) {
660        let nb_public_values = self.get_nb_public_values();
661        self.commit_public_value_and_increment(val, nb_public_values);
662    }
663
664    /// Commits an array of felts in public values.
665    pub fn commit_public_values(&mut self, vals: &Array<C, Felt<C::F>>) {
666        let nb_public_values = self.get_nb_public_values();
667        let len = vals.len();
668        self.range(0, len).for_each(|idx_vec, builder| {
669            let val = builder.get(vals, idx_vec[0]);
670            builder.commit_public_value_and_increment(val, nb_public_values);
671        });
672    }
673
674    pub fn cycle_tracker_start(&mut self, name: &str) {
675        self.operations
676            .push(DslIr::CycleTrackerStart(name.to_string()));
677    }
678
679    pub fn cycle_tracker_end(&mut self, name: &str) {
680        self.operations
681            .push(DslIr::CycleTrackerEnd(name.to_string()));
682    }
683
684    pub fn halt(&mut self) {
685        self.operations.push(DslIr::Halt);
686    }
687}
688
689/// A builder for the DSL that handles if statements.
690pub struct IfBuilder<'a, C: Config> {
691    lhs: SymbolicVar<C::N>,
692    rhs: SymbolicVar<C::N>,
693    is_eq: bool,
694    pub(crate) builder: &'a mut Builder<C>,
695}
696
697/// A set of conditions that if statements can be based on.
698enum IfCondition<N> {
699    EqConst(N, N),
700    NeConst(N, N),
701    Eq(Var<N>, Var<N>),
702    EqI(Var<N>, N),
703    Ne(Var<N>, Var<N>),
704    NeI(Var<N>, N),
705}
706
707impl<C: Config> IfBuilder<'_, C> {
708    pub fn then(&mut self, mut f: impl FnMut(&mut Builder<C>)) {
709        // Get the condition reduced from the expressions for lhs and rhs.
710        let condition = self.condition();
711        // Early return for const branches.
712        match condition {
713            IfCondition::EqConst(lhs, rhs) => {
714                if lhs == rhs {
715                    return f(self.builder);
716                }
717                return;
718            }
719            IfCondition::NeConst(lhs, rhs) => {
720                if lhs != rhs {
721                    return f(self.builder);
722                }
723                return;
724            }
725            _ => (),
726        }
727        assert!(
728            !self.builder.flags.static_only,
729            "Cannot use dynamic branch in static mode"
730        );
731
732        // Execute the `then` block and collect the instructions.
733        let mut f_builder = self.builder.create_sub_builder();
734        f(&mut f_builder);
735        let then_instructions = f_builder.operations;
736
737        // Dispatch instructions to the correct conditional block.
738        match condition {
739            IfCondition::Eq(lhs, rhs) => {
740                let op = DslIr::IfEq(lhs, rhs, then_instructions, Default::default());
741                self.builder.operations.push(op);
742            }
743            IfCondition::EqI(lhs, rhs) => {
744                let op = DslIr::IfEqI(lhs, rhs, then_instructions, Default::default());
745                self.builder.operations.push(op);
746            }
747            IfCondition::Ne(lhs, rhs) => {
748                let op = DslIr::IfNe(lhs, rhs, then_instructions, Default::default());
749                self.builder.operations.push(op);
750            }
751            IfCondition::NeI(lhs, rhs) => {
752                let op = DslIr::IfNeI(lhs, rhs, then_instructions, Default::default());
753                self.builder.operations.push(op);
754            }
755            _ => unreachable!("Const if should have returned early"),
756        }
757    }
758
759    pub fn then_or_else(
760        &mut self,
761        mut then_f: impl FnMut(&mut Builder<C>),
762        mut else_f: impl FnMut(&mut Builder<C>),
763    ) {
764        // Get the condition reduced from the expressions for lhs and rhs.
765        let condition = self.condition();
766        // Early return for const branches.
767        match condition {
768            IfCondition::EqConst(lhs, rhs) => {
769                if lhs == rhs {
770                    return then_f(self.builder);
771                }
772                return else_f(self.builder);
773            }
774            IfCondition::NeConst(lhs, rhs) => {
775                if lhs != rhs {
776                    return then_f(self.builder);
777                }
778                return else_f(self.builder);
779            }
780            _ => (),
781        }
782        assert!(
783            !self.builder.flags.static_only,
784            "Cannot use dynamic branch in static mode"
785        );
786        let mut then_builder = self.builder.create_sub_builder();
787
788        // Execute the `then` and `else_then` blocks and collect the instructions.
789        then_f(&mut then_builder);
790        let then_instructions = then_builder.operations;
791
792        let mut else_builder = self.builder.create_sub_builder();
793        else_f(&mut else_builder);
794        let else_instructions = else_builder.operations;
795
796        // Dispatch instructions to the correct conditional block.
797        match condition {
798            IfCondition::Eq(lhs, rhs) => {
799                let op = DslIr::IfEq(lhs, rhs, then_instructions, else_instructions);
800                self.builder.operations.push(op);
801            }
802            IfCondition::EqI(lhs, rhs) => {
803                let op = DslIr::IfEqI(lhs, rhs, then_instructions, else_instructions);
804                self.builder.operations.push(op);
805            }
806            IfCondition::Ne(lhs, rhs) => {
807                let op = DslIr::IfNe(lhs, rhs, then_instructions, else_instructions);
808                self.builder.operations.push(op);
809            }
810            IfCondition::NeI(lhs, rhs) => {
811                let op = DslIr::IfNeI(lhs, rhs, then_instructions, else_instructions);
812                self.builder.operations.push(op);
813            }
814            _ => unreachable!("Const if should have returned early"),
815        }
816    }
817
818    fn condition(&mut self) -> IfCondition<C::N> {
819        match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
820            (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
821                IfCondition::EqConst(lhs, rhs)
822            }
823            (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
824                IfCondition::NeConst(lhs, rhs)
825            }
826            (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
827                IfCondition::EqI(rhs, lhs)
828            }
829            (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
830                IfCondition::NeI(rhs, lhs)
831            }
832            (SymbolicVar::Const(lhs, _), rhs, true) => {
833                let rhs: Var<C::N> = self.builder.eval(rhs);
834                IfCondition::EqI(rhs, lhs)
835            }
836            (SymbolicVar::Const(lhs, _), rhs, false) => {
837                let rhs: Var<C::N> = self.builder.eval(rhs);
838                IfCondition::NeI(rhs, lhs)
839            }
840            (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
841                let lhs: Var<C::N> = self.builder.eval(lhs);
842                IfCondition::EqI(lhs, rhs)
843            }
844            (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
845                let lhs: Var<C::N> = self.builder.eval(lhs);
846                IfCondition::NeI(lhs, rhs)
847            }
848            (lhs, SymbolicVar::Const(rhs, _), true) => {
849                let lhs: Var<C::N> = self.builder.eval(lhs);
850                IfCondition::EqI(lhs, rhs)
851            }
852            (lhs, SymbolicVar::Const(rhs, _), false) => {
853                let lhs: Var<C::N> = self.builder.eval(lhs);
854                IfCondition::NeI(lhs, rhs)
855            }
856            (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
857            (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
858                IfCondition::Ne(lhs, rhs)
859            }
860            (SymbolicVar::Val(lhs, _), rhs, true) => {
861                let rhs: Var<C::N> = self.builder.eval(rhs);
862                IfCondition::Eq(lhs, rhs)
863            }
864            (SymbolicVar::Val(lhs, _), rhs, false) => {
865                let rhs: Var<C::N> = self.builder.eval(rhs);
866                IfCondition::Ne(lhs, rhs)
867            }
868            (lhs, SymbolicVar::Val(rhs, _), true) => {
869                let lhs: Var<C::N> = self.builder.eval(lhs);
870                IfCondition::Eq(lhs, rhs)
871            }
872            (lhs, SymbolicVar::Val(rhs, _), false) => {
873                let lhs: Var<C::N> = self.builder.eval(lhs);
874                IfCondition::Ne(lhs, rhs)
875            }
876            (lhs, rhs, true) => {
877                let lhs: Var<C::N> = self.builder.eval(lhs);
878                let rhs: Var<C::N> = self.builder.eval(rhs);
879                IfCondition::Eq(lhs, rhs)
880            }
881            (lhs, rhs, false) => {
882                let lhs: Var<C::N> = self.builder.eval(lhs);
883                let rhs: Var<C::N> = self.builder.eval(rhs);
884                IfCondition::Ne(lhs, rhs)
885            }
886        }
887    }
888}
889
890// iterates through zipped pointers
891pub struct IteratorBuilder<'a, C: Config> {
892    starts: Vec<RVar<C::N>>,
893    end0: RVar<C::N>,
894    step_sizes: Vec<C::N>,
895    builder: &'a mut Builder<C>,
896}
897
898impl<C: Config> IteratorBuilder<'_, C> {
899    pub fn for_each(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
900        assert!(self.starts.len() == self.step_sizes.len());
901        assert!(!self.starts.is_empty());
902
903        if self.starts.iter().all(|start| start.is_const()) && self.end0.is_const() {
904            self.for_each_unrolled(|ptrs, builder| {
905                f(ptrs, builder);
906            });
907            return;
908        }
909
910        self.for_each_dynamic(|ptrs, builder| {
911            f(ptrs, builder);
912        });
913    }
914
915    fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
916        let mut ptrs: Vec<_> = self
917            .starts
918            .iter()
919            .map(|start| start.field_value())
920            .collect();
921        let end0 = self.end0.field_value();
922        while ptrs[0] != end0 {
923            f(
924                ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
925                self.builder,
926            );
927            for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
928                *ptr += *step_size;
929            }
930        }
931    }
932
933    fn for_each_dynamic(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
934        assert!(
935            !self.builder.flags.static_only,
936            "Cannot use dynamic loop in static mode"
937        );
938
939        let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
940            .map(|_| self.builder.uninit())
941            .collect();
942        let mut loop_body_builder = self.builder.create_sub_builder();
943
944        f(
945            loop_variables.iter().map(|&v| v.into()).collect(),
946            &mut loop_body_builder,
947        );
948
949        let loop_instructions = loop_body_builder.operations;
950        let op = DslIr::ZipFor(
951            self.starts.clone(),
952            self.end0,
953            self.step_sizes.clone(),
954            loop_variables,
955            loop_instructions,
956        );
957        self.builder.operations.push(op);
958    }
959}