openvm_native_compiler/ir/
builder.rs

1use std::{iter::Zip, vec::IntoIter};
2
3use backtrace::Backtrace;
4use itertools::izip;
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing};
7use serde::{Deserialize, Serialize};
8
9use super::{
10    Array, Config, DslIr, Ext, Felt, FromConstant, MemIndex, MemVariable, RVar, SymbolicExt,
11    SymbolicFelt, SymbolicVar, Usize, Var, Variable, WitnessRef,
12};
13use crate::ir::{collections::ArrayLike, Ptr};
14
15/// TracedVec is a Vec wrapper that records a trace whenever an element is pushed. When extending
16/// from another TracedVec, the traces are copied over.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TracedVec<T> {
19    pub vec: Vec<T>,
20    pub traces: Vec<Option<Backtrace>>,
21}
22
23impl<T> Default for TracedVec<T> {
24    fn default() -> Self {
25        Self::new()
26    }
27}
28
29impl<T> From<Vec<T>> for TracedVec<T> {
30    fn from(vec: Vec<T>) -> Self {
31        let len = vec.len();
32        Self {
33            vec,
34            traces: vec![None; len],
35        }
36    }
37}
38
39impl<T> TracedVec<T> {
40    pub const fn new() -> Self {
41        Self {
42            vec: Vec::new(),
43            traces: Vec::new(),
44        }
45    }
46
47    pub fn push(&mut self, value: T) {
48        self.vec.push(value);
49        self.traces.push(None);
50    }
51
52    /// Pushes a value to the vector and records a backtrace if RUST_BACKTRACE is enabled
53    pub fn trace_push(&mut self, value: T) {
54        self.vec.push(value);
55        if std::env::var_os("RUST_BACKTRACE").is_none() {
56            self.traces.push(None);
57        } else {
58            self.traces.push(Some(Backtrace::new_unresolved()));
59        }
60    }
61
62    pub fn extend<I: IntoIterator<Item = (T, Option<Backtrace>)>>(&mut self, iter: I) {
63        let iter = iter.into_iter();
64        let len = iter.size_hint().0;
65        self.vec.reserve(len);
66        self.traces.reserve(len);
67        for (value, trace) in iter {
68            self.vec.push(value);
69            self.traces.push(trace);
70        }
71    }
72
73    pub fn is_empty(&self) -> bool {
74        self.vec.is_empty()
75    }
76}
77
78impl<T> IntoIterator for TracedVec<T> {
79    type Item = (T, Option<Backtrace>);
80    type IntoIter = Zip<IntoIter<T>, IntoIter<Option<Backtrace>>>;
81
82    fn into_iter(self) -> Self::IntoIter {
83        self.vec.into_iter().zip(self.traces)
84    }
85}
86
87#[derive(Debug, Copy, Clone, Default)]
88pub struct BuilderFlags {
89    pub debug: bool,
90    /// If true, branching/looping/heap memory is disabled.
91    pub static_only: bool,
92}
93
94/// A builder for the DSL.
95///
96/// Can compile to both assembly and a set of constraints.
97#[derive(Debug, Clone, Default)]
98pub struct Builder<C: Config> {
99    pub(crate) var_count: u32,
100    pub(crate) felt_count: u32,
101    pub(crate) ext_count: u32,
102    pub operations: TracedVec<DslIr<C>>,
103    pub(crate) nb_public_values: Option<Var<C::N>>,
104    pub(crate) witness_var_count: u32,
105    pub(crate) witness_felt_count: u32,
106    pub(crate) witness_ext_count: u32,
107    pub(crate) witness_space: Vec<Vec<WitnessRef>>,
108    pub flags: BuilderFlags,
109    pub is_sub_builder: bool,
110}
111
112impl<C: Config> Builder<C> {
113    /// Creates a new builder with a given number of counts for each type.
114    pub fn create_sub_builder(&self) -> Self {
115        Self {
116            var_count: self.var_count,
117            felt_count: self.felt_count,
118            ext_count: self.ext_count,
119            // Witness counts are only used when the target is a circuit.  And sub-builders are
120            // not used when the target is a circuit, so it is fine to set the witness counts to 0.
121            witness_var_count: 0,
122            witness_felt_count: 0,
123            witness_ext_count: 0,
124            witness_space: Default::default(),
125            operations: Default::default(),
126            nb_public_values: self.nb_public_values,
127            flags: self.flags,
128            is_sub_builder: true,
129        }
130    }
131
132    /// Pushes an operation to the builder.
133    pub fn push(&mut self, op: DslIr<C>) {
134        self.operations.push(op);
135    }
136
137    /// Pushes an operation to the builder and records a trace if RUST_BACKTRACE=1.
138    pub fn trace_push(&mut self, op: DslIr<C>) {
139        self.operations.trace_push(op);
140    }
141
142    /// Creates an uninitialized variable.
143    pub fn uninit<V: Variable<C>>(&mut self) -> V {
144        V::uninit(self)
145    }
146
147    /// Evaluates an expression and returns a variable.
148    pub fn eval<V: Variable<C>, E: Into<V::Expression>>(&mut self, expr: E) -> V {
149        V::eval(self, expr)
150    }
151
152    /// Evaluates an expression and returns a right value.
153    pub fn eval_expr(&mut self, expr: impl Into<SymbolicVar<C::N>>) -> RVar<C::N> {
154        let expr = expr.into();
155        match expr {
156            SymbolicVar::Const(c, _) => RVar::Const(c),
157            SymbolicVar::Val(val, _) => RVar::Val(val),
158            _ => {
159                let ret: Var<_> = self.eval(expr);
160                RVar::Val(ret)
161            }
162        }
163    }
164
165    /// Increments Usize by one.
166    pub fn inc(&mut self, u: &Usize<C::N>) {
167        self.assign(u, u.clone() + RVar::one());
168    }
169
170    /// Evaluates a constant expression and returns a variable.
171    pub fn constant<V: FromConstant<C>>(&mut self, value: V::Constant) -> V {
172        V::constant(value, self)
173    }
174
175    /// Assigns an expression to a variable.
176    pub fn assign<V: Variable<C>, E: Into<V::Expression>>(&mut self, dst: &V, expr: E) {
177        dst.assign(expr.into(), self);
178    }
179
180    /// Casts a Felt to a Var.
181    pub fn cast_felt_to_var(&mut self, felt: Felt<C::F>) -> Var<C::N> {
182        let var: Var<_> = self.uninit();
183        self.operations.push(DslIr::CastFV(var, felt));
184        var
185    }
186    /// Casts a Var to a Felt.
187    pub fn unsafe_cast_var_to_felt(&mut self, var: Var<C::N>) -> Felt<C::F> {
188        assert!(!self.flags.static_only, "dynamic mode only");
189        let felt: Felt<_> = self.uninit();
190        self.operations.push(DslIr::UnsafeCastVF(felt, var));
191        felt
192    }
193
194    /// Asserts that a Usize is non-zero
195    pub fn assert_nonzero(&mut self, u: &Usize<C::N>) {
196        if self.flags.static_only {
197            assert_ne!(u.value(), 0, "assert_nonzero failed on constant zero");
198        } else {
199            self.operations.push(DslIr::AssertNonZero(u.clone()));
200        }
201    }
202
203    /// Asserts that two expressions are equal.
204    pub fn assert_eq<V: Variable<C>>(
205        &mut self,
206        lhs: impl Into<V::Expression>,
207        rhs: impl Into<V::Expression>,
208    ) {
209        V::assert_eq(lhs, rhs, self);
210    }
211
212    /// Assert that two vars are equal.
213    pub fn assert_var_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
214        &mut self,
215        lhs: LhsExpr,
216        rhs: RhsExpr,
217    ) {
218        self.assert_eq::<Var<C::N>>(lhs, rhs);
219    }
220
221    /// Assert that two felts are equal.
222    pub fn assert_felt_eq<LhsExpr: Into<SymbolicFelt<C::F>>, RhsExpr: Into<SymbolicFelt<C::F>>>(
223        &mut self,
224        lhs: LhsExpr,
225        rhs: RhsExpr,
226    ) {
227        self.assert_eq::<Felt<C::F>>(lhs, rhs);
228    }
229
230    /// Assert that two exts are equal.
231    pub fn assert_ext_eq<
232        LhsExpr: Into<SymbolicExt<C::F, C::EF>>,
233        RhsExpr: Into<SymbolicExt<C::F, C::EF>>,
234    >(
235        &mut self,
236        lhs: LhsExpr,
237        rhs: RhsExpr,
238    ) {
239        self.assert_eq::<Ext<C::F, C::EF>>(lhs, rhs);
240    }
241
242    pub fn assert_usize_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
243        &mut self,
244        lhs: LhsExpr,
245        rhs: RhsExpr,
246    ) {
247        self.assert_eq::<Usize<C::N>>(lhs, rhs);
248    }
249
250    /// Assert that two arrays are equal.
251    pub fn assert_var_array_eq(&mut self, lhs: &Array<C, Var<C::N>>, rhs: &Array<C, Var<C::N>>) {
252        self.assert_var_eq(lhs.len(), rhs.len());
253        self.range(0, lhs.len()).for_each(|idx_vec, builder| {
254            let l = builder.get(lhs, idx_vec[0]);
255            let r = builder.get(rhs, idx_vec[0]);
256            builder.assert_var_eq(l, r);
257        });
258    }
259
260    /// Evaluate a block of operations if two expressions are equal.
261    pub fn if_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
262        &mut self,
263        lhs: LhsExpr,
264        rhs: RhsExpr,
265    ) -> IfBuilder<'_, C> {
266        IfBuilder {
267            lhs: lhs.into(),
268            rhs: rhs.into(),
269            is_eq: true,
270            builder: self,
271        }
272    }
273
274    /// Evaluate a block of operations if two expressions are not equal.
275    pub fn if_ne<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
276        &mut self,
277        lhs: LhsExpr,
278        rhs: RhsExpr,
279    ) -> IfBuilder<'_, C> {
280        IfBuilder {
281            lhs: lhs.into(),
282            rhs: rhs.into(),
283            is_eq: false,
284            builder: self,
285        }
286    }
287
288    /// Asserts that lhs is less than rhs in time O(rhs).
289    pub fn assert_less_than_slow_small_rhs<
290        LhsExpr: Into<SymbolicVar<C::N>>,
291        RhsExpr: Into<SymbolicVar<C::N>>,
292    >(
293        &mut self,
294        lhs: LhsExpr,
295        rhs: RhsExpr,
296    ) {
297        let lhs: Usize<_> = self.eval(lhs.into());
298        let rhs: Usize<_> = self.eval(rhs.into());
299        let product: Usize<_> = self.eval(lhs.clone());
300        self.range(1, rhs).for_each(|i_vec, builder| {
301            let i = i_vec[0];
302            let diff: Usize<_> = builder.eval(lhs.clone() - i);
303            builder.assign(&product, product.clone() * diff);
304        });
305        self.assert_usize_eq(product, RVar::from(0));
306    }
307
308    /// Asserts that lhs is less than rhs in time O(lhs).
309    pub fn assert_less_than_slow_small_lhs<
310        LhsExpr: Into<SymbolicVar<C::N>>,
311        RhsExpr: Into<SymbolicVar<C::N>>,
312    >(
313        &mut self,
314        lhs: LhsExpr,
315        rhs: RhsExpr,
316    ) {
317        let lhs: Usize<_> = self.eval(lhs.into());
318        let rhs: Usize<_> = self.eval(rhs.into());
319        let product: Usize<_> = self.eval(rhs.clone());
320        let lhs_plus_one: Usize<_> = self.eval(lhs.clone() + RVar::one());
321        self.range(1, lhs_plus_one).for_each(|i_vec, builder| {
322            let i = i_vec[0];
323            let diff: Usize<_> = builder.eval(rhs.clone() - i);
324            builder.assign(&product, product.clone() * diff);
325        });
326        self.assert_nonzero(&product);
327    }
328
329    /// Asserts that lhs is less than rhs in time O(log(lhs) + log(rhs)).
330    ///
331    /// Only works for Felt == BabyBear and in the VM.
332    ///
333    /// Uses bit decomposition hint, which has large constant factor overhead, so prefer
334    /// [Self::assert_less_than_slow_small_rhs] when rhs is small.
335    pub fn assert_less_than_slow_bit_decomp(&mut self, lhs: Var<C::N>, rhs: Var<C::N>) {
336        let lhs = self.unsafe_cast_var_to_felt(lhs);
337        let rhs = self.unsafe_cast_var_to_felt(rhs);
338
339        let lhs_bits = self.num2bits_f(lhs, C::N::bits() as u32);
340        let rhs_bits = self.num2bits_f(rhs, C::N::bits() as u32);
341
342        let is_lt: Var<_> = self.eval(C::N::ZERO);
343
344        iter_zip!(self, lhs_bits, rhs_bits).for_each(|ptr_vec, builder| {
345            let lhs_bit = builder.iter_ptr_get(&lhs_bits, ptr_vec[0]);
346            let rhs_bit = builder.iter_ptr_get(&rhs_bits, ptr_vec[1]);
347
348            builder.if_ne(lhs_bit, rhs_bit).then(|builder| {
349                builder.assign(&is_lt, rhs_bit);
350            });
351        });
352        self.assert_var_eq(is_lt, C::N::ONE);
353    }
354
355    /// asserts that x has at most num_bits bits
356    pub fn range_check_var(&mut self, x: Var<C::N>, num_bits: usize) {
357        assert!(!self.flags.static_only, "range_check_var is dynamic");
358        assert!(num_bits <= 30);
359        self.trace_push(DslIr::RangeCheckV(x, num_bits));
360    }
361
362    /// Evaluate a block of operations over a range from start to end.
363    pub fn range(
364        &mut self,
365        start: impl Into<RVar<C::N>>,
366        end: impl Into<RVar<C::N>>,
367    ) -> IteratorBuilder<'_, C> {
368        self.range_with_step(start, end, C::N::ONE)
369    }
370    /// Evaluate a block of operations over a range from start to end with a custom step.
371    pub fn range_with_step(
372        &mut self,
373        start: impl Into<RVar<C::N>>,
374        end: impl Into<RVar<C::N>>,
375        step: C::N,
376    ) -> IteratorBuilder<'_, C> {
377        let start = start.into();
378        let end0 = end.into();
379        IteratorBuilder {
380            starts: vec![start],
381            end0,
382            step_sizes: vec![step],
383            builder: self,
384        }
385    }
386
387    pub fn zip<'a>(
388        &'a mut self,
389        arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
390    ) -> IteratorBuilder<'a, C> {
391        assert!(!arrays.is_empty());
392        if arrays.iter().all(|array| array.is_fixed()) {
393            IteratorBuilder {
394                starts: vec![RVar::zero(); arrays.len()],
395                end0: arrays[0].len().into(),
396                step_sizes: vec![C::N::ONE; arrays.len()],
397                builder: self,
398            }
399        } else if arrays.iter().all(|array| !array.is_fixed()) {
400            IteratorBuilder {
401                starts: arrays
402                    .iter()
403                    .map(|array| array.ptr().address.into())
404                    .collect(),
405                end0: {
406                    let len: RVar<C::N> = arrays[0].len().into();
407                    let size = arrays[0].element_size_of();
408                    let end: Var<C::N> =
409                        self.eval(arrays[0].ptr().address + len * RVar::from(size));
410                    end.into()
411                },
412                step_sizes: arrays
413                    .iter()
414                    .map(|array| C::N::from_usize(array.element_size_of()))
415                    .collect(),
416                builder: self,
417            }
418        } else {
419            panic!("Cannot use zipped pointer iterator with mixed arrays");
420        }
421    }
422
423    pub fn print_debug(&mut self, val: usize) {
424        let constant = self.eval(C::N::from_usize(val));
425        self.print_v(constant);
426    }
427
428    /// Print a variable.
429    pub fn print_v(&mut self, dst: Var<C::N>) {
430        self.operations.push(DslIr::PrintV(dst));
431    }
432
433    /// Print a felt.
434    pub fn print_f(&mut self, dst: Felt<C::F>) {
435        self.operations.push(DslIr::PrintF(dst));
436    }
437
438    /// Print an ext.
439    pub fn print_e(&mut self, dst: Ext<C::F, C::EF>) {
440        self.operations.push(DslIr::PrintE(dst));
441    }
442
443    pub fn hint_var(&mut self) -> Var<C::N> {
444        let ptr = self.alloc(RVar::one(), 1);
445        // Prepare data for hinting.
446        self.operations.push(DslIr::HintFelt());
447        let index = MemIndex {
448            index: RVar::zero(),
449            offset: 0,
450            size: 1,
451        };
452        self.operations.push(DslIr::StoreHintWord(ptr, index));
453        let v: Var<C::N> = self.uninit();
454        self.load(v, ptr, index);
455        v
456    }
457
458    pub fn hint_felt(&mut self) -> Felt<C::F> {
459        let ptr = self.alloc(RVar::one(), 1);
460        // Prepare data for hinting.
461        self.operations.push(DslIr::HintFelt());
462        let index = MemIndex {
463            index: RVar::zero(),
464            offset: 0,
465            size: 1,
466        };
467        self.operations.push(DslIr::StoreHintWord(ptr, index));
468        let f: Felt<C::F> = self.uninit();
469        self.load(f, ptr, index);
470        f
471    }
472
473    pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
474        let flattened = self.hint_felts_fixed(C::EF::DIMENSION);
475
476        // Simply recast memory as Array<Ext>.
477        let array: Array<C, Ext<_, _>> = match flattened {
478            Array::Fixed(_) => unreachable!(),
479            Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::from(1)),
480        };
481        self.get(&array, 0)
482    }
483
484    /// Hint a vector of variables.
485    ///
486    /// Writes the next element of the witness stream into memory and returns it.
487    pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
488        self.hint_words()
489    }
490
491    /// Hint a vector of felts.
492    pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
493        self.hint_words()
494    }
495
496    pub fn hint_felts_fixed(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, Felt<C::F>> {
497        self.hint_words_fixed(len)
498    }
499
500    /// Hints an array of V and assumes V::size_of() == 1.
501    fn hint_words<V: MemVariable<C>>(&mut self) -> Array<C, V> {
502        assert_eq!(V::size_of(), 1);
503
504        // Allocate space for the length variable. We assume that mem[ptr..] is empty.
505        let ptr = self.alloc(RVar::one(), 1);
506
507        // Prepare length + data for hinting.
508        self.operations.push(DslIr::HintInputVec());
509
510        // Write and retrieve length hint.
511        let index = MemIndex {
512            index: RVar::zero(),
513            offset: 0,
514            size: 1,
515        };
516        // MemIndex.index share the same pointer, but it doesn't matter.
517        self.operations.push(DslIr::StoreHintWord(ptr, index));
518
519        let vlen: Var<C::N> = self.uninit();
520        self.load(vlen, ptr, index);
521        let arr = self.dyn_array(vlen);
522
523        // Write the content hints directly into the array memory.
524        iter_zip!(self, arr).for_each(|ptr_vec, builder| {
525            let index = MemIndex {
526                index: 0.into(),
527                offset: 0,
528                size: 1,
529            };
530            builder.operations.push(DslIr::StoreHintWord(
531                Ptr {
532                    address: ptr_vec[0].variable(),
533                },
534                index,
535            ));
536        });
537        arr
538    }
539
540    /// Hints an array of V and assumes V::size_of() == 1.
541    fn hint_words_fixed<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
542        assert_eq!(V::size_of(), 1);
543
544        let arr = self.dyn_array(len.into());
545        // Write the content hints directly into the array memory.
546        iter_zip!(self, arr).for_each(|ptr_vec, builder| {
547            let index = MemIndex {
548                index: 0.into(),
549                offset: 0,
550                size: 1,
551            };
552            builder.operations.push(DslIr::HintFelt());
553            builder.operations.push(DslIr::StoreHintWord(
554                Ptr {
555                    address: ptr_vec[0].variable(),
556                },
557                index,
558            ));
559        });
560        arr
561    }
562
563    /// Hint a vector of exts.
564    ///
565    /// Emits two hint opcodes: the first for the number of exts, the second for the list of exts
566    /// themselves.
567    pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
568        let len = self.hint_var();
569        let flattened = self.hint_felts();
570
571        let size = <Ext<C::F, C::EF> as MemVariable<C>>::size_of();
572        self.assert_usize_eq(flattened.len(), len * C::N::from_usize(size));
573
574        // Simply recast memory as Array<Ext>.
575        match flattened {
576            Array::Fixed(_) => unreachable!(),
577            Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::Var(len)),
578        }
579    }
580
581    /// Move data from input stream into hint space. Return an ID which can be used to load the
582    /// data at runtime.
583    pub fn hint_load(&mut self) -> Var<C::N> {
584        self.trace_push(DslIr::HintLoad());
585        let ptr = self.alloc(RVar::one(), 1);
586        let index = MemIndex {
587            index: RVar::zero(),
588            offset: 0,
589            size: 1,
590        };
591        // MemIndex.index share the same pointer, but it doesn't matter.
592        self.operations.push(DslIr::StoreHintWord(ptr, index));
593        let id: Var<C::N> = self.uninit();
594        self.load(id, ptr, index);
595        id
596    }
597
598    pub fn witness_var(&mut self) -> Var<C::N> {
599        assert!(
600            !self.is_sub_builder,
601            "Cannot create a witness var with a sub builder"
602        );
603        let witness = self.uninit();
604        self.operations
605            .push(DslIr::WitnessVar(witness, self.witness_var_count));
606        self.witness_var_count += 1;
607        witness
608    }
609
610    pub fn witness_felt(&mut self) -> Felt<C::F> {
611        assert!(
612            !self.is_sub_builder,
613            "Cannot create a witness felt with a sub builder"
614        );
615        let witness = self.uninit();
616        self.operations
617            .push(DslIr::WitnessFelt(witness, self.witness_felt_count));
618        self.witness_felt_count += 1;
619        witness
620    }
621
622    pub fn witness_ext(&mut self) -> Ext<C::F, C::EF> {
623        assert!(
624            !self.is_sub_builder,
625            "Cannot create a witness ext with a sub builder"
626        );
627        let witness = self.uninit();
628        self.operations
629            .push(DslIr::WitnessExt(witness, self.witness_ext_count));
630        self.witness_ext_count += 1;
631        witness
632    }
633
634    pub fn witness_load(&mut self, witness_refs: Vec<WitnessRef>) -> Usize<C::N> {
635        assert!(
636            !self.is_sub_builder,
637            "Cannot load witness refs with a sub builder"
638        );
639        let ret = self.witness_space.len();
640        self.witness_space.push(witness_refs);
641        ret.into()
642    }
643
644    pub fn get_witness_refs(&self, id: Usize<C::N>) -> &[WitnessRef] {
645        self.witness_space.get(id.value()).unwrap()
646    }
647
648    /// Throws an error.
649    pub fn error(&mut self) {
650        self.operations.trace_push(DslIr::Error());
651    }
652
653    fn get_nb_public_values(&mut self) -> Var<C::N> {
654        assert!(
655            !self.is_sub_builder,
656            "Cannot commit to public values with a sub builder"
657        );
658        if self.nb_public_values.is_none() {
659            self.nb_public_values = Some(self.eval(C::N::ZERO));
660        }
661        *self.nb_public_values.as_ref().unwrap()
662    }
663
664    fn commit_public_value_and_increment(&mut self, val: Felt<C::F>, nb_public_values: Var<C::N>) {
665        assert!(
666            !self.flags.static_only,
667            "Static mode should use static_commit_public_value"
668        );
669        self.operations.push(DslIr::Publish(val, nb_public_values));
670        self.assign(&nb_public_values, nb_public_values + C::N::ONE);
671    }
672
673    /// Commits a Var as public value. This value will be constrained when verified. This method
674    /// should only be used in static mode.
675    pub fn static_commit_public_value(&mut self, index: usize, val: Var<C::N>) {
676        assert!(
677            self.flags.static_only,
678            "Dynamic mode should use commit_public_value instead."
679        );
680        self.operations.push(DslIr::CircuitPublish(val, index));
681    }
682
683    /// Register and commits a felt as public value.  This value will be constrained when verified.
684    pub fn commit_public_value(&mut self, val: Felt<C::F>) {
685        let nb_public_values = self.get_nb_public_values();
686        self.commit_public_value_and_increment(val, nb_public_values);
687    }
688
689    /// Commits an array of felts in public values.
690    pub fn commit_public_values(&mut self, vals: &Array<C, Felt<C::F>>) {
691        let nb_public_values = self.get_nb_public_values();
692        let len = vals.len();
693        self.range(0, len).for_each(|idx_vec, builder| {
694            let val = builder.get(vals, idx_vec[0]);
695            builder.commit_public_value_and_increment(val, nb_public_values);
696        });
697    }
698
699    pub fn cycle_tracker_start(&mut self, name: &str) {
700        self.operations
701            .push(DslIr::CycleTrackerStart(name.to_string()));
702    }
703
704    pub fn cycle_tracker_end(&mut self, name: &str) {
705        self.operations
706            .push(DslIr::CycleTrackerEnd(name.to_string()));
707    }
708
709    pub fn halt(&mut self) {
710        self.operations.push(DslIr::Halt);
711    }
712}
713
714/// A builder for the DSL that handles if statements.
715pub struct IfBuilder<'a, C: Config> {
716    lhs: SymbolicVar<C::N>,
717    rhs: SymbolicVar<C::N>,
718    is_eq: bool,
719    pub(crate) builder: &'a mut Builder<C>,
720}
721
722/// A set of conditions that if statements can be based on.
723enum IfCondition<N> {
724    EqConst(N, N),
725    NeConst(N, N),
726    Eq(Var<N>, Var<N>),
727    EqI(Var<N>, N),
728    Ne(Var<N>, Var<N>),
729    NeI(Var<N>, N),
730}
731
732impl<C: Config> IfBuilder<'_, C> {
733    pub fn then(&mut self, mut f: impl FnMut(&mut Builder<C>)) {
734        // Get the condition reduced from the expressions for lhs and rhs.
735        let condition = self.condition();
736        // Early return for const branches.
737        match condition {
738            IfCondition::EqConst(lhs, rhs) => {
739                if lhs == rhs {
740                    return f(self.builder);
741                }
742                return;
743            }
744            IfCondition::NeConst(lhs, rhs) => {
745                if lhs != rhs {
746                    return f(self.builder);
747                }
748                return;
749            }
750            _ => (),
751        }
752        assert!(
753            !self.builder.flags.static_only,
754            "Cannot use dynamic branch in static mode"
755        );
756
757        // Execute the `then` block and collect the instructions.
758        let mut f_builder = self.builder.create_sub_builder();
759        f(&mut f_builder);
760        let then_instructions = f_builder.operations;
761
762        // Dispatch instructions to the correct conditional block.
763        match condition {
764            IfCondition::Eq(lhs, rhs) => {
765                let op = DslIr::IfEq(lhs, rhs, then_instructions, Default::default());
766                self.builder.operations.push(op);
767            }
768            IfCondition::EqI(lhs, rhs) => {
769                let op = DslIr::IfEqI(lhs, rhs, then_instructions, Default::default());
770                self.builder.operations.push(op);
771            }
772            IfCondition::Ne(lhs, rhs) => {
773                let op = DslIr::IfNe(lhs, rhs, then_instructions, Default::default());
774                self.builder.operations.push(op);
775            }
776            IfCondition::NeI(lhs, rhs) => {
777                let op = DslIr::IfNeI(lhs, rhs, then_instructions, Default::default());
778                self.builder.operations.push(op);
779            }
780            _ => unreachable!("Const if should have returned early"),
781        }
782    }
783
784    pub fn then_or_else(
785        &mut self,
786        mut then_f: impl FnMut(&mut Builder<C>),
787        mut else_f: impl FnMut(&mut Builder<C>),
788    ) {
789        // Get the condition reduced from the expressions for lhs and rhs.
790        let condition = self.condition();
791        // Early return for const branches.
792        match condition {
793            IfCondition::EqConst(lhs, rhs) => {
794                if lhs == rhs {
795                    return then_f(self.builder);
796                }
797                return else_f(self.builder);
798            }
799            IfCondition::NeConst(lhs, rhs) => {
800                if lhs != rhs {
801                    return then_f(self.builder);
802                }
803                return else_f(self.builder);
804            }
805            _ => (),
806        }
807        assert!(
808            !self.builder.flags.static_only,
809            "Cannot use dynamic branch in static mode"
810        );
811        let mut then_builder = self.builder.create_sub_builder();
812
813        // Execute the `then` and `else_then` blocks and collect the instructions.
814        then_f(&mut then_builder);
815        let then_instructions = then_builder.operations;
816
817        let mut else_builder = self.builder.create_sub_builder();
818        else_f(&mut else_builder);
819        let else_instructions = else_builder.operations;
820
821        // Dispatch instructions to the correct conditional block.
822        match condition {
823            IfCondition::Eq(lhs, rhs) => {
824                let op = DslIr::IfEq(lhs, rhs, then_instructions, else_instructions);
825                self.builder.operations.push(op);
826            }
827            IfCondition::EqI(lhs, rhs) => {
828                let op = DslIr::IfEqI(lhs, rhs, then_instructions, else_instructions);
829                self.builder.operations.push(op);
830            }
831            IfCondition::Ne(lhs, rhs) => {
832                let op = DslIr::IfNe(lhs, rhs, then_instructions, else_instructions);
833                self.builder.operations.push(op);
834            }
835            IfCondition::NeI(lhs, rhs) => {
836                let op = DslIr::IfNeI(lhs, rhs, then_instructions, else_instructions);
837                self.builder.operations.push(op);
838            }
839            _ => unreachable!("Const if should have returned early"),
840        }
841    }
842
843    fn condition(&mut self) -> IfCondition<C::N> {
844        match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
845            (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
846                IfCondition::EqConst(lhs, rhs)
847            }
848            (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
849                IfCondition::NeConst(lhs, rhs)
850            }
851            (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
852                IfCondition::EqI(rhs, lhs)
853            }
854            (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
855                IfCondition::NeI(rhs, lhs)
856            }
857            (SymbolicVar::Const(lhs, _), rhs, true) => {
858                let rhs: Var<C::N> = self.builder.eval(rhs);
859                IfCondition::EqI(rhs, lhs)
860            }
861            (SymbolicVar::Const(lhs, _), rhs, false) => {
862                let rhs: Var<C::N> = self.builder.eval(rhs);
863                IfCondition::NeI(rhs, lhs)
864            }
865            (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
866                let lhs: Var<C::N> = self.builder.eval(lhs);
867                IfCondition::EqI(lhs, rhs)
868            }
869            (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
870                let lhs: Var<C::N> = self.builder.eval(lhs);
871                IfCondition::NeI(lhs, rhs)
872            }
873            (lhs, SymbolicVar::Const(rhs, _), true) => {
874                let lhs: Var<C::N> = self.builder.eval(lhs);
875                IfCondition::EqI(lhs, rhs)
876            }
877            (lhs, SymbolicVar::Const(rhs, _), false) => {
878                let lhs: Var<C::N> = self.builder.eval(lhs);
879                IfCondition::NeI(lhs, rhs)
880            }
881            (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
882            (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
883                IfCondition::Ne(lhs, rhs)
884            }
885            (SymbolicVar::Val(lhs, _), rhs, true) => {
886                let rhs: Var<C::N> = self.builder.eval(rhs);
887                IfCondition::Eq(lhs, rhs)
888            }
889            (SymbolicVar::Val(lhs, _), rhs, false) => {
890                let rhs: Var<C::N> = self.builder.eval(rhs);
891                IfCondition::Ne(lhs, rhs)
892            }
893            (lhs, SymbolicVar::Val(rhs, _), true) => {
894                let lhs: Var<C::N> = self.builder.eval(lhs);
895                IfCondition::Eq(lhs, rhs)
896            }
897            (lhs, SymbolicVar::Val(rhs, _), false) => {
898                let lhs: Var<C::N> = self.builder.eval(lhs);
899                IfCondition::Ne(lhs, rhs)
900            }
901            (lhs, rhs, true) => {
902                let lhs: Var<C::N> = self.builder.eval(lhs);
903                let rhs: Var<C::N> = self.builder.eval(rhs);
904                IfCondition::Eq(lhs, rhs)
905            }
906            (lhs, rhs, false) => {
907                let lhs: Var<C::N> = self.builder.eval(lhs);
908                let rhs: Var<C::N> = self.builder.eval(rhs);
909                IfCondition::Ne(lhs, rhs)
910            }
911        }
912    }
913}
914
915// iterates through zipped pointers
916pub struct IteratorBuilder<'a, C: Config> {
917    starts: Vec<RVar<C::N>>,
918    end0: RVar<C::N>,
919    step_sizes: Vec<C::N>,
920    builder: &'a mut Builder<C>,
921}
922
923impl<C: Config> IteratorBuilder<'_, C> {
924    pub fn for_each(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
925        assert!(self.starts.len() == self.step_sizes.len());
926        assert!(!self.starts.is_empty());
927
928        if self.starts.iter().all(|start| start.is_const()) && self.end0.is_const() {
929            self.for_each_unrolled(|ptrs, builder| {
930                f(ptrs, builder);
931            });
932            return;
933        }
934
935        self.for_each_dynamic(|ptrs, builder| {
936            f(ptrs, builder);
937        });
938    }
939
940    fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
941        let mut ptrs: Vec<_> = self
942            .starts
943            .iter()
944            .map(|start| start.field_value())
945            .collect();
946        let end0 = self.end0.field_value();
947        while ptrs[0] != end0 {
948            f(
949                ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
950                self.builder,
951            );
952            for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
953                *ptr += *step_size;
954            }
955        }
956    }
957
958    fn for_each_dynamic(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
959        assert!(
960            !self.builder.flags.static_only,
961            "Cannot use dynamic loop in static mode"
962        );
963
964        let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
965            .map(|_| self.builder.uninit())
966            .collect();
967        let mut loop_body_builder = self.builder.create_sub_builder();
968
969        f(
970            loop_variables.iter().map(|&v| v.into()).collect(),
971            &mut loop_body_builder,
972        );
973
974        let loop_instructions = loop_body_builder.operations;
975        let op = DslIr::ZipFor(
976            self.starts.clone(),
977            self.end0,
978            self.step_sizes.clone(),
979            loop_variables,
980            loop_instructions,
981        );
982        self.builder.operations.push(op);
983    }
984}