1use std::{iter::Zip, vec::IntoIter};
2
3use backtrace::Backtrace;
4use itertools::izip;
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
7use serde::{Deserialize, Serialize};
8
9use super::{
10 Array, Config, DslIr, Ext, Felt, FromConstant, MemIndex, MemVariable, RVar, SymbolicExt,
11 SymbolicFelt, SymbolicVar, Usize, Var, Variable, WitnessRef,
12};
13use crate::ir::{collections::ArrayLike, Ptr};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TracedVec<T> {
19 pub vec: Vec<T>,
20 pub traces: Vec<Option<Backtrace>>,
21}
22
23impl<T> Default for TracedVec<T> {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<T> From<Vec<T>> for TracedVec<T> {
30 fn from(vec: Vec<T>) -> Self {
31 let len = vec.len();
32 Self {
33 vec,
34 traces: vec![None; len],
35 }
36 }
37}
38
39impl<T> TracedVec<T> {
40 pub const fn new() -> Self {
41 Self {
42 vec: Vec::new(),
43 traces: Vec::new(),
44 }
45 }
46
47 pub fn push(&mut self, value: T) {
48 self.vec.push(value);
49 self.traces.push(None);
50 }
51
52 pub fn trace_push(&mut self, value: T) {
54 self.vec.push(value);
55 if std::env::var_os("RUST_BACKTRACE").is_none() {
56 self.traces.push(None);
57 } else {
58 self.traces.push(Some(Backtrace::new_unresolved()));
59 }
60 }
61
62 pub fn extend<I: IntoIterator<Item = (T, Option<Backtrace>)>>(&mut self, iter: I) {
63 let iter = iter.into_iter();
64 let len = iter.size_hint().0;
65 self.vec.reserve(len);
66 self.traces.reserve(len);
67 for (value, trace) in iter {
68 self.vec.push(value);
69 self.traces.push(trace);
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.vec.is_empty()
75 }
76}
77
78impl<T> IntoIterator for TracedVec<T> {
79 type Item = (T, Option<Backtrace>);
80 type IntoIter = Zip<IntoIter<T>, IntoIter<Option<Backtrace>>>;
81
82 fn into_iter(self) -> Self::IntoIter {
83 self.vec.into_iter().zip(self.traces)
84 }
85}
86
87#[derive(Debug, Copy, Clone, Default)]
88pub struct BuilderFlags {
89 pub debug: bool,
90 pub static_only: bool,
92}
93
94#[derive(Debug, Clone, Default)]
98pub struct Builder<C: Config> {
99 pub(crate) var_count: u32,
100 pub(crate) felt_count: u32,
101 pub(crate) ext_count: u32,
102 pub operations: TracedVec<DslIr<C>>,
103 pub(crate) nb_public_values: Option<Var<C::N>>,
104 pub(crate) witness_var_count: u32,
105 pub(crate) witness_felt_count: u32,
106 pub(crate) witness_ext_count: u32,
107 pub(crate) witness_space: Vec<Vec<WitnessRef>>,
108 pub flags: BuilderFlags,
109 pub is_sub_builder: bool,
110}
111
112impl<C: Config> Builder<C> {
113 pub fn create_sub_builder(&self) -> Self {
115 Self {
116 var_count: self.var_count,
117 felt_count: self.felt_count,
118 ext_count: self.ext_count,
119 witness_var_count: 0,
122 witness_felt_count: 0,
123 witness_ext_count: 0,
124 witness_space: Default::default(),
125 operations: Default::default(),
126 nb_public_values: self.nb_public_values,
127 flags: self.flags,
128 is_sub_builder: true,
129 }
130 }
131
132 pub fn push(&mut self, op: DslIr<C>) {
134 self.operations.push(op);
135 }
136
137 pub fn trace_push(&mut self, op: DslIr<C>) {
139 self.operations.trace_push(op);
140 }
141
142 pub fn uninit<V: Variable<C>>(&mut self) -> V {
144 V::uninit(self)
145 }
146
147 pub fn eval<V: Variable<C>, E: Into<V::Expression>>(&mut self, expr: E) -> V {
149 V::eval(self, expr)
150 }
151
152 pub fn eval_expr(&mut self, expr: impl Into<SymbolicVar<C::N>>) -> RVar<C::N> {
154 let expr = expr.into();
155 match expr {
156 SymbolicVar::Const(c, _) => RVar::Const(c),
157 SymbolicVar::Val(val, _) => RVar::Val(val),
158 _ => {
159 let ret: Var<_> = self.eval(expr);
160 RVar::Val(ret)
161 }
162 }
163 }
164
165 pub fn inc(&mut self, u: &Usize<C::N>) {
167 self.assign(u, u.clone() + RVar::one());
168 }
169
170 pub fn constant<V: FromConstant<C>>(&mut self, value: V::Constant) -> V {
172 V::constant(value, self)
173 }
174
175 pub fn assign<V: Variable<C>, E: Into<V::Expression>>(&mut self, dst: &V, expr: E) {
177 dst.assign(expr.into(), self);
178 }
179
180 pub fn cast_felt_to_var(&mut self, felt: Felt<C::F>) -> Var<C::N> {
182 let var: Var<_> = self.uninit();
183 self.operations.push(DslIr::CastFV(var, felt));
184 var
185 }
186 pub fn unsafe_cast_var_to_felt(&mut self, var: Var<C::N>) -> Felt<C::F> {
188 assert!(!self.flags.static_only, "dynamic mode only");
189 let felt: Felt<_> = self.uninit();
190 self.operations.push(DslIr::UnsafeCastVF(felt, var));
191 felt
192 }
193
194 pub fn assert_nonzero(&mut self, u: &Usize<C::N>) {
196 self.operations.push(DslIr::AssertNonZero(u.clone()));
197 }
198
199 pub fn assert_eq<V: Variable<C>>(
201 &mut self,
202 lhs: impl Into<V::Expression>,
203 rhs: impl Into<V::Expression>,
204 ) {
205 V::assert_eq(lhs, rhs, self);
206 }
207
208 pub fn assert_var_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
210 &mut self,
211 lhs: LhsExpr,
212 rhs: RhsExpr,
213 ) {
214 self.assert_eq::<Var<C::N>>(lhs, rhs);
215 }
216
217 pub fn assert_felt_eq<LhsExpr: Into<SymbolicFelt<C::F>>, RhsExpr: Into<SymbolicFelt<C::F>>>(
219 &mut self,
220 lhs: LhsExpr,
221 rhs: RhsExpr,
222 ) {
223 self.assert_eq::<Felt<C::F>>(lhs, rhs);
224 }
225
226 pub fn assert_ext_eq<
228 LhsExpr: Into<SymbolicExt<C::F, C::EF>>,
229 RhsExpr: Into<SymbolicExt<C::F, C::EF>>,
230 >(
231 &mut self,
232 lhs: LhsExpr,
233 rhs: RhsExpr,
234 ) {
235 self.assert_eq::<Ext<C::F, C::EF>>(lhs, rhs);
236 }
237
238 pub fn assert_usize_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
239 &mut self,
240 lhs: LhsExpr,
241 rhs: RhsExpr,
242 ) {
243 self.assert_eq::<Usize<C::N>>(lhs, rhs);
244 }
245
246 pub fn assert_var_array_eq(&mut self, lhs: &Array<C, Var<C::N>>, rhs: &Array<C, Var<C::N>>) {
248 self.assert_var_eq(lhs.len(), rhs.len());
249 self.range(0, lhs.len()).for_each(|idx_vec, builder| {
250 let l = builder.get(lhs, idx_vec[0]);
251 let r = builder.get(rhs, idx_vec[0]);
252 builder.assert_var_eq(l, r);
253 });
254 }
255
256 pub fn if_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
258 &mut self,
259 lhs: LhsExpr,
260 rhs: RhsExpr,
261 ) -> IfBuilder<C> {
262 IfBuilder {
263 lhs: lhs.into(),
264 rhs: rhs.into(),
265 is_eq: true,
266 builder: self,
267 }
268 }
269
270 pub fn if_ne<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
272 &mut self,
273 lhs: LhsExpr,
274 rhs: RhsExpr,
275 ) -> IfBuilder<C> {
276 IfBuilder {
277 lhs: lhs.into(),
278 rhs: rhs.into(),
279 is_eq: false,
280 builder: self,
281 }
282 }
283
284 pub fn assert_less_than_slow_small_rhs<
286 LhsExpr: Into<SymbolicVar<C::N>>,
287 RhsExpr: Into<SymbolicVar<C::N>>,
288 >(
289 &mut self,
290 lhs: LhsExpr,
291 rhs: RhsExpr,
292 ) {
293 let lhs: Usize<_> = self.eval(lhs.into());
294 let rhs: Usize<_> = self.eval(rhs.into());
295 let product: Usize<_> = self.eval(lhs.clone());
296 self.range(1, rhs).for_each(|i_vec, builder| {
297 let i = i_vec[0];
298 let diff: Usize<_> = builder.eval(lhs.clone() - i);
299 builder.assign(&product, product.clone() * diff);
300 });
301 self.assert_usize_eq(product, RVar::from(0));
302 }
303
304 pub fn assert_less_than_slow_bit_decomp(&mut self, lhs: Var<C::N>, rhs: Var<C::N>) {
311 let lhs = self.unsafe_cast_var_to_felt(lhs);
312 let rhs = self.unsafe_cast_var_to_felt(rhs);
313
314 let lhs_bits = self.num2bits_f(lhs, C::N::bits() as u32);
315 let rhs_bits = self.num2bits_f(rhs, C::N::bits() as u32);
316
317 let is_lt: Var<_> = self.eval(C::N::ZERO);
318
319 iter_zip!(self, lhs_bits, rhs_bits).for_each(|ptr_vec, builder| {
320 let lhs_bit = builder.iter_ptr_get(&lhs_bits, ptr_vec[0]);
321 let rhs_bit = builder.iter_ptr_get(&rhs_bits, ptr_vec[1]);
322
323 builder.if_ne(lhs_bit, rhs_bit).then(|builder| {
324 builder.assign(&is_lt, rhs_bit);
325 });
326 });
327 self.assert_var_eq(is_lt, C::N::ONE);
328 }
329
330 pub fn range_check_var(&mut self, x: Var<C::N>, num_bits: usize) {
332 assert!(!self.flags.static_only, "range_check_var is dynamic");
333 assert!(num_bits <= 30);
334 self.trace_push(DslIr::RangeCheckV(x, num_bits));
335 }
336
337 pub fn range(
339 &mut self,
340 start: impl Into<RVar<C::N>>,
341 end: impl Into<RVar<C::N>>,
342 ) -> IteratorBuilder<C> {
343 self.range_with_step(start, end, C::N::ONE)
344 }
345 pub fn range_with_step(
347 &mut self,
348 start: impl Into<RVar<C::N>>,
349 end: impl Into<RVar<C::N>>,
350 step: C::N,
351 ) -> IteratorBuilder<C> {
352 let start = start.into();
353 let end0 = end.into();
354 IteratorBuilder {
355 starts: vec![start],
356 end0,
357 step_sizes: vec![step],
358 builder: self,
359 }
360 }
361
362 pub fn zip<'a>(
363 &'a mut self,
364 arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
365 ) -> IteratorBuilder<'a, C> {
366 assert!(!arrays.is_empty());
367 if arrays.iter().all(|array| array.is_fixed()) {
368 IteratorBuilder {
369 starts: vec![RVar::zero(); arrays.len()],
370 end0: arrays[0].len().into(),
371 step_sizes: vec![C::N::ONE; arrays.len()],
372 builder: self,
373 }
374 } else if arrays.iter().all(|array| !array.is_fixed()) {
375 IteratorBuilder {
376 starts: arrays
377 .iter()
378 .map(|array| array.ptr().address.into())
379 .collect(),
380 end0: {
381 let len: RVar<C::N> = arrays[0].len().into();
382 let size = arrays[0].element_size_of();
383 let end: Var<C::N> =
384 self.eval(arrays[0].ptr().address + len * RVar::from(size));
385 end.into()
386 },
387 step_sizes: arrays
388 .iter()
389 .map(|array| C::N::from_canonical_usize(array.element_size_of()))
390 .collect(),
391 builder: self,
392 }
393 } else {
394 panic!("Cannot use zipped pointer iterator with mixed arrays");
395 }
396 }
397
398 pub fn print_debug(&mut self, val: usize) {
399 let constant = self.eval(C::N::from_canonical_usize(val));
400 self.print_v(constant);
401 }
402
403 pub fn print_v(&mut self, dst: Var<C::N>) {
405 self.operations.push(DslIr::PrintV(dst));
406 }
407
408 pub fn print_f(&mut self, dst: Felt<C::F>) {
410 self.operations.push(DslIr::PrintF(dst));
411 }
412
413 pub fn print_e(&mut self, dst: Ext<C::F, C::EF>) {
415 self.operations.push(DslIr::PrintE(dst));
416 }
417
418 pub fn hint_var(&mut self) -> Var<C::N> {
419 let ptr = self.alloc(RVar::one(), 1);
420 self.operations.push(DslIr::HintFelt());
422 let index = MemIndex {
423 index: RVar::zero(),
424 offset: 0,
425 size: 1,
426 };
427 self.operations.push(DslIr::StoreHintWord(ptr, index));
428 let v: Var<C::N> = self.uninit();
429 self.load(v, ptr, index);
430 v
431 }
432
433 pub fn hint_felt(&mut self) -> Felt<C::F> {
434 let ptr = self.alloc(RVar::one(), 1);
435 self.operations.push(DslIr::HintFelt());
437 let index = MemIndex {
438 index: RVar::zero(),
439 offset: 0,
440 size: 1,
441 };
442 self.operations.push(DslIr::StoreHintWord(ptr, index));
443 let f: Felt<C::F> = self.uninit();
444 self.load(f, ptr, index);
445 f
446 }
447
448 pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
449 let flattened = self.hint_felts_fixed(C::EF::D);
450
451 let array: Array<C, Ext<_, _>> = match flattened {
453 Array::Fixed(_) => unreachable!(),
454 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::from(1)),
455 };
456 self.get(&array, 0)
457 }
458
459 pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
463 self.hint_words()
464 }
465
466 pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
468 self.hint_words()
469 }
470
471 pub fn hint_felts_fixed(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, Felt<C::F>> {
472 self.hint_words_fixed(len)
473 }
474
475 fn hint_words<V: MemVariable<C>>(&mut self) -> Array<C, V> {
477 assert_eq!(V::size_of(), 1);
478
479 let ptr = self.alloc(RVar::one(), 1);
481
482 self.operations.push(DslIr::HintInputVec());
484
485 let index = MemIndex {
487 index: RVar::zero(),
488 offset: 0,
489 size: 1,
490 };
491 self.operations.push(DslIr::StoreHintWord(ptr, index));
493
494 let vlen: Var<C::N> = self.uninit();
495 self.load(vlen, ptr, index);
496 let arr = self.dyn_array(vlen);
497
498 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
500 let index = MemIndex {
501 index: 0.into(),
502 offset: 0,
503 size: 1,
504 };
505 builder.operations.push(DslIr::StoreHintWord(
506 Ptr {
507 address: ptr_vec[0].variable(),
508 },
509 index,
510 ));
511 });
512 arr
513 }
514
515 fn hint_words_fixed<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
517 assert_eq!(V::size_of(), 1);
518
519 let arr = self.dyn_array(len.into());
520 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
522 let index = MemIndex {
523 index: 0.into(),
524 offset: 0,
525 size: 1,
526 };
527 builder.operations.push(DslIr::HintFelt());
528 builder.operations.push(DslIr::StoreHintWord(
529 Ptr {
530 address: ptr_vec[0].variable(),
531 },
532 index,
533 ));
534 });
535 arr
536 }
537
538 pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
543 let len = self.hint_var();
544 let flattened = self.hint_felts();
545
546 let size = <Ext<C::F, C::EF> as MemVariable<C>>::size_of();
547 self.assert_usize_eq(flattened.len(), len * C::N::from_canonical_usize(size));
548
549 match flattened {
551 Array::Fixed(_) => unreachable!(),
552 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::Var(len)),
553 }
554 }
555
556 pub fn hint_load(&mut self) -> Var<C::N> {
559 self.trace_push(DslIr::HintLoad());
560 let ptr = self.alloc(RVar::one(), 1);
561 let index = MemIndex {
562 index: RVar::zero(),
563 offset: 0,
564 size: 1,
565 };
566 self.operations.push(DslIr::StoreHintWord(ptr, index));
568 let id: Var<C::N> = self.uninit();
569 self.load(id, ptr, index);
570 id
571 }
572
573 pub fn witness_var(&mut self) -> Var<C::N> {
574 assert!(
575 !self.is_sub_builder,
576 "Cannot create a witness var with a sub builder"
577 );
578 let witness = self.uninit();
579 self.operations
580 .push(DslIr::WitnessVar(witness, self.witness_var_count));
581 self.witness_var_count += 1;
582 witness
583 }
584
585 pub fn witness_felt(&mut self) -> Felt<C::F> {
586 assert!(
587 !self.is_sub_builder,
588 "Cannot create a witness felt with a sub builder"
589 );
590 let witness = self.uninit();
591 self.operations
592 .push(DslIr::WitnessFelt(witness, self.witness_felt_count));
593 self.witness_felt_count += 1;
594 witness
595 }
596
597 pub fn witness_ext(&mut self) -> Ext<C::F, C::EF> {
598 assert!(
599 !self.is_sub_builder,
600 "Cannot create a witness ext with a sub builder"
601 );
602 let witness = self.uninit();
603 self.operations
604 .push(DslIr::WitnessExt(witness, self.witness_ext_count));
605 self.witness_ext_count += 1;
606 witness
607 }
608
609 pub fn witness_load(&mut self, witness_refs: Vec<WitnessRef>) -> Usize<C::N> {
610 assert!(
611 !self.is_sub_builder,
612 "Cannot load witness refs with a sub builder"
613 );
614 let ret = self.witness_space.len();
615 self.witness_space.push(witness_refs);
616 ret.into()
617 }
618
619 pub fn get_witness_refs(&self, id: Usize<C::N>) -> &[WitnessRef] {
620 self.witness_space.get(id.value()).unwrap()
621 }
622
623 pub fn error(&mut self) {
625 self.operations.trace_push(DslIr::Error());
626 }
627
628 fn get_nb_public_values(&mut self) -> Var<C::N> {
629 assert!(
630 !self.is_sub_builder,
631 "Cannot commit to public values with a sub builder"
632 );
633 if self.nb_public_values.is_none() {
634 self.nb_public_values = Some(self.eval(C::N::ZERO));
635 }
636 *self.nb_public_values.as_ref().unwrap()
637 }
638
639 fn commit_public_value_and_increment(&mut self, val: Felt<C::F>, nb_public_values: Var<C::N>) {
640 assert!(
641 !self.flags.static_only,
642 "Static mode should use static_commit_public_value"
643 );
644 self.operations.push(DslIr::Publish(val, nb_public_values));
645 self.assign(&nb_public_values, nb_public_values + C::N::ONE);
646 }
647
648 pub fn static_commit_public_value(&mut self, index: usize, val: Var<C::N>) {
651 assert!(
652 self.flags.static_only,
653 "Dynamic mode should use commit_public_value instead."
654 );
655 self.operations.push(DslIr::CircuitPublish(val, index));
656 }
657
658 pub fn commit_public_value(&mut self, val: Felt<C::F>) {
660 let nb_public_values = self.get_nb_public_values();
661 self.commit_public_value_and_increment(val, nb_public_values);
662 }
663
664 pub fn commit_public_values(&mut self, vals: &Array<C, Felt<C::F>>) {
666 let nb_public_values = self.get_nb_public_values();
667 let len = vals.len();
668 self.range(0, len).for_each(|idx_vec, builder| {
669 let val = builder.get(vals, idx_vec[0]);
670 builder.commit_public_value_and_increment(val, nb_public_values);
671 });
672 }
673
674 pub fn cycle_tracker_start(&mut self, name: &str) {
675 self.operations
676 .push(DslIr::CycleTrackerStart(name.to_string()));
677 }
678
679 pub fn cycle_tracker_end(&mut self, name: &str) {
680 self.operations
681 .push(DslIr::CycleTrackerEnd(name.to_string()));
682 }
683
684 pub fn halt(&mut self) {
685 self.operations.push(DslIr::Halt);
686 }
687}
688
689pub struct IfBuilder<'a, C: Config> {
691 lhs: SymbolicVar<C::N>,
692 rhs: SymbolicVar<C::N>,
693 is_eq: bool,
694 pub(crate) builder: &'a mut Builder<C>,
695}
696
697enum IfCondition<N> {
699 EqConst(N, N),
700 NeConst(N, N),
701 Eq(Var<N>, Var<N>),
702 EqI(Var<N>, N),
703 Ne(Var<N>, Var<N>),
704 NeI(Var<N>, N),
705}
706
707impl<C: Config> IfBuilder<'_, C> {
708 pub fn then(&mut self, mut f: impl FnMut(&mut Builder<C>)) {
709 let condition = self.condition();
711 match condition {
713 IfCondition::EqConst(lhs, rhs) => {
714 if lhs == rhs {
715 return f(self.builder);
716 }
717 return;
718 }
719 IfCondition::NeConst(lhs, rhs) => {
720 if lhs != rhs {
721 return f(self.builder);
722 }
723 return;
724 }
725 _ => (),
726 }
727 assert!(
728 !self.builder.flags.static_only,
729 "Cannot use dynamic branch in static mode"
730 );
731
732 let mut f_builder = self.builder.create_sub_builder();
734 f(&mut f_builder);
735 let then_instructions = f_builder.operations;
736
737 match condition {
739 IfCondition::Eq(lhs, rhs) => {
740 let op = DslIr::IfEq(lhs, rhs, then_instructions, Default::default());
741 self.builder.operations.push(op);
742 }
743 IfCondition::EqI(lhs, rhs) => {
744 let op = DslIr::IfEqI(lhs, rhs, then_instructions, Default::default());
745 self.builder.operations.push(op);
746 }
747 IfCondition::Ne(lhs, rhs) => {
748 let op = DslIr::IfNe(lhs, rhs, then_instructions, Default::default());
749 self.builder.operations.push(op);
750 }
751 IfCondition::NeI(lhs, rhs) => {
752 let op = DslIr::IfNeI(lhs, rhs, then_instructions, Default::default());
753 self.builder.operations.push(op);
754 }
755 _ => unreachable!("Const if should have returned early"),
756 }
757 }
758
759 pub fn then_or_else(
760 &mut self,
761 mut then_f: impl FnMut(&mut Builder<C>),
762 mut else_f: impl FnMut(&mut Builder<C>),
763 ) {
764 let condition = self.condition();
766 match condition {
768 IfCondition::EqConst(lhs, rhs) => {
769 if lhs == rhs {
770 return then_f(self.builder);
771 }
772 return else_f(self.builder);
773 }
774 IfCondition::NeConst(lhs, rhs) => {
775 if lhs != rhs {
776 return then_f(self.builder);
777 }
778 return else_f(self.builder);
779 }
780 _ => (),
781 }
782 assert!(
783 !self.builder.flags.static_only,
784 "Cannot use dynamic branch in static mode"
785 );
786 let mut then_builder = self.builder.create_sub_builder();
787
788 then_f(&mut then_builder);
790 let then_instructions = then_builder.operations;
791
792 let mut else_builder = self.builder.create_sub_builder();
793 else_f(&mut else_builder);
794 let else_instructions = else_builder.operations;
795
796 match condition {
798 IfCondition::Eq(lhs, rhs) => {
799 let op = DslIr::IfEq(lhs, rhs, then_instructions, else_instructions);
800 self.builder.operations.push(op);
801 }
802 IfCondition::EqI(lhs, rhs) => {
803 let op = DslIr::IfEqI(lhs, rhs, then_instructions, else_instructions);
804 self.builder.operations.push(op);
805 }
806 IfCondition::Ne(lhs, rhs) => {
807 let op = DslIr::IfNe(lhs, rhs, then_instructions, else_instructions);
808 self.builder.operations.push(op);
809 }
810 IfCondition::NeI(lhs, rhs) => {
811 let op = DslIr::IfNeI(lhs, rhs, then_instructions, else_instructions);
812 self.builder.operations.push(op);
813 }
814 _ => unreachable!("Const if should have returned early"),
815 }
816 }
817
818 fn condition(&mut self) -> IfCondition<C::N> {
819 match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
820 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
821 IfCondition::EqConst(lhs, rhs)
822 }
823 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
824 IfCondition::NeConst(lhs, rhs)
825 }
826 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
827 IfCondition::EqI(rhs, lhs)
828 }
829 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
830 IfCondition::NeI(rhs, lhs)
831 }
832 (SymbolicVar::Const(lhs, _), rhs, true) => {
833 let rhs: Var<C::N> = self.builder.eval(rhs);
834 IfCondition::EqI(rhs, lhs)
835 }
836 (SymbolicVar::Const(lhs, _), rhs, false) => {
837 let rhs: Var<C::N> = self.builder.eval(rhs);
838 IfCondition::NeI(rhs, lhs)
839 }
840 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
841 let lhs: Var<C::N> = self.builder.eval(lhs);
842 IfCondition::EqI(lhs, rhs)
843 }
844 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
845 let lhs: Var<C::N> = self.builder.eval(lhs);
846 IfCondition::NeI(lhs, rhs)
847 }
848 (lhs, SymbolicVar::Const(rhs, _), true) => {
849 let lhs: Var<C::N> = self.builder.eval(lhs);
850 IfCondition::EqI(lhs, rhs)
851 }
852 (lhs, SymbolicVar::Const(rhs, _), false) => {
853 let lhs: Var<C::N> = self.builder.eval(lhs);
854 IfCondition::NeI(lhs, rhs)
855 }
856 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
857 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
858 IfCondition::Ne(lhs, rhs)
859 }
860 (SymbolicVar::Val(lhs, _), rhs, true) => {
861 let rhs: Var<C::N> = self.builder.eval(rhs);
862 IfCondition::Eq(lhs, rhs)
863 }
864 (SymbolicVar::Val(lhs, _), rhs, false) => {
865 let rhs: Var<C::N> = self.builder.eval(rhs);
866 IfCondition::Ne(lhs, rhs)
867 }
868 (lhs, SymbolicVar::Val(rhs, _), true) => {
869 let lhs: Var<C::N> = self.builder.eval(lhs);
870 IfCondition::Eq(lhs, rhs)
871 }
872 (lhs, SymbolicVar::Val(rhs, _), false) => {
873 let lhs: Var<C::N> = self.builder.eval(lhs);
874 IfCondition::Ne(lhs, rhs)
875 }
876 (lhs, rhs, true) => {
877 let lhs: Var<C::N> = self.builder.eval(lhs);
878 let rhs: Var<C::N> = self.builder.eval(rhs);
879 IfCondition::Eq(lhs, rhs)
880 }
881 (lhs, rhs, false) => {
882 let lhs: Var<C::N> = self.builder.eval(lhs);
883 let rhs: Var<C::N> = self.builder.eval(rhs);
884 IfCondition::Ne(lhs, rhs)
885 }
886 }
887 }
888}
889
890pub struct IteratorBuilder<'a, C: Config> {
892 starts: Vec<RVar<C::N>>,
893 end0: RVar<C::N>,
894 step_sizes: Vec<C::N>,
895 builder: &'a mut Builder<C>,
896}
897
898impl<C: Config> IteratorBuilder<'_, C> {
899 pub fn for_each(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
900 assert!(self.starts.len() == self.step_sizes.len());
901 assert!(!self.starts.is_empty());
902
903 if self.starts.iter().all(|start| start.is_const()) && self.end0.is_const() {
904 self.for_each_unrolled(|ptrs, builder| {
905 f(ptrs, builder);
906 });
907 return;
908 }
909
910 self.for_each_dynamic(|ptrs, builder| {
911 f(ptrs, builder);
912 });
913 }
914
915 fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
916 let mut ptrs: Vec<_> = self
917 .starts
918 .iter()
919 .map(|start| start.field_value())
920 .collect();
921 let end0 = self.end0.field_value();
922 while ptrs[0] != end0 {
923 f(
924 ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
925 self.builder,
926 );
927 for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
928 *ptr += *step_size;
929 }
930 }
931 }
932
933 fn for_each_dynamic(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
934 assert!(
935 !self.builder.flags.static_only,
936 "Cannot use dynamic loop in static mode"
937 );
938
939 let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
940 .map(|_| self.builder.uninit())
941 .collect();
942 let mut loop_body_builder = self.builder.create_sub_builder();
943
944 f(
945 loop_variables.iter().map(|&v| v.into()).collect(),
946 &mut loop_body_builder,
947 );
948
949 let loop_instructions = loop_body_builder.operations;
950 let op = DslIr::ZipFor(
951 self.starts.clone(),
952 self.end0,
953 self.step_sizes.clone(),
954 loop_variables,
955 loop_instructions,
956 );
957 self.builder.operations.push(op);
958 }
959}