1use std::sync::Arc;
2
3use itertools::Itertools;
4use num_bigint::{BigInt, BigUint};
5use num_integer::Integer;
6use openvm_stark_backend::p3_field::{
7 extension::{BinomialExtensionField, BinomiallyExtendable},
8 Field, FieldAlgebra, FieldExtensionAlgebra, PrimeField32, PrimeField64,
9};
10use openvm_stark_sdk::p3_baby_bear::BabyBear;
11use snark_verifier_sdk::snark_verifier::{
12 halo2_base::{
13 gates::{GateChip, GateInstructions, RangeChip, RangeInstructions},
14 halo2_proofs::halo2curves::bn256::Fr,
15 utils::{bigint_to_fe, biguint_to_fe, bit_length, fe_to_bigint, BigPrimeField},
16 AssignedValue, Context, QuantumCell,
17 },
18 util::arithmetic::{Field as _, PrimeField as _},
19};
20
21pub(crate) const BABYBEAR_MAX_BITS: usize = 31;
22const RESERVED_HIGH_BITS: usize = 2;
26
27#[derive(Copy, Clone, Debug)]
28pub struct AssignedBabyBear {
29 pub value: AssignedValue<Fr>,
37 pub max_bits: usize,
39}
40
41impl AssignedBabyBear {
42 pub fn to_baby_bear(&self) -> BabyBear {
43 let mut b_int = fe_to_bigint(self.value.value()) % BabyBear::ORDER_U32;
44 if b_int < BigInt::from(0) {
45 b_int += BabyBear::ORDER_U32;
46 }
47 BabyBear::from_canonical_u32(b_int.try_into().unwrap())
48 }
49}
50
51pub struct BabyBearChip {
52 pub range: Arc<RangeChip<Fr>>,
53}
54
55impl BabyBearChip {
56 pub fn new(range_chip: Arc<RangeChip<Fr>>) -> Self {
57 BabyBearChip { range: range_chip }
58 }
59
60 pub fn gate(&self) -> &GateChip<Fr> {
61 self.range.gate()
62 }
63
64 pub fn load_witness(&self, ctx: &mut Context<Fr>, value: BabyBear) -> AssignedBabyBear {
65 let value = ctx.load_witness(Fr::from(PrimeField64::as_canonical_u64(&value)));
66 self.range.range_check(ctx, value, BABYBEAR_MAX_BITS);
67 AssignedBabyBear {
68 value,
69 max_bits: BABYBEAR_MAX_BITS,
70 }
71 }
72
73 pub fn load_constant(&self, ctx: &mut Context<Fr>, value: BabyBear) -> AssignedBabyBear {
74 let max_bits = bit_length(value.as_canonical_u64());
75 let value = ctx.load_constant(Fr::from(PrimeField64::as_canonical_u64(&value)));
76 AssignedBabyBear { value, max_bits }
77 }
78
79 pub fn reduce(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
80 debug_assert!(fe_to_bigint(a.value.value()).bits() as usize <= a.max_bits);
81 let (_, r) = signed_div_mod(&self.range, ctx, a.value, a.max_bits);
82 let r = AssignedBabyBear {
83 value: r,
84 max_bits: BABYBEAR_MAX_BITS,
85 };
86 debug_assert_eq!(a.to_baby_bear(), r.to_baby_bear());
87 r
88 }
89
90 pub fn reduce_max_bits(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
92 if a.max_bits > BABYBEAR_MAX_BITS {
93 self.reduce(ctx, a)
94 } else {
95 a
96 }
97 }
98
99 pub fn add(
100 &self,
101 ctx: &mut Context<Fr>,
102 mut a: AssignedBabyBear,
103 mut b: AssignedBabyBear,
104 ) -> AssignedBabyBear {
105 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
106 a = self.reduce(ctx, a);
107 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
108 b = self.reduce(ctx, b);
109 }
110 }
111 let value = self.gate().add(ctx, a.value, b.value);
112 let max_bits = a.max_bits.max(b.max_bits) + 1;
113 let mut c = AssignedBabyBear { value, max_bits };
114 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() + b.to_baby_bear());
115 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
116 c = self.reduce(ctx, c);
117 }
118 c
119 }
120
121 pub fn neg(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
122 let value = self.gate().neg(ctx, a.value);
123 let b = AssignedBabyBear {
124 value,
125 max_bits: a.max_bits,
126 };
127 debug_assert_eq!(b.to_baby_bear(), -a.to_baby_bear());
128 b
129 }
130
131 pub fn sub(
132 &self,
133 ctx: &mut Context<Fr>,
134 mut a: AssignedBabyBear,
135 mut b: AssignedBabyBear,
136 ) -> AssignedBabyBear {
137 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
138 a = self.reduce(ctx, a);
139 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
140 b = self.reduce(ctx, b);
141 }
142 }
143 let value = self.gate().sub(ctx, a.value, b.value);
144 let max_bits = a.max_bits.max(b.max_bits) + 1;
145 let mut c = AssignedBabyBear { value, max_bits };
146 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() - b.to_baby_bear());
147 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
148 c = self.reduce(ctx, c);
149 }
150 c
151 }
152
153 pub fn mul(
154 &self,
155 ctx: &mut Context<Fr>,
156 mut a: AssignedBabyBear,
157 mut b: AssignedBabyBear,
158 ) -> AssignedBabyBear {
159 if a.max_bits < b.max_bits {
160 std::mem::swap(&mut a, &mut b);
161 }
162 if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
163 a = self.reduce(ctx, a);
164 if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
165 b = self.reduce(ctx, b);
166 }
167 }
168 let value = self.gate().mul(ctx, a.value, b.value);
169 let max_bits = a.max_bits + b.max_bits;
170
171 let mut c = AssignedBabyBear { value, max_bits };
172 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
173 c = self.reduce(ctx, c);
174 }
175 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() * b.to_baby_bear());
176 c
177 }
178
179 pub fn mul_add(
180 &self,
181 ctx: &mut Context<Fr>,
182 mut a: AssignedBabyBear,
183 mut b: AssignedBabyBear,
184 mut c: AssignedBabyBear,
185 ) -> AssignedBabyBear {
186 if a.max_bits < b.max_bits {
187 std::mem::swap(&mut a, &mut b);
188 }
189 if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
190 a = self.reduce(ctx, a);
191 if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
192 b = self.reduce(ctx, b);
193 }
194 }
195 if c.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
196 c = self.reduce(ctx, c)
197 }
198 let value = self.gate().mul_add(ctx, a.value, b.value, c.value);
199 let max_bits = c.max_bits.max(a.max_bits + b.max_bits) + 1;
200
201 let mut d = AssignedBabyBear { value, max_bits };
202 if d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
203 d = self.reduce(ctx, d);
204 }
205 debug_assert_eq!(
206 d.to_baby_bear(),
207 a.to_baby_bear() * b.to_baby_bear() + c.to_baby_bear()
208 );
209 d
210 }
211
212 pub fn div(
213 &self,
214 ctx: &mut Context<Fr>,
215 mut a: AssignedBabyBear,
216 mut b: AssignedBabyBear,
217 ) -> AssignedBabyBear {
218 let b_val = b.to_baby_bear();
219 let b_inv = b_val.try_inverse().unwrap();
220
221 let mut c = self.load_witness(ctx, a.to_baby_bear() * b_inv);
222 if a.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
224 a = self.reduce(ctx, a);
225 }
226 if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
227 b = self.reduce(ctx, b);
228 }
229 if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
230 c = self.reduce(ctx, c);
231 }
232 let diff = self.gate().sub_mul(ctx, a.value, b.value, c.value);
233 let max_bits = a.max_bits.max(b.max_bits + c.max_bits) + 1;
234 self.assert_zero(
235 ctx,
236 AssignedBabyBear {
237 value: diff,
238 max_bits,
239 },
240 );
241 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() / b.to_baby_bear());
242 c
243 }
244
245 fn special_inner_product(
247 &self,
248 ctx: &mut Context<Fr>,
249 a: &mut [AssignedBabyBear],
250 b: &mut [AssignedBabyBear],
251 s: usize,
252 ) -> AssignedBabyBear {
253 assert!(a.len() == b.len());
254 assert!(a.len() == 4);
255 let mut max_bits = 0;
256 let lb = if s > 3 { s - 3 } else { 0 };
257 let ub = 4.min(s + 1);
258 let range = lb..ub;
259 let other_range = (s + 1 - ub)..(s + 1 - lb);
260 let len = if s < 3 { s + 1 } else { 7 - s };
261 for (i, (c, d)) in a[range.clone()]
262 .iter_mut()
263 .zip(b[other_range.clone()].iter_mut().rev())
264 .enumerate()
265 {
266 if c.max_bits + d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i {
267 if c.max_bits >= d.max_bits {
268 *c = self.reduce(ctx, *c);
269 if c.max_bits + d.max_bits
270 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
271 {
272 *d = self.reduce(ctx, *d);
273 }
274 } else {
275 *d = self.reduce(ctx, *d);
276 if c.max_bits + d.max_bits
277 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
278 {
279 *c = self.reduce(ctx, *c);
280 }
281 }
282 }
283 if i == 0 {
284 max_bits = c.max_bits + d.max_bits;
285 } else {
286 max_bits = max_bits.max(c.max_bits + d.max_bits) + 1
287 }
288 }
289 let a_raw = a[range]
290 .iter()
291 .map(|a| QuantumCell::Existing(a.value))
292 .collect_vec();
293 let b_raw = b[other_range]
294 .iter()
295 .rev()
296 .map(|b| QuantumCell::Existing(b.value))
297 .collect_vec();
298 let prod = self.gate().inner_product(ctx, a_raw, b_raw);
299 AssignedBabyBear {
300 value: prod,
301 max_bits,
302 }
303 }
304
305 pub fn select(
306 &self,
307 ctx: &mut Context<Fr>,
308 cond: AssignedValue<Fr>,
309 a: AssignedBabyBear,
310 b: AssignedBabyBear,
311 ) -> AssignedBabyBear {
312 let value = self.gate().select(ctx, a.value, b.value, cond);
313 let max_bits = a.max_bits.max(b.max_bits);
314 AssignedBabyBear { value, max_bits }
315 }
316
317 pub fn assert_zero(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) {
318 debug_assert_eq!(a.to_baby_bear(), BabyBear::ZERO);
320 assert!(a.max_bits <= Fr::CAPACITY as usize - RESERVED_HIGH_BITS);
321 let a_num_bits = a.max_bits;
322 let b: BigUint = BabyBear::ORDER_U32.into();
323 let a_val = fe_to_bigint(a.value.value());
324 assert!(a_val.bits() <= a_num_bits as u64);
325 let (div, _) = a_val.div_mod_floor(&b.clone().into());
326 let div = bigint_to_fe(&div);
327 ctx.assign_region(
328 [
329 QuantumCell::Constant(Fr::ZERO),
330 QuantumCell::Constant(biguint_to_fe(&b)),
331 QuantumCell::Witness(div),
332 a.value.into(),
333 ],
334 [0],
335 );
336 let div = ctx.get(-2);
337 let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
339 let shifted_div =
340 self.range
341 .gate()
342 .add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
343 debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
344 self.range
345 .range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
346 }
347
348 pub fn assert_equal(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear, b: AssignedBabyBear) {
349 debug_assert_eq!(a.to_baby_bear(), b.to_baby_bear());
350 let diff = self.sub(ctx, a, b);
351 self.assert_zero(ctx, diff);
352 }
353}
354
355fn signed_div_mod<F>(
364 range: &RangeChip<F>,
365 ctx: &mut Context<F>,
366 a: impl Into<QuantumCell<F>>,
367 a_num_bits: usize,
368) -> (AssignedValue<F>, AssignedValue<F>)
369where
370 F: BigPrimeField,
371{
372 let a = a.into();
409 let b = BigUint::from(BabyBear::ORDER_U32);
410 let a_val = fe_to_bigint(a.value());
411 assert!(a_val.bits() <= a_num_bits as u64);
412 let (div, rem) = a_val.div_mod_floor(&b.clone().into());
413 let [div, rem] = [div, rem].map(|v| bigint_to_fe(&v));
414 ctx.assign_region(
415 [
416 QuantumCell::Witness(rem),
417 QuantumCell::Constant(biguint_to_fe(&b)),
418 QuantumCell::Witness(div),
419 a,
420 ],
421 [0],
422 );
423 let rem = ctx.get(-4);
424 let div = ctx.get(-2);
425 let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
427 let shifted_div = range
428 .gate()
429 .add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
430 debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
431 range.range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
432 debug_assert!(*rem.value() < biguint_to_fe(&b));
433 range.check_big_less_than_safe(ctx, rem, b);
434 (div, rem)
435}
436
437pub struct BabyBearExt4Chip {
439 pub base: Arc<BabyBearChip>,
440}
441
442#[derive(Copy, Clone, Debug)]
443pub struct AssignedBabyBearExt4(pub [AssignedBabyBear; 4]);
444pub type BabyBearExt4 = BinomialExtensionField<BabyBear, 4>;
445
446impl AssignedBabyBearExt4 {
447 pub fn to_extension_field(&self) -> BabyBearExt4 {
448 let b_val = (0..4).map(|i| self.0[i].to_baby_bear()).collect_vec();
449 BabyBearExt4::from_base_slice(&b_val)
450 }
451}
452
453impl BabyBearExt4Chip {
454 pub fn new(base_chip: Arc<BabyBearChip>) -> Self {
455 BabyBearExt4Chip { base: base_chip }
456 }
457 pub fn load_witness(&self, ctx: &mut Context<Fr>, value: BabyBearExt4) -> AssignedBabyBearExt4 {
458 AssignedBabyBearExt4(
459 value
460 .as_base_slice()
461 .iter()
462 .map(|x| self.base.load_witness(ctx, *x))
463 .collect_vec()
464 .try_into()
465 .unwrap(),
466 )
467 }
468 pub fn load_constant(
469 &self,
470 ctx: &mut Context<Fr>,
471 value: BabyBearExt4,
472 ) -> AssignedBabyBearExt4 {
473 AssignedBabyBearExt4(
474 value
475 .as_base_slice()
476 .iter()
477 .map(|x| self.base.load_constant(ctx, *x))
478 .collect_vec()
479 .try_into()
480 .unwrap(),
481 )
482 }
483 pub fn add(
484 &self,
485 ctx: &mut Context<Fr>,
486 a: AssignedBabyBearExt4,
487 b: AssignedBabyBearExt4,
488 ) -> AssignedBabyBearExt4 {
489 AssignedBabyBearExt4(
490 a.0.iter()
491 .zip(b.0.iter())
492 .map(|(a, b)| self.base.add(ctx, *a, *b))
493 .collect_vec()
494 .try_into()
495 .unwrap(),
496 )
497 }
498
499 pub fn neg(&self, ctx: &mut Context<Fr>, a: AssignedBabyBearExt4) -> AssignedBabyBearExt4 {
500 AssignedBabyBearExt4(
501 a.0.iter()
502 .map(|x| self.base.neg(ctx, *x))
503 .collect_vec()
504 .try_into()
505 .unwrap(),
506 )
507 }
508
509 pub fn sub(
510 &self,
511 ctx: &mut Context<Fr>,
512 a: AssignedBabyBearExt4,
513 b: AssignedBabyBearExt4,
514 ) -> AssignedBabyBearExt4 {
515 AssignedBabyBearExt4(
516 a.0.iter()
517 .zip(b.0.iter())
518 .map(|(a, b)| self.base.sub(ctx, *a, *b))
519 .collect_vec()
520 .try_into()
521 .unwrap(),
522 )
523 }
524
525 pub fn scalar_mul(
526 &self,
527 ctx: &mut Context<Fr>,
528 a: AssignedBabyBearExt4,
529 b: AssignedBabyBear,
530 ) -> AssignedBabyBearExt4 {
531 AssignedBabyBearExt4(
532 a.0.iter()
533 .map(|x| self.base.mul(ctx, *x, b))
534 .collect_vec()
535 .try_into()
536 .unwrap(),
537 )
538 }
539
540 pub fn select(
541 &self,
542 ctx: &mut Context<Fr>,
543 cond: AssignedValue<Fr>,
544 a: AssignedBabyBearExt4,
545 b: AssignedBabyBearExt4,
546 ) -> AssignedBabyBearExt4 {
547 AssignedBabyBearExt4(
548 a.0.iter()
549 .zip(b.0.iter())
550 .map(|(a, b)| self.base.select(ctx, cond, *a, *b))
551 .collect_vec()
552 .try_into()
553 .unwrap(),
554 )
555 }
556
557 pub fn assert_zero(&self, ctx: &mut Context<Fr>, a: AssignedBabyBearExt4) {
558 for x in a.0.iter() {
559 self.base.assert_zero(ctx, *x);
560 }
561 }
562
563 pub fn assert_equal(
564 &self,
565 ctx: &mut Context<Fr>,
566 a: AssignedBabyBearExt4,
567 b: AssignedBabyBearExt4,
568 ) {
569 for (a, b) in a.0.iter().zip(b.0.iter()) {
570 self.base.assert_equal(ctx, *a, *b);
571 }
572 }
573
574 pub fn mul(
575 &self,
576 ctx: &mut Context<Fr>,
577 mut a: AssignedBabyBearExt4,
578 mut b: AssignedBabyBearExt4,
579 ) -> AssignedBabyBearExt4 {
580 let mut coeffs = Vec::with_capacity(7);
581 for s in 0..7 {
582 coeffs.push(self.base.special_inner_product(ctx, &mut a.0, &mut b.0, s));
583 }
584 let w = self
585 .base
586 .load_constant(ctx, <BabyBear as BinomiallyExtendable<4>>::W);
587 for i in 4..7 {
588 coeffs[i - 4] = self.base.mul_add(ctx, coeffs[i], w, coeffs[i - 4]);
589 }
590 coeffs.truncate(4);
591 let c = AssignedBabyBearExt4(coeffs.try_into().unwrap());
592 debug_assert_eq!(
593 c.to_extension_field(),
594 a.to_extension_field() * b.to_extension_field()
595 );
596 c
597 }
598
599 pub fn div(
600 &self,
601 ctx: &mut Context<Fr>,
602 a: AssignedBabyBearExt4,
603 b: AssignedBabyBearExt4,
604 ) -> AssignedBabyBearExt4 {
605 let b_val = b.to_extension_field();
606 let b_inv = b_val.try_inverse().unwrap();
607
608 let c = self.load_witness(ctx, a.to_extension_field() * b_inv);
609 let prod = self.mul(ctx, b, c);
611 self.assert_equal(ctx, a, prod);
612
613 debug_assert_eq!(
614 c.to_extension_field(),
615 a.to_extension_field() / b.to_extension_field()
616 );
617 c
618 }
619
620 pub fn reduce_max_bits(
621 &self,
622 ctx: &mut Context<Fr>,
623 a: AssignedBabyBearExt4,
624 ) -> AssignedBabyBearExt4 {
625 AssignedBabyBearExt4(
626 a.0.into_iter()
627 .map(|x| self.base.reduce_max_bits(ctx, x))
628 .collect::<Vec<_>>()
629 .try_into()
630 .unwrap(),
631 )
632 }
633}