1use std::{iter::Zip, vec::IntoIter};
2
3use backtrace::Backtrace;
4use itertools::izip;
5use openvm_native_compiler_derive::iter_zip;
6use openvm_stark_backend::p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing};
7use serde::{Deserialize, Serialize};
8
9use super::{
10 Array, Config, DslIr, Ext, Felt, FromConstant, MemIndex, MemVariable, RVar, SymbolicExt,
11 SymbolicFelt, SymbolicVar, Usize, Var, Variable, WitnessRef,
12};
13use crate::ir::{collections::ArrayLike, Ptr};
14
15#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct TracedVec<T> {
19 pub vec: Vec<T>,
20 pub traces: Vec<Option<Backtrace>>,
21}
22
23impl<T> Default for TracedVec<T> {
24 fn default() -> Self {
25 Self::new()
26 }
27}
28
29impl<T> From<Vec<T>> for TracedVec<T> {
30 fn from(vec: Vec<T>) -> Self {
31 let len = vec.len();
32 Self {
33 vec,
34 traces: vec![None; len],
35 }
36 }
37}
38
39impl<T> TracedVec<T> {
40 pub const fn new() -> Self {
41 Self {
42 vec: Vec::new(),
43 traces: Vec::new(),
44 }
45 }
46
47 pub fn push(&mut self, value: T) {
48 self.vec.push(value);
49 self.traces.push(None);
50 }
51
52 pub fn trace_push(&mut self, value: T) {
54 self.vec.push(value);
55 if std::env::var_os("RUST_BACKTRACE").is_none() {
56 self.traces.push(None);
57 } else {
58 self.traces.push(Some(Backtrace::new_unresolved()));
59 }
60 }
61
62 pub fn extend<I: IntoIterator<Item = (T, Option<Backtrace>)>>(&mut self, iter: I) {
63 let iter = iter.into_iter();
64 let len = iter.size_hint().0;
65 self.vec.reserve(len);
66 self.traces.reserve(len);
67 for (value, trace) in iter {
68 self.vec.push(value);
69 self.traces.push(trace);
70 }
71 }
72
73 pub fn is_empty(&self) -> bool {
74 self.vec.is_empty()
75 }
76}
77
78impl<T> IntoIterator for TracedVec<T> {
79 type Item = (T, Option<Backtrace>);
80 type IntoIter = Zip<IntoIter<T>, IntoIter<Option<Backtrace>>>;
81
82 fn into_iter(self) -> Self::IntoIter {
83 self.vec.into_iter().zip(self.traces)
84 }
85}
86
87#[derive(Debug, Copy, Clone, Default)]
88pub struct BuilderFlags {
89 pub debug: bool,
90 pub static_only: bool,
92}
93
94#[derive(Debug, Clone, Default)]
98pub struct Builder<C: Config> {
99 pub(crate) var_count: u32,
100 pub(crate) felt_count: u32,
101 pub(crate) ext_count: u32,
102 pub operations: TracedVec<DslIr<C>>,
103 pub(crate) nb_public_values: Option<Var<C::N>>,
104 pub(crate) witness_var_count: u32,
105 pub(crate) witness_felt_count: u32,
106 pub(crate) witness_ext_count: u32,
107 pub(crate) witness_space: Vec<Vec<WitnessRef>>,
108 pub flags: BuilderFlags,
109 pub is_sub_builder: bool,
110}
111
112impl<C: Config> Builder<C> {
113 pub fn create_sub_builder(&self) -> Self {
115 Self {
116 var_count: self.var_count,
117 felt_count: self.felt_count,
118 ext_count: self.ext_count,
119 witness_var_count: 0,
122 witness_felt_count: 0,
123 witness_ext_count: 0,
124 witness_space: Default::default(),
125 operations: Default::default(),
126 nb_public_values: self.nb_public_values,
127 flags: self.flags,
128 is_sub_builder: true,
129 }
130 }
131
132 pub fn push(&mut self, op: DslIr<C>) {
134 self.operations.push(op);
135 }
136
137 pub fn trace_push(&mut self, op: DslIr<C>) {
139 self.operations.trace_push(op);
140 }
141
142 pub fn uninit<V: Variable<C>>(&mut self) -> V {
144 V::uninit(self)
145 }
146
147 pub fn eval<V: Variable<C>, E: Into<V::Expression>>(&mut self, expr: E) -> V {
149 V::eval(self, expr)
150 }
151
152 pub fn eval_expr(&mut self, expr: impl Into<SymbolicVar<C::N>>) -> RVar<C::N> {
154 let expr = expr.into();
155 match expr {
156 SymbolicVar::Const(c, _) => RVar::Const(c),
157 SymbolicVar::Val(val, _) => RVar::Val(val),
158 _ => {
159 let ret: Var<_> = self.eval(expr);
160 RVar::Val(ret)
161 }
162 }
163 }
164
165 pub fn inc(&mut self, u: &Usize<C::N>) {
167 self.assign(u, u.clone() + RVar::one());
168 }
169
170 pub fn constant<V: FromConstant<C>>(&mut self, value: V::Constant) -> V {
172 V::constant(value, self)
173 }
174
175 pub fn assign<V: Variable<C>, E: Into<V::Expression>>(&mut self, dst: &V, expr: E) {
177 dst.assign(expr.into(), self);
178 }
179
180 pub fn cast_felt_to_var(&mut self, felt: Felt<C::F>) -> Var<C::N> {
182 let var: Var<_> = self.uninit();
183 self.operations.push(DslIr::CastFV(var, felt));
184 var
185 }
186 pub fn unsafe_cast_var_to_felt(&mut self, var: Var<C::N>) -> Felt<C::F> {
188 assert!(!self.flags.static_only, "dynamic mode only");
189 let felt: Felt<_> = self.uninit();
190 self.operations.push(DslIr::UnsafeCastVF(felt, var));
191 felt
192 }
193
194 pub fn assert_nonzero(&mut self, u: &Usize<C::N>) {
196 if self.flags.static_only {
197 assert_ne!(u.value(), 0, "assert_nonzero failed on constant zero");
198 } else {
199 self.operations.push(DslIr::AssertNonZero(u.clone()));
200 }
201 }
202
203 pub fn assert_eq<V: Variable<C>>(
205 &mut self,
206 lhs: impl Into<V::Expression>,
207 rhs: impl Into<V::Expression>,
208 ) {
209 V::assert_eq(lhs, rhs, self);
210 }
211
212 pub fn assert_var_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
214 &mut self,
215 lhs: LhsExpr,
216 rhs: RhsExpr,
217 ) {
218 self.assert_eq::<Var<C::N>>(lhs, rhs);
219 }
220
221 pub fn assert_felt_eq<LhsExpr: Into<SymbolicFelt<C::F>>, RhsExpr: Into<SymbolicFelt<C::F>>>(
223 &mut self,
224 lhs: LhsExpr,
225 rhs: RhsExpr,
226 ) {
227 self.assert_eq::<Felt<C::F>>(lhs, rhs);
228 }
229
230 pub fn assert_ext_eq<
232 LhsExpr: Into<SymbolicExt<C::F, C::EF>>,
233 RhsExpr: Into<SymbolicExt<C::F, C::EF>>,
234 >(
235 &mut self,
236 lhs: LhsExpr,
237 rhs: RhsExpr,
238 ) {
239 self.assert_eq::<Ext<C::F, C::EF>>(lhs, rhs);
240 }
241
242 pub fn assert_usize_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
243 &mut self,
244 lhs: LhsExpr,
245 rhs: RhsExpr,
246 ) {
247 self.assert_eq::<Usize<C::N>>(lhs, rhs);
248 }
249
250 pub fn assert_var_array_eq(&mut self, lhs: &Array<C, Var<C::N>>, rhs: &Array<C, Var<C::N>>) {
252 self.assert_var_eq(lhs.len(), rhs.len());
253 self.range(0, lhs.len()).for_each(|idx_vec, builder| {
254 let l = builder.get(lhs, idx_vec[0]);
255 let r = builder.get(rhs, idx_vec[0]);
256 builder.assert_var_eq(l, r);
257 });
258 }
259
260 pub fn if_eq<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
262 &mut self,
263 lhs: LhsExpr,
264 rhs: RhsExpr,
265 ) -> IfBuilder<'_, C> {
266 IfBuilder {
267 lhs: lhs.into(),
268 rhs: rhs.into(),
269 is_eq: true,
270 builder: self,
271 }
272 }
273
274 pub fn if_ne<LhsExpr: Into<SymbolicVar<C::N>>, RhsExpr: Into<SymbolicVar<C::N>>>(
276 &mut self,
277 lhs: LhsExpr,
278 rhs: RhsExpr,
279 ) -> IfBuilder<'_, C> {
280 IfBuilder {
281 lhs: lhs.into(),
282 rhs: rhs.into(),
283 is_eq: false,
284 builder: self,
285 }
286 }
287
288 pub fn assert_less_than_slow_small_rhs<
290 LhsExpr: Into<SymbolicVar<C::N>>,
291 RhsExpr: Into<SymbolicVar<C::N>>,
292 >(
293 &mut self,
294 lhs: LhsExpr,
295 rhs: RhsExpr,
296 ) {
297 let lhs: Usize<_> = self.eval(lhs.into());
298 let rhs: Usize<_> = self.eval(rhs.into());
299 let product: Usize<_> = self.eval(lhs.clone());
300 self.range(1, rhs).for_each(|i_vec, builder| {
301 let i = i_vec[0];
302 let diff: Usize<_> = builder.eval(lhs.clone() - i);
303 builder.assign(&product, product.clone() * diff);
304 });
305 self.assert_usize_eq(product, RVar::from(0));
306 }
307
308 pub fn assert_less_than_slow_small_lhs<
310 LhsExpr: Into<SymbolicVar<C::N>>,
311 RhsExpr: Into<SymbolicVar<C::N>>,
312 >(
313 &mut self,
314 lhs: LhsExpr,
315 rhs: RhsExpr,
316 ) {
317 let lhs: Usize<_> = self.eval(lhs.into());
318 let rhs: Usize<_> = self.eval(rhs.into());
319 let product: Usize<_> = self.eval(rhs.clone());
320 let lhs_plus_one: Usize<_> = self.eval(lhs.clone() + RVar::one());
321 self.range(1, lhs_plus_one).for_each(|i_vec, builder| {
322 let i = i_vec[0];
323 let diff: Usize<_> = builder.eval(rhs.clone() - i);
324 builder.assign(&product, product.clone() * diff);
325 });
326 self.assert_nonzero(&product);
327 }
328
329 pub fn assert_less_than_slow_bit_decomp(&mut self, lhs: Var<C::N>, rhs: Var<C::N>) {
336 let lhs = self.unsafe_cast_var_to_felt(lhs);
337 let rhs = self.unsafe_cast_var_to_felt(rhs);
338
339 let lhs_bits = self.num2bits_f(lhs, C::N::bits() as u32);
340 let rhs_bits = self.num2bits_f(rhs, C::N::bits() as u32);
341
342 let is_lt: Var<_> = self.eval(C::N::ZERO);
343
344 iter_zip!(self, lhs_bits, rhs_bits).for_each(|ptr_vec, builder| {
345 let lhs_bit = builder.iter_ptr_get(&lhs_bits, ptr_vec[0]);
346 let rhs_bit = builder.iter_ptr_get(&rhs_bits, ptr_vec[1]);
347
348 builder.if_ne(lhs_bit, rhs_bit).then(|builder| {
349 builder.assign(&is_lt, rhs_bit);
350 });
351 });
352 self.assert_var_eq(is_lt, C::N::ONE);
353 }
354
355 pub fn range_check_var(&mut self, x: Var<C::N>, num_bits: usize) {
357 assert!(!self.flags.static_only, "range_check_var is dynamic");
358 assert!(num_bits <= 30);
359 self.trace_push(DslIr::RangeCheckV(x, num_bits));
360 }
361
362 pub fn range(
364 &mut self,
365 start: impl Into<RVar<C::N>>,
366 end: impl Into<RVar<C::N>>,
367 ) -> IteratorBuilder<'_, C> {
368 self.range_with_step(start, end, C::N::ONE)
369 }
370 pub fn range_with_step(
372 &mut self,
373 start: impl Into<RVar<C::N>>,
374 end: impl Into<RVar<C::N>>,
375 step: C::N,
376 ) -> IteratorBuilder<'_, C> {
377 let start = start.into();
378 let end0 = end.into();
379 IteratorBuilder {
380 starts: vec![start],
381 end0,
382 step_sizes: vec![step],
383 builder: self,
384 }
385 }
386
387 pub fn zip<'a>(
388 &'a mut self,
389 arrays: &'a [Box<dyn ArrayLike<C> + 'a>],
390 ) -> IteratorBuilder<'a, C> {
391 assert!(!arrays.is_empty());
392 if arrays.iter().all(|array| array.is_fixed()) {
393 IteratorBuilder {
394 starts: vec![RVar::zero(); arrays.len()],
395 end0: arrays[0].len().into(),
396 step_sizes: vec![C::N::ONE; arrays.len()],
397 builder: self,
398 }
399 } else if arrays.iter().all(|array| !array.is_fixed()) {
400 IteratorBuilder {
401 starts: arrays
402 .iter()
403 .map(|array| array.ptr().address.into())
404 .collect(),
405 end0: {
406 let len: RVar<C::N> = arrays[0].len().into();
407 let size = arrays[0].element_size_of();
408 let end: Var<C::N> =
409 self.eval(arrays[0].ptr().address + len * RVar::from(size));
410 end.into()
411 },
412 step_sizes: arrays
413 .iter()
414 .map(|array| C::N::from_usize(array.element_size_of()))
415 .collect(),
416 builder: self,
417 }
418 } else {
419 panic!("Cannot use zipped pointer iterator with mixed arrays");
420 }
421 }
422
423 pub fn print_debug(&mut self, val: usize) {
424 let constant = self.eval(C::N::from_usize(val));
425 self.print_v(constant);
426 }
427
428 pub fn print_v(&mut self, dst: Var<C::N>) {
430 self.operations.push(DslIr::PrintV(dst));
431 }
432
433 pub fn print_f(&mut self, dst: Felt<C::F>) {
435 self.operations.push(DslIr::PrintF(dst));
436 }
437
438 pub fn print_e(&mut self, dst: Ext<C::F, C::EF>) {
440 self.operations.push(DslIr::PrintE(dst));
441 }
442
443 pub fn hint_var(&mut self) -> Var<C::N> {
444 let ptr = self.alloc(RVar::one(), 1);
445 self.operations.push(DslIr::HintFelt());
447 let index = MemIndex {
448 index: RVar::zero(),
449 offset: 0,
450 size: 1,
451 };
452 self.operations.push(DslIr::StoreHintWord(ptr, index));
453 let v: Var<C::N> = self.uninit();
454 self.load(v, ptr, index);
455 v
456 }
457
458 pub fn hint_felt(&mut self) -> Felt<C::F> {
459 let ptr = self.alloc(RVar::one(), 1);
460 self.operations.push(DslIr::HintFelt());
462 let index = MemIndex {
463 index: RVar::zero(),
464 offset: 0,
465 size: 1,
466 };
467 self.operations.push(DslIr::StoreHintWord(ptr, index));
468 let f: Felt<C::F> = self.uninit();
469 self.load(f, ptr, index);
470 f
471 }
472
473 pub fn hint_ext(&mut self) -> Ext<C::F, C::EF> {
474 let flattened = self.hint_felts_fixed(C::EF::DIMENSION);
475
476 let array: Array<C, Ext<_, _>> = match flattened {
478 Array::Fixed(_) => unreachable!(),
479 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::from(1)),
480 };
481 self.get(&array, 0)
482 }
483
484 pub fn hint_vars(&mut self) -> Array<C, Var<C::N>> {
488 self.hint_words()
489 }
490
491 pub fn hint_felts(&mut self) -> Array<C, Felt<C::F>> {
493 self.hint_words()
494 }
495
496 pub fn hint_felts_fixed(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, Felt<C::F>> {
497 self.hint_words_fixed(len)
498 }
499
500 fn hint_words<V: MemVariable<C>>(&mut self) -> Array<C, V> {
502 assert_eq!(V::size_of(), 1);
503
504 let ptr = self.alloc(RVar::one(), 1);
506
507 self.operations.push(DslIr::HintInputVec());
509
510 let index = MemIndex {
512 index: RVar::zero(),
513 offset: 0,
514 size: 1,
515 };
516 self.operations.push(DslIr::StoreHintWord(ptr, index));
518
519 let vlen: Var<C::N> = self.uninit();
520 self.load(vlen, ptr, index);
521 let arr = self.dyn_array(vlen);
522
523 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
525 let index = MemIndex {
526 index: 0.into(),
527 offset: 0,
528 size: 1,
529 };
530 builder.operations.push(DslIr::StoreHintWord(
531 Ptr {
532 address: ptr_vec[0].variable(),
533 },
534 index,
535 ));
536 });
537 arr
538 }
539
540 fn hint_words_fixed<V: MemVariable<C>>(&mut self, len: impl Into<RVar<C::N>>) -> Array<C, V> {
542 assert_eq!(V::size_of(), 1);
543
544 let arr = self.dyn_array(len.into());
545 iter_zip!(self, arr).for_each(|ptr_vec, builder| {
547 let index = MemIndex {
548 index: 0.into(),
549 offset: 0,
550 size: 1,
551 };
552 builder.operations.push(DslIr::HintFelt());
553 builder.operations.push(DslIr::StoreHintWord(
554 Ptr {
555 address: ptr_vec[0].variable(),
556 },
557 index,
558 ));
559 });
560 arr
561 }
562
563 pub fn hint_exts(&mut self) -> Array<C, Ext<C::F, C::EF>> {
568 let len = self.hint_var();
569 let flattened = self.hint_felts();
570
571 let size = <Ext<C::F, C::EF> as MemVariable<C>>::size_of();
572 self.assert_usize_eq(flattened.len(), len * C::N::from_usize(size));
573
574 match flattened {
576 Array::Fixed(_) => unreachable!(),
577 Array::Dyn(ptr, _) => Array::Dyn(ptr, Usize::Var(len)),
578 }
579 }
580
581 pub fn hint_load(&mut self) -> Var<C::N> {
584 self.trace_push(DslIr::HintLoad());
585 let ptr = self.alloc(RVar::one(), 1);
586 let index = MemIndex {
587 index: RVar::zero(),
588 offset: 0,
589 size: 1,
590 };
591 self.operations.push(DslIr::StoreHintWord(ptr, index));
593 let id: Var<C::N> = self.uninit();
594 self.load(id, ptr, index);
595 id
596 }
597
598 pub fn witness_var(&mut self) -> Var<C::N> {
599 assert!(
600 !self.is_sub_builder,
601 "Cannot create a witness var with a sub builder"
602 );
603 let witness = self.uninit();
604 self.operations
605 .push(DslIr::WitnessVar(witness, self.witness_var_count));
606 self.witness_var_count += 1;
607 witness
608 }
609
610 pub fn witness_felt(&mut self) -> Felt<C::F> {
611 assert!(
612 !self.is_sub_builder,
613 "Cannot create a witness felt with a sub builder"
614 );
615 let witness = self.uninit();
616 self.operations
617 .push(DslIr::WitnessFelt(witness, self.witness_felt_count));
618 self.witness_felt_count += 1;
619 witness
620 }
621
622 pub fn witness_ext(&mut self) -> Ext<C::F, C::EF> {
623 assert!(
624 !self.is_sub_builder,
625 "Cannot create a witness ext with a sub builder"
626 );
627 let witness = self.uninit();
628 self.operations
629 .push(DslIr::WitnessExt(witness, self.witness_ext_count));
630 self.witness_ext_count += 1;
631 witness
632 }
633
634 pub fn witness_load(&mut self, witness_refs: Vec<WitnessRef>) -> Usize<C::N> {
635 assert!(
636 !self.is_sub_builder,
637 "Cannot load witness refs with a sub builder"
638 );
639 let ret = self.witness_space.len();
640 self.witness_space.push(witness_refs);
641 ret.into()
642 }
643
644 pub fn get_witness_refs(&self, id: Usize<C::N>) -> &[WitnessRef] {
645 self.witness_space.get(id.value()).unwrap()
646 }
647
648 pub fn error(&mut self) {
650 self.operations.trace_push(DslIr::Error());
651 }
652
653 fn get_nb_public_values(&mut self) -> Var<C::N> {
654 assert!(
655 !self.is_sub_builder,
656 "Cannot commit to public values with a sub builder"
657 );
658 if self.nb_public_values.is_none() {
659 self.nb_public_values = Some(self.eval(C::N::ZERO));
660 }
661 *self.nb_public_values.as_ref().unwrap()
662 }
663
664 fn commit_public_value_and_increment(&mut self, val: Felt<C::F>, nb_public_values: Var<C::N>) {
665 assert!(
666 !self.flags.static_only,
667 "Static mode should use static_commit_public_value"
668 );
669 self.operations.push(DslIr::Publish(val, nb_public_values));
670 self.assign(&nb_public_values, nb_public_values + C::N::ONE);
671 }
672
673 pub fn static_commit_public_value(&mut self, index: usize, val: Var<C::N>) {
676 assert!(
677 self.flags.static_only,
678 "Dynamic mode should use commit_public_value instead."
679 );
680 self.operations.push(DslIr::CircuitPublish(val, index));
681 }
682
683 pub fn commit_public_value(&mut self, val: Felt<C::F>) {
685 let nb_public_values = self.get_nb_public_values();
686 self.commit_public_value_and_increment(val, nb_public_values);
687 }
688
689 pub fn commit_public_values(&mut self, vals: &Array<C, Felt<C::F>>) {
691 let nb_public_values = self.get_nb_public_values();
692 let len = vals.len();
693 self.range(0, len).for_each(|idx_vec, builder| {
694 let val = builder.get(vals, idx_vec[0]);
695 builder.commit_public_value_and_increment(val, nb_public_values);
696 });
697 }
698
699 pub fn cycle_tracker_start(&mut self, name: &str) {
700 self.operations
701 .push(DslIr::CycleTrackerStart(name.to_string()));
702 }
703
704 pub fn cycle_tracker_end(&mut self, name: &str) {
705 self.operations
706 .push(DslIr::CycleTrackerEnd(name.to_string()));
707 }
708
709 pub fn halt(&mut self) {
710 self.operations.push(DslIr::Halt);
711 }
712}
713
714pub struct IfBuilder<'a, C: Config> {
716 lhs: SymbolicVar<C::N>,
717 rhs: SymbolicVar<C::N>,
718 is_eq: bool,
719 pub(crate) builder: &'a mut Builder<C>,
720}
721
722enum IfCondition<N> {
724 EqConst(N, N),
725 NeConst(N, N),
726 Eq(Var<N>, Var<N>),
727 EqI(Var<N>, N),
728 Ne(Var<N>, Var<N>),
729 NeI(Var<N>, N),
730}
731
732impl<C: Config> IfBuilder<'_, C> {
733 pub fn then(&mut self, mut f: impl FnMut(&mut Builder<C>)) {
734 let condition = self.condition();
736 match condition {
738 IfCondition::EqConst(lhs, rhs) => {
739 if lhs == rhs {
740 return f(self.builder);
741 }
742 return;
743 }
744 IfCondition::NeConst(lhs, rhs) => {
745 if lhs != rhs {
746 return f(self.builder);
747 }
748 return;
749 }
750 _ => (),
751 }
752 assert!(
753 !self.builder.flags.static_only,
754 "Cannot use dynamic branch in static mode"
755 );
756
757 let mut f_builder = self.builder.create_sub_builder();
759 f(&mut f_builder);
760 let then_instructions = f_builder.operations;
761
762 match condition {
764 IfCondition::Eq(lhs, rhs) => {
765 let op = DslIr::IfEq(lhs, rhs, then_instructions, Default::default());
766 self.builder.operations.push(op);
767 }
768 IfCondition::EqI(lhs, rhs) => {
769 let op = DslIr::IfEqI(lhs, rhs, then_instructions, Default::default());
770 self.builder.operations.push(op);
771 }
772 IfCondition::Ne(lhs, rhs) => {
773 let op = DslIr::IfNe(lhs, rhs, then_instructions, Default::default());
774 self.builder.operations.push(op);
775 }
776 IfCondition::NeI(lhs, rhs) => {
777 let op = DslIr::IfNeI(lhs, rhs, then_instructions, Default::default());
778 self.builder.operations.push(op);
779 }
780 _ => unreachable!("Const if should have returned early"),
781 }
782 }
783
784 pub fn then_or_else(
785 &mut self,
786 mut then_f: impl FnMut(&mut Builder<C>),
787 mut else_f: impl FnMut(&mut Builder<C>),
788 ) {
789 let condition = self.condition();
791 match condition {
793 IfCondition::EqConst(lhs, rhs) => {
794 if lhs == rhs {
795 return then_f(self.builder);
796 }
797 return else_f(self.builder);
798 }
799 IfCondition::NeConst(lhs, rhs) => {
800 if lhs != rhs {
801 return then_f(self.builder);
802 }
803 return else_f(self.builder);
804 }
805 _ => (),
806 }
807 assert!(
808 !self.builder.flags.static_only,
809 "Cannot use dynamic branch in static mode"
810 );
811 let mut then_builder = self.builder.create_sub_builder();
812
813 then_f(&mut then_builder);
815 let then_instructions = then_builder.operations;
816
817 let mut else_builder = self.builder.create_sub_builder();
818 else_f(&mut else_builder);
819 let else_instructions = else_builder.operations;
820
821 match condition {
823 IfCondition::Eq(lhs, rhs) => {
824 let op = DslIr::IfEq(lhs, rhs, then_instructions, else_instructions);
825 self.builder.operations.push(op);
826 }
827 IfCondition::EqI(lhs, rhs) => {
828 let op = DslIr::IfEqI(lhs, rhs, then_instructions, else_instructions);
829 self.builder.operations.push(op);
830 }
831 IfCondition::Ne(lhs, rhs) => {
832 let op = DslIr::IfNe(lhs, rhs, then_instructions, else_instructions);
833 self.builder.operations.push(op);
834 }
835 IfCondition::NeI(lhs, rhs) => {
836 let op = DslIr::IfNeI(lhs, rhs, then_instructions, else_instructions);
837 self.builder.operations.push(op);
838 }
839 _ => unreachable!("Const if should have returned early"),
840 }
841 }
842
843 fn condition(&mut self) -> IfCondition<C::N> {
844 match (self.lhs.clone(), self.rhs.clone(), self.is_eq) {
845 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), true) => {
846 IfCondition::EqConst(lhs, rhs)
847 }
848 (SymbolicVar::Const(lhs, _), SymbolicVar::Const(rhs, _), false) => {
849 IfCondition::NeConst(lhs, rhs)
850 }
851 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), true) => {
852 IfCondition::EqI(rhs, lhs)
853 }
854 (SymbolicVar::Const(lhs, _), SymbolicVar::Val(rhs, _), false) => {
855 IfCondition::NeI(rhs, lhs)
856 }
857 (SymbolicVar::Const(lhs, _), rhs, true) => {
858 let rhs: Var<C::N> = self.builder.eval(rhs);
859 IfCondition::EqI(rhs, lhs)
860 }
861 (SymbolicVar::Const(lhs, _), rhs, false) => {
862 let rhs: Var<C::N> = self.builder.eval(rhs);
863 IfCondition::NeI(rhs, lhs)
864 }
865 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), true) => {
866 let lhs: Var<C::N> = self.builder.eval(lhs);
867 IfCondition::EqI(lhs, rhs)
868 }
869 (SymbolicVar::Val(lhs, _), SymbolicVar::Const(rhs, _), false) => {
870 let lhs: Var<C::N> = self.builder.eval(lhs);
871 IfCondition::NeI(lhs, rhs)
872 }
873 (lhs, SymbolicVar::Const(rhs, _), true) => {
874 let lhs: Var<C::N> = self.builder.eval(lhs);
875 IfCondition::EqI(lhs, rhs)
876 }
877 (lhs, SymbolicVar::Const(rhs, _), false) => {
878 let lhs: Var<C::N> = self.builder.eval(lhs);
879 IfCondition::NeI(lhs, rhs)
880 }
881 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), true) => IfCondition::Eq(lhs, rhs),
882 (SymbolicVar::Val(lhs, _), SymbolicVar::Val(rhs, _), false) => {
883 IfCondition::Ne(lhs, rhs)
884 }
885 (SymbolicVar::Val(lhs, _), rhs, true) => {
886 let rhs: Var<C::N> = self.builder.eval(rhs);
887 IfCondition::Eq(lhs, rhs)
888 }
889 (SymbolicVar::Val(lhs, _), rhs, false) => {
890 let rhs: Var<C::N> = self.builder.eval(rhs);
891 IfCondition::Ne(lhs, rhs)
892 }
893 (lhs, SymbolicVar::Val(rhs, _), true) => {
894 let lhs: Var<C::N> = self.builder.eval(lhs);
895 IfCondition::Eq(lhs, rhs)
896 }
897 (lhs, SymbolicVar::Val(rhs, _), false) => {
898 let lhs: Var<C::N> = self.builder.eval(lhs);
899 IfCondition::Ne(lhs, rhs)
900 }
901 (lhs, rhs, true) => {
902 let lhs: Var<C::N> = self.builder.eval(lhs);
903 let rhs: Var<C::N> = self.builder.eval(rhs);
904 IfCondition::Eq(lhs, rhs)
905 }
906 (lhs, rhs, false) => {
907 let lhs: Var<C::N> = self.builder.eval(lhs);
908 let rhs: Var<C::N> = self.builder.eval(rhs);
909 IfCondition::Ne(lhs, rhs)
910 }
911 }
912 }
913}
914
915pub struct IteratorBuilder<'a, C: Config> {
917 starts: Vec<RVar<C::N>>,
918 end0: RVar<C::N>,
919 step_sizes: Vec<C::N>,
920 builder: &'a mut Builder<C>,
921}
922
923impl<C: Config> IteratorBuilder<'_, C> {
924 pub fn for_each(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
925 assert!(self.starts.len() == self.step_sizes.len());
926 assert!(!self.starts.is_empty());
927
928 if self.starts.iter().all(|start| start.is_const()) && self.end0.is_const() {
929 self.for_each_unrolled(|ptrs, builder| {
930 f(ptrs, builder);
931 });
932 return;
933 }
934
935 self.for_each_dynamic(|ptrs, builder| {
936 f(ptrs, builder);
937 });
938 }
939
940 fn for_each_unrolled(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
941 let mut ptrs: Vec<_> = self
942 .starts
943 .iter()
944 .map(|start| start.field_value())
945 .collect();
946 let end0 = self.end0.field_value();
947 while ptrs[0] != end0 {
948 f(
949 ptrs.iter().map(|ptr| RVar::Const(*ptr)).collect(),
950 self.builder,
951 );
952 for (ptr, step_size) in izip!(&mut ptrs, &self.step_sizes) {
953 *ptr += *step_size;
954 }
955 }
956 }
957
958 fn for_each_dynamic(&mut self, mut f: impl FnMut(Vec<RVar<C::N>>, &mut Builder<C>)) {
959 assert!(
960 !self.builder.flags.static_only,
961 "Cannot use dynamic loop in static mode"
962 );
963
964 let loop_variables: Vec<Var<C::N>> = (0..self.starts.len())
965 .map(|_| self.builder.uninit())
966 .collect();
967 let mut loop_body_builder = self.builder.create_sub_builder();
968
969 f(
970 loop_variables.iter().map(|&v| v.into()).collect(),
971 &mut loop_body_builder,
972 );
973
974 let loop_instructions = loop_body_builder.operations;
975 let op = DslIr::ZipFor(
976 self.starts.clone(),
977 self.end0,
978 self.step_sizes.clone(),
979 loop_variables,
980 loop_instructions,
981 );
982 self.builder.operations.push(op);
983 }
984}