1use std::{iter::Zip, vec::IntoIter};
2
3use backtrace::Backtrace;
4use itertools::izip;
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{Field, FieldAlgebra, FieldExtensionAlgebra};
7use serde::{Deserialize, Serialize};
8
9use super::{
10 Array, Config, DslIr, Ext, Felt, FromConstant, MemIndex, MemVariable, RVar, SymbolicExt,
11 SymbolicFelt, SymbolicVar, Usize, Var, Variable, WitnessRef,
12};
13use crate::ir::{collections::ArrayLike, Ptr};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TracedVec<T> {
19 pub vec: Vec<T>,
20 pub traces: Vec<Option<Backtrace>>,
21}
22
23impl<T> Default for TracedVec<T> {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<T> From<Vec<T>> for TracedVec<T> {
30 fn from(vec: Vec<T>) -> Self {
31 let len = vec.len();
32 Self {
33 vec,
34 traces: vec![None; len],
35 }
36 }
37}
38
39impl<T> TracedVec<T> {
40 pub const fn new() -> Self {
41 Self {
42 vec: Vec::new(),
43 traces: Vec::new(),
44 }
45 }
46
47 pub fn push(&mut self, value: T) {
48 self.vec.push(value);
49 self.traces.push(None);
50 }
51
52 pub fn trace_push(&mut self, value: T) {
54 self.vec.push(value);
55 if std::env::var_os("RUST_BACKTRACE").is_none() {
56 self.traces.push(None);
57 } else {
58 self.traces.push(Some(Backtrace::new_unresolved()));
59 }
60 }
61
62 pub fn extend<I: IntoIterator<Item = (T, Option<Backtrace>)>>(&mut self, iter: I) {
63 let iter = iter.into_iter();
64 let len = iter.size_hint().0;
65 self.vec.reserve(len);
66 self.traces.reserve(len);
67 for (value, trace) in iter {
68 self.vec.push(value);
69 self.traces.push(trace);
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.vec.is_empty()
75 }
76}
77
78impl<T> IntoIterator for TracedVec<T> {
79 type Item = (T, Option<Backtrace>);
80 type IntoIter = Zip<IntoIter<T>, IntoIter<Option<Backtrace>>>;
81
82 fn into_iter(self) -> Self::IntoIter {
83 self.vec.into_iter().zip(self.traces)
84 }
85}
86
87#[derive(Debug, Copy, Clone, Default)]
88pub struct BuilderFlags {
89 pub debug: bool,
90 pub static_only: bool,
92}
93
94#[derive(Debug, Clone, Default)]
98pub struct Builder<C: Config> {
99 pub(crate) var_count: u32,
100 pub(crate) felt_count: u32,
101 pub(crate) ext_count: u32,
102 pub operations: TracedVec<DslIr<C>>,
103 pub(crate) nb_public_values: Option<Var<C::N>>,
104 pub(crate) witness_var_count: u32,
105 pub(crate) witness_felt_count: u32,
106 pub(crate) witness_ext_count: u32,
107 pub(crate) witness_space: Vec<Vec<WitnessRef>>,
108 pub flags: BuilderFlags,
109 pub is_sub_builder: bool,
110}
111
112impl<C: Config> Builder<C> {
113 pub fn create_sub_builder(&self) -> Self {
115 Self {
116 var_count: self.var_count,
117 felt_count: self.felt_count,
118 ext_count: self.ext_count,
119 witness_var_count: 0,
122 witness_felt_count: 0,
123 witness_ext_count: 0,
124 witness_space: Default::default(),
125 operations: Default::default(),
126 nb_public_values: self.nb_public_values,
127 flags: self.flags,
128 is_sub_builder: true,
129 }
130 }
131
132 pub fn push(&mut self, op: DslIr<C>) {
134 self.operations.push(op);
135 }
136
137 pub fn trace_push(&mut self, op: DslIr<C>) {
139 self.operations.trace_push(op);
140 }
141
142 pub fn uninit<V: Variable<C>>(&mut self) -> V {
144 V::uninit(self)
145 }
146
147 pub fn eval<V: Variable<C>, E: Into<V::Expression>>(&mut self, expr: E) -> V {
149 V::eval(self, expr)
150 }
151
152 pub fn eval_expr(&mut self, expr: impl Into<SymbolicVar<C::N>>) -> RVar<C::N> {
154 let expr = expr.into();
155 match expr {
156 SymbolicVar::Const(c, _) => RVar::Const(c),
157 SymbolicVar::Val(val, _) => RVar::Val(val),
158 _ => {
159 let ret: Var<_> = self.eval(expr);
160 RVar::Val(ret)
161 }
162 }
163 }
164
165 pub fn inc(&mut self, u: &Usize<C::N>) {
167 self.assign(u, u.clone() + RVar::one());
168 }
169
170 pub fn constant<V: FromConstant<C>>(&mut self, value: V::Constant) -> V {
172 V::constant(value, self)
173 }
174
175 pub fn assign<V: Variable<C>, E: Into<V::Expression>>(&mut self, dst: &V, expr: E) {
177 dst.assign(expr.into(), self);
178 }
179
180 pub fn cast_felt_to_var(&mut self, felt: Felt<C::F>) -> Var<C::N> {
182 let var: Var<_> = self.uninit();
183 self.operations.push(DslIr::CastFV(var, felt));
184 var
185 }
186 pub fn unsafe_cast_var_to_felt(&mut self, var: Var<C::N>) -> Felt<C::F> {
188 assert!(!self.flags.static_only, "dynamic mode only");
189 let felt: Felt<_> = self.uninit();
190 self.operations.push(DslIr::UnsafeCastVF(felt, var));
191 felt
192 }
193
194 pub fn assert_nonzero(&mut self, u: &Usize<C::N>) {
196 self.operations.push(DslIr::AssertNonZero(u.clone()));
197 }
198
199 pub fn assert_eq<V: Variable<C>>(
201 &mut self,
202 lhs: impl Into<V::Expression>,
203 rhs: impl Into<V::Expression>,
204 ) {
205 V::assert_eq(lhs, rhs, self);
206 }
207
208 pub fn assert_var_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
210 &mut self,
211 lhs: LhsExpr,
212 rhs: RhsExpr,
213 ) {
214 self.assert_eq::<Var<C::N>>(lhs, rhs);
215 }
216
217 pub fn assert_felt_eq<LhsExpr: Into<SymbolicFelt<C::F>>, RhsExpr: Into<SymbolicFelt<C::F>>>(
219 &mut self,
220 lhs: LhsExpr,
221 rhs: RhsExpr,
222 ) {
223 self.assert_eq::<Felt<C::F>>(lhs, rhs);
224 }
225
226 pub fn assert_ext_eq<
228 LhsExpr: Into<SymbolicExt<C::F, C::EF>>,
229 RhsExpr: Into<SymbolicExt<C::F, C::EF>>,
230 >(
231 &mut self,
232 lhs: LhsExpr,
233 rhs: RhsExpr,
234 ) {
235 self.assert_eq::<Ext<C::F, C::EF>>(lhs, rhs);
236 }
237
238 pub fn assert_usize_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
239 &mut self,
240 lhs: LhsExpr,
241 rhs: RhsExpr,
242 ) {
243 self.assert_eq::<Usize<C::N>>(lhs, rhs);
244 }
245
246 pub fn assert_var_array_eq(&mut self, lhs: &Array<C, Var<C::N>>, rhs: &Array<C, Var<C::N>>) {
248 self.assert_var_eq(lhs.len(), rhs.len());
249 self.range(0, lhs.len()).for_each(|idx_vec, builder| {
250 let l = builder.get(lhs, idx_vec[0]);
251 let r = builder.get(rhs, idx_vec[0]);
252 builder.assert_var_eq(l, r);
253 });
254 }
255
256 pub fn if_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
258 &mut self,
259 lhs: LhsExpr,
260 rhs: RhsExpr,
261 ) -> IfBuilder<C> {
262 IfBuilder {
263 lhs: lhs.into(),
264 rhs: rhs.into(),
265 is_eq: true,
266 builder: self,
267 }
268 }
269
270 pub fn if_ne<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
272 &mut self,
273 lhs: LhsExpr,
274 rhs: RhsExpr,
275 ) -> IfBuilder<C> {
276 IfBuilder {
277 lhs: lhs.into(),
278 rhs: rhs.into(),
279 is_eq: false,
280 builder: self,
281 }
282 }
283
284 pub fn assert_less_than_slow_small_rhs<
286 LhsExpr: Into<SymbolicVar<C::N>>,
287 RhsExpr: Into<SymbolicVar<C::N>>,
288 >(
289 &mut self,
290 lhs: LhsExpr,
291 rhs: RhsExpr,
292 ) {
293 let lhs: Usize<_> = self.eval(lhs.into());
294 let rhs: Usize<_> = self.eval(rhs.into());
295 let product: Usize<_> = self.eval(lhs.clone());
296 self.range(1, rhs).for_each(|i_vec, builder| {
297 let i = i_vec[0];
298 let diff: Usize<_> = builder.eval(lhs.clone() - i);
299 builder.assign(&product, product.clone() * diff);
300 });
301 self.assert_usize_eq(product, RVar::from(0));
302 }
303
304 pub fn assert_less_than_slow_bit_decomp(&mut self, lhs: Var<C::N>, rhs: Var<C::N>) {
311 let lhs = self.unsafe_cast_var_to_felt(lhs);
312 let rhs = self.unsafe_cast_var_to_felt(rhs);
313
314 let lhs_bits = self.num2bits_f(lhs, C::N::bits() as u32);
315 let rhs_bits = self.num2bits_f(rhs, C::N::bits() as u32);
316
317 let is_lt: Var<_> = self.eval(C::N::ZERO);
318
319 iter_zip!(self, lhs_bits, rhs_bits).for_each(|ptr_vec, builder| {
320 let lhs_bit = builder.iter_ptr_get(&lhs_bits, ptr_vec[0]);
321 let rhs_bit = builder.iter_ptr_get(&rhs_bits, ptr_vec[1]);
322
323 builder.if_ne(lhs_bit, rhs_bit).then(|builder| {
324 builder.assign(&is_lt, rhs_bit);
325 });
326 });
327 self.assert_var_eq(is_lt, C::N::ONE);
328 }
329
330 pub fn range_check_var(&mut self, x: Var<C::N>, num_bits: usize) {
332 assert!(!self.flags.static_only, "range_check_var is dynamic");
333 assert!(num_bits <= 30);
334 self.trace_push(DslIr::RangeCheckV(x, num_bits));
335 }
336
337 pub fn range(
339 &mut self,
340 start: impl Into<RVar<C::N>>,
341 end: impl Into<RVar<C::N>>,
342 ) -> IteratorBuilder<C> {
343 self.range_with_step(start, end, C::N::ONE)
344 }
345 pub fn range_with_step(
347 &mut self,
348 start: impl Into<RVar<C::N>>,
349 end: impl Into<RVar<C::N>>,
350 step: C::N,
351 ) -> IteratorBuilder<C> {
352 let start = start.into();
353 let end0 = end.into();
354 IteratorBuilder {
355 starts: vec![start],
356 end0,
357 step_sizes: vec![step],
358 builder: self,
359 }
360 }
361
362 pub fn zip<'a>(
363 &'a mut self,
364 arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
365 ) -> IteratorBuilder<'a, C> {
366 assert!(!arrays.is_empty());
367 if arrays.iter().all(|array| array.is_fixed()) {
368 IteratorBuilder {
369 starts: vec![RVar::zero(); arrays.len()],
370 end0: arrays[0].len().into(),
371 step_sizes: vec![C::N::ONE; arrays.len()],
372 builder: self,
373 }
374 } else if arrays.iter().all(|array| !array.is_fixed()) {
375 IteratorBuilder {
376 starts: arrays
377 .iter()
378 .map(|array| array.ptr().address.into())
379 .collect(),
380 end0: {
381 let len: RVar<C::N> = arrays[0].len().into();
382 let size = arrays[0].element_size_of();
383 let end: Var<C::N> =
384 self.eval(arrays[0].ptr().address + len * RVar::from(size));
385 end.into()
386 },
387 step_sizes: arrays
388 .iter()
389 .map(|array| C::N::from_canonical_usize(array.element_size_of()))
390 .collect(),
391 builder: self,
392 }
393 } else {
394 panic!("Cannot use zipped pointer iterator with mixed arrays");
395 }
396 }
397
398 pub fn print_debug(&mut self, val: usize) {
399 let constant = self.eval(C::N::from_canonical_usize(val));
400 self.print_v(constant);
401 }
402
403 pub fn print_v(&mut self, dst: Var<C::N>) {
405 self.operations.push(DslIr::PrintV(dst));
406 }
407
408 pub fn print_f(&mut self, dst: Felt<C::F>) {
410 self.operations.push(DslIr::PrintF(dst));
411 }
412
413 pub fn print_e(&mut self, dst: Ext<C::F, C::EF>) {
415 self.operations.push(DslIr::PrintE(dst));
416 }
417
418 pub fn hint_var(&mut self) -> Var<C::N> {
419 let ptr = self.alloc(RVar::one(), 1);
420 self.operations.push(DslIr::HintFelt());
422 let index = MemIndex {
423 index: RVar::zero(),
424 offset: 0,
425 size: 1,
426 };
427 self.operations.push(DslIr::StoreHintWord(ptr, index));
428 let v: Var<C::N> = self.uninit();
429 self.load(v, ptr, index);
430 v
431 }
432
433 pub fn hint_felt(&mut self) -> Felt<C::F> {
434 let ptr = self.alloc(RVar::one(), 1);
435 self.operations.push(DslIr::HintFelt());
437 let index = MemIndex {
438 index: RVar::zero(),
439 offset: 0,
440 size: 1,
441 };
442 self.operations.push(DslIr::StoreHintWord(ptr, index));
443 let f: Felt<C::F> = self.uninit();
444 self.load(f, ptr, index);
445 f
446 }
447
448 pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
449 let flattened = self.hint_felts_fixed(C::EF::D);
450
451 let array: Array<C, Ext<_, _>> = match flattened {
453 Array::Fixed(_) => unreachable!(),
454 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::from(1)),
455 };
456 self.get(&array, 0)
457 }
458
459 pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
463 self.hint_words()
464 }
465
466 pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
468 self.hint_words()
469 }
470
471 pub fn hint_felts_fixed(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, Felt<C::F>> {
472 self.hint_words_fixed(len)
473 }
474
475 fn hint_words<V: MemVariable<C>>(&mut self) -> Array<C, V> {
477 assert_eq!(V::size_of(), 1);
478
479 let ptr = self.alloc(RVar::one(), 1);
481
482 self.operations.push(DslIr::HintInputVec());
484
485 let index = MemIndex {
487 index: RVar::zero(),
488 offset: 0,
489 size: 1,
490 };
491 self.operations.push(DslIr::StoreHintWord(ptr, index));
493
494 let vlen: Var<C::N> = self.uninit();
495 self.load(vlen, ptr, index);
496 let arr = self.dyn_array(vlen);
497
498 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
500 let index = MemIndex {
501 index: 0.into(),
502 offset: 0,
503 size: 1,
504 };
505 builder.operations.push(DslIr::StoreHintWord(
506 Ptr {
507 address: ptr_vec[0].variable(),
508 },
509 index,
510 ));
511 });
512 arr
513 }
514
515 fn hint_words_fixed<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
517 assert_eq!(V::size_of(), 1);
518
519 let arr = self.dyn_array(len.into());
520 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
522 let index = MemIndex {
523 index: 0.into(),
524 offset: 0,
525 size: 1,
526 };
527 builder.operations.push(DslIr::HintFelt());
528 builder.operations.push(DslIr::StoreHintWord(
529 Ptr {
530 address: ptr_vec[0].variable(),
531 },
532 index,
533 ));
534 });
535 arr
536 }
537
538 pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
543 let len = self.hint_var();
544 let flattened = self.hint_felts();
545
546 let size = <Ext<C::F, C::EF> as MemVariable<C>>::size_of();
547 self.assert_usize_eq(flattened.len(), len * C::N::from_canonical_usize(size));
548
549 match flattened {
551 Array::Fixed(_) => unreachable!(),
552 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::Var(len)),
553 }
554 }
555
556 pub fn hint_load(&mut self) -> Var<C::N> {
559 self.trace_push(DslIr::HintLoad());
560 let ptr = self.alloc(RVar::one(), 1);
561 let index = MemIndex {
562 index: RVar::zero(),
563 offset: 0,
564 size: 1,
565 };
566 self.operations.push(DslIr::StoreHintWord(ptr, index));
568 let id: Var<C::N> = self.uninit();
569 self.load(id, ptr, index);
570 id
571 }
572
573 pub fn witness_var(&mut self) -> Var<C::N> {
574 assert!(
575 !self.is_sub_builder,
576 "Cannot create a witness var with a sub builder"
577 );
578 let witness = self.uninit();
579 self.operations
580 .push(DslIr::WitnessVar(witness, self.witness_var_count));
581 self.witness_var_count += 1;
582 witness
583 }
584
585 pub fn witness_felt(&mut self) -> Felt<C::F> {
586 assert!(
587 !self.is_sub_builder,
588 "Cannot create a witness felt with a sub builder"
589 );
590 let witness = self.uninit();
591 self.operations
592 .push(DslIr::WitnessFelt(witness, self.witness_felt_count));
593 self.witness_felt_count += 1;
594 witness
595 }
596
597 pub fn witness_ext(&mut self) -> Ext<C::F, C::EF> {
598 assert!(
599 !self.is_sub_builder,
600 "Cannot create a witness ext with a sub builder"
601 );
602 let witness = self.uninit();
603 self.operations
604 .push(DslIr::WitnessExt(witness, self.witness_ext_count));
605 self.witness_ext_count += 1;
606 witness
607 }
608
609 pub fn witness_load(&mut self, witness_refs: Vec<WitnessRef>) -> Usize<C::N> {
610 assert!(
611 !self.is_sub_builder,
612 "Cannot load witness refs with a sub builder"
613 );
614 let ret = self.witness_space.len();
615 self.witness_space.push(witness_refs);
616 ret.into()
617 }
618
619 pub fn get_witness_refs(&self, id: Usize<C::N>) -> &[WitnessRef] {
620 self.witness_space.get(id.value()).unwrap()
621 }
622
623 pub fn error(&mut self) {
625 self.operations.trace_push(DslIr::Error());
626 }
627
628 fn get_nb_public_values(&mut self) -> Var<C::N> {
629 assert!(
630 !self.is_sub_builder,
631 "Cannot commit to public values with a sub builder"
632 );
633 if self.nb_public_values.is_none() {
634 self.nb_public_values = Some(self.eval(C::N::ZERO));
635 }
636 *self.nb_public_values.as_ref().unwrap()
637 }
638
639 fn commit_public_value_and_increment(&mut self, val: Felt<C::F>, nb_public_values: Var<C::N>) {
640 assert!(
641 !self.flags.static_only,
642 "Static mode should use static_commit_public_value"
643 );
644 self.operations.push(DslIr::Publish(val, nb_public_values));
645 self.assign(&nb_public_values, nb_public_values + C::N::ONE);
646 }
647
648 pub fn static_commit_public_value(&mut self, index: usize, val: Var<C::N>) {
650 assert!(
651 self.flags.static_only,
652 "Dynamic mode should use commit_public_value instead."
653 );
654 self.operations.push(DslIr::CircuitPublish(val, index));
655 }
656
657 pub fn commit_public_value(&mut self, val: Felt<C::F>) {
659 let nb_public_values = self.get_nb_public_values();
660 self.commit_public_value_and_increment(val, nb_public_values);
661 }
662
663 pub fn commit_public_values(&mut self, vals: &Array<C, Felt<C::F>>) {
665 let nb_public_values = self.get_nb_public_values();
666 let len = vals.len();
667 self.range(0, len).for_each(|idx_vec, builder| {
668 let val = builder.get(vals, idx_vec[0]);
669 builder.commit_public_value_and_increment(val, nb_public_values);
670 });
671 }
672
673 pub fn cycle_tracker_start(&mut self, name: &str) {
674 self.operations
675 .push(DslIr::CycleTrackerStart(name.to_string()));
676 }
677
678 pub fn cycle_tracker_end(&mut self, name: &str) {
679 self.operations
680 .push(DslIr::CycleTrackerEnd(name.to_string()));
681 }
682
683 pub fn halt(&mut self) {
684 self.operations.push(DslIr::Halt);
685 }
686}
687
688pub struct IfBuilder<'a, C: Config> {
690 lhs: SymbolicVar<C::N>,
691 rhs: SymbolicVar<C::N>,
692 is_eq: bool,
693 pub(crate) builder: &'a mut Builder<C>,
694}
695
696enum IfCondition<N> {
698 EqConst(N, N),
699 NeConst(N, N),
700 Eq(Var<N>, Var<N>),
701 EqI(Var<N>, N),
702 Ne(Var<N>, Var<N>),
703 NeI(Var<N>, N),
704}
705
706impl<C: Config> IfBuilder<'_, C> {
707 pub fn then(&mut self, mut f: impl FnMut(&mut Builder<C>)) {
708 let condition = self.condition();
710 match condition {
712 IfCondition::EqConst(lhs, rhs) => {
713 if lhs == rhs {
714 return f(self.builder);
715 }
716 return;
717 }
718 IfCondition::NeConst(lhs, rhs) => {
719 if lhs != rhs {
720 return f(self.builder);
721 }
722 return;
723 }
724 _ => (),
725 }
726 assert!(
727 !self.builder.flags.static_only,
728 "Cannot use dynamic branch in static mode"
729 );
730
731 let mut f_builder = self.builder.create_sub_builder();
733 f(&mut f_builder);
734 let then_instructions = f_builder.operations;
735
736 match condition {
738 IfCondition::Eq(lhs, rhs) => {
739 let op = DslIr::IfEq(lhs, rhs, then_instructions, Default::default());
740 self.builder.operations.push(op);
741 }
742 IfCondition::EqI(lhs, rhs) => {
743 let op = DslIr::IfEqI(lhs, rhs, then_instructions, Default::default());
744 self.builder.operations.push(op);
745 }
746 IfCondition::Ne(lhs, rhs) => {
747 let op = DslIr::IfNe(lhs, rhs, then_instructions, Default::default());
748 self.builder.operations.push(op);
749 }
750 IfCondition::NeI(lhs, rhs) => {
751 let op = DslIr::IfNeI(lhs, rhs, then_instructions, Default::default());
752 self.builder.operations.push(op);
753 }
754 _ => unreachable!("Const if should have returned early"),
755 }
756 }
757
758 pub fn then_or_else(
759 &mut self,
760 mut then_f: impl FnMut(&mut Builder<C>),
761 mut else_f: impl FnMut(&mut Builder<C>),
762 ) {
763 let condition = self.condition();
765 match condition {
767 IfCondition::EqConst(lhs, rhs) => {
768 if lhs == rhs {
769 return then_f(self.builder);
770 }
771 return else_f(self.builder);
772 }
773 IfCondition::NeConst(lhs, rhs) => {
774 if lhs != rhs {
775 return then_f(self.builder);
776 }
777 return else_f(self.builder);
778 }
779 _ => (),
780 }
781 assert!(
782 !self.builder.flags.static_only,
783 "Cannot use dynamic branch in static mode"
784 );
785 let mut then_builder = self.builder.create_sub_builder();
786
787 then_f(&mut then_builder);
789 let then_instructions = then_builder.operations;
790
791 let mut else_builder = self.builder.create_sub_builder();
792 else_f(&mut else_builder);
793 let else_instructions = else_builder.operations;
794
795 match condition {
797 IfCondition::Eq(lhs, rhs) => {
798 let op = DslIr::IfEq(lhs, rhs, then_instructions, else_instructions);
799 self.builder.operations.push(op);
800 }
801 IfCondition::EqI(lhs, rhs) => {
802 let op = DslIr::IfEqI(lhs, rhs, then_instructions, else_instructions);
803 self.builder.operations.push(op);
804 }
805 IfCondition::Ne(lhs, rhs) => {
806 let op = DslIr::IfNe(lhs, rhs, then_instructions, else_instructions);
807 self.builder.operations.push(op);
808 }
809 IfCondition::NeI(lhs, rhs) => {
810 let op = DslIr::IfNeI(lhs, rhs, then_instructions, else_instructions);
811 self.builder.operations.push(op);
812 }
813 _ => unreachable!("Const if should have returned early"),
814 }
815 }
816
817 fn condition(&mut self) -> IfCondition<C::N> {
818 match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
819 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
820 IfCondition::EqConst(lhs, rhs)
821 }
822 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
823 IfCondition::NeConst(lhs, rhs)
824 }
825 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
826 IfCondition::EqI(rhs, lhs)
827 }
828 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
829 IfCondition::NeI(rhs, lhs)
830 }
831 (SymbolicVar::Const(lhs, _), rhs, true) => {
832 let rhs: Var<C::N> = self.builder.eval(rhs);
833 IfCondition::EqI(rhs, lhs)
834 }
835 (SymbolicVar::Const(lhs, _), rhs, false) => {
836 let rhs: Var<C::N> = self.builder.eval(rhs);
837 IfCondition::NeI(rhs, lhs)
838 }
839 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
840 let lhs: Var<C::N> = self.builder.eval(lhs);
841 IfCondition::EqI(lhs, rhs)
842 }
843 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
844 let lhs: Var<C::N> = self.builder.eval(lhs);
845 IfCondition::NeI(lhs, rhs)
846 }
847 (lhs, SymbolicVar::Const(rhs, _), true) => {
848 let lhs: Var<C::N> = self.builder.eval(lhs);
849 IfCondition::EqI(lhs, rhs)
850 }
851 (lhs, SymbolicVar::Const(rhs, _), false) => {
852 let lhs: Var<C::N> = self.builder.eval(lhs);
853 IfCondition::NeI(lhs, rhs)
854 }
855 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
856 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
857 IfCondition::Ne(lhs, rhs)
858 }
859 (SymbolicVar::Val(lhs, _), rhs, true) => {
860 let rhs: Var<C::N> = self.builder.eval(rhs);
861 IfCondition::Eq(lhs, rhs)
862 }
863 (SymbolicVar::Val(lhs, _), rhs, false) => {
864 let rhs: Var<C::N> = self.builder.eval(rhs);
865 IfCondition::Ne(lhs, rhs)
866 }
867 (lhs, SymbolicVar::Val(rhs, _), true) => {
868 let lhs: Var<C::N> = self.builder.eval(lhs);
869 IfCondition::Eq(lhs, rhs)
870 }
871 (lhs, SymbolicVar::Val(rhs, _), false) => {
872 let lhs: Var<C::N> = self.builder.eval(lhs);
873 IfCondition::Ne(lhs, rhs)
874 }
875 (lhs, rhs, true) => {
876 let lhs: Var<C::N> = self.builder.eval(lhs);
877 let rhs: Var<C::N> = self.builder.eval(rhs);
878 IfCondition::Eq(lhs, rhs)
879 }
880 (lhs, rhs, false) => {
881 let lhs: Var<C::N> = self.builder.eval(lhs);
882 let rhs: Var<C::N> = self.builder.eval(rhs);
883 IfCondition::Ne(lhs, rhs)
884 }
885 }
886 }
887}
888
889pub struct IteratorBuilder<'a, C: Config> {
891 starts: Vec<RVar<C::N>>,
892 end0: RVar<C::N>,
893 step_sizes: Vec<C::N>,
894 builder: &'a mut Builder<C>,
895}
896
897impl<C: Config> IteratorBuilder<'_, C> {
898 pub fn for_each(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
899 assert!(self.starts.len() == self.step_sizes.len());
900 assert!(!self.starts.is_empty());
901
902 if self.starts.iter().all(|start| start.is_const()) && self.end0.is_const() {
903 self.for_each_unrolled(|ptrs, builder| {
904 f(ptrs, builder);
905 });
906 return;
907 }
908
909 self.for_each_dynamic(|ptrs, builder| {
910 f(ptrs, builder);
911 });
912 }
913
914 fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
915 let mut ptrs: Vec<_> = self
916 .starts
917 .iter()
918 .map(|start| start.field_value())
919 .collect();
920 let end0 = self.end0.field_value();
921 while ptrs[0] != end0 {
922 f(
923 ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
924 self.builder,
925 );
926 for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
927 *ptr += *step_size;
928 }
929 }
930 }
931
932 fn for_each_dynamic(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
933 assert!(
934 !self.builder.flags.static_only,
935 "Cannot use dynamic loop in static mode"
936 );
937
938 let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
939 .map(|_| self.builder.uninit())
940 .collect();
941 let mut loop_body_builder = self.builder.create_sub_builder();
942
943 f(
944 loop_variables.iter().map(|&v| v.into()).collect(),
945 &mut loop_body_builder,
946 );
947
948 let loop_instructions = loop_body_builder.operations;
949 let op = DslIr::ZipFor(
950 self.starts.clone(),
951 self.end0,
952 self.step_sizes.clone(),
953 loop_variables,
954 loop_instructions,
955 );
956 self.builder.operations.push(op);
957 }
958}