1use std::sync::Arc;
2
3use itertools::Itertools;
4use num_bigint::{BigInt, BigUint};
5use num_integer::Integer;
6use openvm_stark_backend::p3_field::{
7 extension::{BinomialExtensionField, BinomiallyExtendable},
8 Field, FieldAlgebra, FieldExtensionAlgebra, PrimeField32, PrimeField64,
9};
10use openvm_stark_sdk::p3_baby_bear::BabyBear;
11use snark_verifier_sdk::snark_verifier::{
12 halo2_base::{
13 gates::{GateChip, GateInstructions, RangeChip, RangeInstructions},
14 halo2_proofs::halo2curves::bn256::Fr,
15 utils::{bigint_to_fe, biguint_to_fe, bit_length, fe_to_bigint, BigPrimeField},
16 AssignedValue, Context, QuantumCell,
17 },
18 util::arithmetic::{Field as _, PrimeField as _},
19};
20
21pub(crate) const BABYBEAR_MAX_BITS: usize = 31;
22const RESERVED_HIGH_BITS: usize = 2;
26
27#[derive(Copy, Clone, Debug)]
28pub struct AssignedBabyBear {
29 pub value: AssignedValue<Fr>,
37 pub max_bits: usize,
39}
40
41impl AssignedBabyBear {
42 pub fn to_baby_bear(&self) -> BabyBear {
43 let mut b_int = fe_to_bigint(self.value.value()) % BabyBear::ORDER_U32;
44 if b_int < BigInt::from(0) {
45 b_int += BabyBear::ORDER_U32;
46 }
47 BabyBear::from_canonical_u32(b_int.try_into().unwrap())
48 }
49}
50
51pub struct BabyBearChip {
52 pub range: Arc<RangeChip<Fr>>,
53}
54
55impl BabyBearChip {
56 pub fn new(range_chip: Arc<RangeChip<Fr>>) -> Self {
57 BabyBearChip { range: range_chip }
58 }
59
60 pub fn gate(&self) -> &GateChip<Fr> {
61 self.range.gate()
62 }
63
64 pub fn load_witness(&self, ctx: &mut Context<Fr>, value: BabyBear) -> AssignedBabyBear {
65 let value = ctx.load_witness(Fr::from(PrimeField64::as_canonical_u64(&value)));
66 self.range.range_check(ctx, value, BABYBEAR_MAX_BITS);
67 AssignedBabyBear {
68 value,
69 max_bits: BABYBEAR_MAX_BITS,
70 }
71 }
72
73 pub fn load_constant(&self, ctx: &mut Context<Fr>, value: BabyBear) -> AssignedBabyBear {
74 let max_bits = bit_length(value.as_canonical_u64());
75 let value = ctx.load_constant(Fr::from(PrimeField64::as_canonical_u64(&value)));
76 AssignedBabyBear { value, max_bits }
77 }
78
79 pub fn reduce(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
80 debug_assert!(fe_to_bigint(a.value.value()).bits() as usize <= a.max_bits);
81 let (_, r) = signed_div_mod(&self.range, ctx, a.value, a.max_bits);
82 let r = AssignedBabyBear {
83 value: r,
84 max_bits: BABYBEAR_MAX_BITS,
85 };
86 debug_assert_eq!(a.to_baby_bear(), r.to_baby_bear());
87 r
88 }
89
90 pub fn reduce_max_bits(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
93 if a.max_bits > BABYBEAR_MAX_BITS {
94 self.reduce(ctx, a)
95 } else {
96 a
97 }
98 }
99
100 pub fn add(
101 &self,
102 ctx: &mut Context<Fr>,
103 mut a: AssignedBabyBear,
104 mut b: AssignedBabyBear,
105 ) -> AssignedBabyBear {
106 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
107 a = self.reduce(ctx, a);
108 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
109 b = self.reduce(ctx, b);
110 }
111 }
112 let value = self.gate().add(ctx, a.value, b.value);
113 let max_bits = a.max_bits.max(b.max_bits) + 1;
114 let mut c = AssignedBabyBear { value, max_bits };
115 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() + b.to_baby_bear());
116 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
117 c = self.reduce(ctx, c);
118 }
119 c
120 }
121
122 pub fn neg(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) -> AssignedBabyBear {
123 let value = self.gate().neg(ctx, a.value);
124 let b = AssignedBabyBear {
125 value,
126 max_bits: a.max_bits,
127 };
128 debug_assert_eq!(b.to_baby_bear(), -a.to_baby_bear());
129 b
130 }
131
132 pub fn sub(
133 &self,
134 ctx: &mut Context<Fr>,
135 mut a: AssignedBabyBear,
136 mut b: AssignedBabyBear,
137 ) -> AssignedBabyBear {
138 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
139 a = self.reduce(ctx, a);
140 if a.max_bits.max(b.max_bits) + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
141 b = self.reduce(ctx, b);
142 }
143 }
144 let value = self.gate().sub(ctx, a.value, b.value);
145 let max_bits = a.max_bits.max(b.max_bits) + 1;
146 let mut c = AssignedBabyBear { value, max_bits };
147 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() - b.to_baby_bear());
148 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
149 c = self.reduce(ctx, c);
150 }
151 c
152 }
153
154 pub fn mul(
155 &self,
156 ctx: &mut Context<Fr>,
157 mut a: AssignedBabyBear,
158 mut b: AssignedBabyBear,
159 ) -> AssignedBabyBear {
160 if a.max_bits < b.max_bits {
161 std::mem::swap(&mut a, &mut b);
162 }
163 if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
164 a = self.reduce(ctx, a);
165 if a.max_bits + b.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
166 b = self.reduce(ctx, b);
167 }
168 }
169 let value = self.gate().mul(ctx, a.value, b.value);
170 let max_bits = a.max_bits + b.max_bits;
171
172 let mut c = AssignedBabyBear { value, max_bits };
173 if c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
174 c = self.reduce(ctx, c);
175 }
176 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() * b.to_baby_bear());
177 c
178 }
179
180 pub fn mul_add(
181 &self,
182 ctx: &mut Context<Fr>,
183 mut a: AssignedBabyBear,
184 mut b: AssignedBabyBear,
185 mut c: AssignedBabyBear,
186 ) -> AssignedBabyBear {
187 if a.max_bits < b.max_bits {
188 std::mem::swap(&mut a, &mut b);
189 }
190 if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
191 a = self.reduce(ctx, a);
192 if a.max_bits + b.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
193 b = self.reduce(ctx, b);
194 }
195 }
196 if c.max_bits + 1 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
197 c = self.reduce(ctx, c)
198 }
199 let value = self.gate().mul_add(ctx, a.value, b.value, c.value);
200 let max_bits = c.max_bits.max(a.max_bits + b.max_bits) + 1;
201
202 let mut d = AssignedBabyBear { value, max_bits };
203 if d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
204 d = self.reduce(ctx, d);
205 }
206 debug_assert_eq!(
207 d.to_baby_bear(),
208 a.to_baby_bear() * b.to_baby_bear() + c.to_baby_bear()
209 );
210 d
211 }
212
213 pub fn div(
214 &self,
215 ctx: &mut Context<Fr>,
216 mut a: AssignedBabyBear,
217 mut b: AssignedBabyBear,
218 ) -> AssignedBabyBear {
219 let b_val = b.to_baby_bear();
220 let b_inv = b_val.try_inverse().unwrap();
221
222 let mut c = self.load_witness(ctx, a.to_baby_bear() * b_inv);
223 if a.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
225 a = self.reduce(ctx, a);
226 }
227 if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
228 b = self.reduce(ctx, b);
229 }
230 if b.max_bits + c.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS {
231 c = self.reduce(ctx, c);
232 }
233 let diff = self.gate().sub_mul(ctx, a.value, b.value, c.value);
234 let max_bits = a.max_bits.max(b.max_bits + c.max_bits) + 1;
235 self.assert_zero(
236 ctx,
237 AssignedBabyBear {
238 value: diff,
239 max_bits,
240 },
241 );
242 debug_assert_eq!(c.to_baby_bear(), a.to_baby_bear() / b.to_baby_bear());
243 c
244 }
245
246 fn special_inner_product(
249 &self,
250 ctx: &mut Context<Fr>,
251 a: &mut [AssignedBabyBear],
252 b: &mut [AssignedBabyBear],
253 s: usize,
254 ) -> AssignedBabyBear {
255 assert!(a.len() == b.len());
256 assert!(a.len() == 4);
257 let mut max_bits = 0;
258 let lb = s.saturating_sub(3);
259 let ub = 4.min(s + 1);
260 let range = lb..ub;
261 let other_range = (s + 1 - ub)..(s + 1 - lb);
262 let len = if s < 3 { s + 1 } else { 7 - s };
263 for (i, (c, d)) in a[range.clone()]
264 .iter_mut()
265 .zip(b[other_range.clone()].iter_mut().rev())
266 .enumerate()
267 {
268 if c.max_bits + d.max_bits > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i {
269 if c.max_bits >= d.max_bits {
270 *c = self.reduce(ctx, *c);
271 if c.max_bits + d.max_bits
272 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
273 {
274 *d = self.reduce(ctx, *d);
275 }
276 } else {
277 *d = self.reduce(ctx, *d);
278 if c.max_bits + d.max_bits
279 > Fr::CAPACITY as usize - RESERVED_HIGH_BITS - len + i
280 {
281 *c = self.reduce(ctx, *c);
282 }
283 }
284 }
285 if i == 0 {
286 max_bits = c.max_bits + d.max_bits;
287 } else {
288 max_bits = max_bits.max(c.max_bits + d.max_bits) + 1
289 }
290 }
291 let a_raw = a[range]
292 .iter()
293 .map(|a| QuantumCell::Existing(a.value))
294 .collect_vec();
295 let b_raw = b[other_range]
296 .iter()
297 .rev()
298 .map(|b| QuantumCell::Existing(b.value))
299 .collect_vec();
300 let prod = self.gate().inner_product(ctx, a_raw, b_raw);
301 AssignedBabyBear {
302 value: prod,
303 max_bits,
304 }
305 }
306
307 pub fn select(
308 &self,
309 ctx: &mut Context<Fr>,
310 cond: AssignedValue<Fr>,
311 a: AssignedBabyBear,
312 b: AssignedBabyBear,
313 ) -> AssignedBabyBear {
314 let value = self.gate().select(ctx, a.value, b.value, cond);
315 let max_bits = a.max_bits.max(b.max_bits);
316 AssignedBabyBear { value, max_bits }
317 }
318
319 pub fn assert_zero(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear) {
320 debug_assert_eq!(a.to_baby_bear(), BabyBear::ZERO);
322 assert!(a.max_bits <= Fr::CAPACITY as usize - RESERVED_HIGH_BITS);
323 let a_num_bits = a.max_bits;
324 let b: BigUint = BabyBear::ORDER_U32.into();
325 let a_val = fe_to_bigint(a.value.value());
326 assert!(a_val.bits() <= a_num_bits as u64);
327 let (div, _) = a_val.div_mod_floor(&b.clone().into());
328 let div = bigint_to_fe(&div);
329 ctx.assign_region(
330 [
331 QuantumCell::Constant(Fr::ZERO),
332 QuantumCell::Constant(biguint_to_fe(&b)),
333 QuantumCell::Witness(div),
334 a.value.into(),
335 ],
336 [0],
337 );
338 let div = ctx.get(-2);
339 let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
341 let shifted_div =
342 self.range
343 .gate()
344 .add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
345 debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
346 self.range
347 .range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
348 }
349
350 pub fn assert_equal(&self, ctx: &mut Context<Fr>, a: AssignedBabyBear, b: AssignedBabyBear) {
351 debug_assert_eq!(a.to_baby_bear(), b.to_baby_bear());
352 let diff = self.sub(ctx, a, b);
353 self.assert_zero(ctx, diff);
354 }
355}
356
357fn signed_div_mod<F>(
366 range: &RangeChip<F>,
367 ctx: &mut Context<F>,
368 a: impl Into<QuantumCell<F>>,
369 a_num_bits: usize,
370) -> (AssignedValue<F>, AssignedValue<F>)
371where
372 F: BigPrimeField,
373{
374 let a = a.into();
411 let b = BigUint::from(BabyBear::ORDER_U32);
412 let a_val = fe_to_bigint(a.value());
413 assert!(a_val.bits() <= a_num_bits as u64);
414 let (div, rem) = a_val.div_mod_floor(&b.clone().into());
415 let [div, rem] = [div, rem].map(|v| bigint_to_fe(&v));
416 ctx.assign_region(
417 [
418 QuantumCell::Witness(rem),
419 QuantumCell::Constant(biguint_to_fe(&b)),
420 QuantumCell::Witness(div),
421 a,
422 ],
423 [0],
424 );
425 let rem = ctx.get(-4);
426 let div = ctx.get(-2);
427 let bound = (BigUint::from(1u32) << (a_num_bits as u32)) / &b;
429 let shifted_div = range
430 .gate()
431 .add(ctx, div, QuantumCell::Constant(biguint_to_fe(&bound)));
432 debug_assert!(*shifted_div.value() < biguint_to_fe(&(&bound * 2u32 + 1u32)));
433 range.range_check(ctx, shifted_div, (bound * 2u32 + 1u32).bits() as usize);
434 debug_assert!(*rem.value() < biguint_to_fe(&b));
435 range.check_big_less_than_safe(ctx, rem, b);
436 (div, rem)
437}
438
439pub struct BabyBearExt4Chip {
441 pub base: Arc<BabyBearChip>,
442}
443
444#[derive(Copy, Clone, Debug)]
445pub struct AssignedBabyBearExt4(pub [AssignedBabyBear; 4]);
446pub type BabyBearExt4 = BinomialExtensionField<BabyBear, 4>;
447
448impl AssignedBabyBearExt4 {
449 pub fn to_extension_field(&self) -> BabyBearExt4 {
450 let b_val = (0..4).map(|i| self.0[i].to_baby_bear()).collect_vec();
451 BabyBearExt4::from_base_slice(&b_val)
452 }
453}
454
455impl BabyBearExt4Chip {
456 pub fn new(base_chip: Arc<BabyBearChip>) -> Self {
457 BabyBearExt4Chip { base: base_chip }
458 }
459 pub fn load_witness(&self, ctx: &mut Context<Fr>, value: BabyBearExt4) -> AssignedBabyBearExt4 {
460 AssignedBabyBearExt4(
461 value
462 .as_base_slice()
463 .iter()
464 .map(|x| self.base.load_witness(ctx, *x))
465 .collect_vec()
466 .try_into()
467 .unwrap(),
468 )
469 }
470 pub fn load_constant(
471 &self,
472 ctx: &mut Context<Fr>,
473 value: BabyBearExt4,
474 ) -> AssignedBabyBearExt4 {
475 AssignedBabyBearExt4(
476 value
477 .as_base_slice()
478 .iter()
479 .map(|x| self.base.load_constant(ctx, *x))
480 .collect_vec()
481 .try_into()
482 .unwrap(),
483 )
484 }
485 pub fn add(
486 &self,
487 ctx: &mut Context<Fr>,
488 a: AssignedBabyBearExt4,
489 b: AssignedBabyBearExt4,
490 ) -> AssignedBabyBearExt4 {
491 AssignedBabyBearExt4(
492 a.0.iter()
493 .zip(b.0.iter())
494 .map(|(a, b)| self.base.add(ctx, *a, *b))
495 .collect_vec()
496 .try_into()
497 .unwrap(),
498 )
499 }
500
501 pub fn neg(&self, ctx: &mut Context<Fr>, a: AssignedBabyBearExt4) -> AssignedBabyBearExt4 {
502 AssignedBabyBearExt4(
503 a.0.iter()
504 .map(|x| self.base.neg(ctx, *x))
505 .collect_vec()
506 .try_into()
507 .unwrap(),
508 )
509 }
510
511 pub fn sub(
512 &self,
513 ctx: &mut Context<Fr>,
514 a: AssignedBabyBearExt4,
515 b: AssignedBabyBearExt4,
516 ) -> AssignedBabyBearExt4 {
517 AssignedBabyBearExt4(
518 a.0.iter()
519 .zip(b.0.iter())
520 .map(|(a, b)| self.base.sub(ctx, *a, *b))
521 .collect_vec()
522 .try_into()
523 .unwrap(),
524 )
525 }
526
527 pub fn scalar_mul(
528 &self,
529 ctx: &mut Context<Fr>,
530 a: AssignedBabyBearExt4,
531 b: AssignedBabyBear,
532 ) -> AssignedBabyBearExt4 {
533 AssignedBabyBearExt4(
534 a.0.iter()
535 .map(|x| self.base.mul(ctx, *x, b))
536 .collect_vec()
537 .try_into()
538 .unwrap(),
539 )
540 }
541
542 pub fn select(
543 &self,
544 ctx: &mut Context<Fr>,
545 cond: AssignedValue<Fr>,
546 a: AssignedBabyBearExt4,
547 b: AssignedBabyBearExt4,
548 ) -> AssignedBabyBearExt4 {
549 AssignedBabyBearExt4(
550 a.0.iter()
551 .zip(b.0.iter())
552 .map(|(a, b)| self.base.select(ctx, cond, *a, *b))
553 .collect_vec()
554 .try_into()
555 .unwrap(),
556 )
557 }
558
559 pub fn assert_zero(&self, ctx: &mut Context<Fr>, a: AssignedBabyBearExt4) {
560 for x in a.0.iter() {
561 self.base.assert_zero(ctx, *x);
562 }
563 }
564
565 pub fn assert_equal(
566 &self,
567 ctx: &mut Context<Fr>,
568 a: AssignedBabyBearExt4,
569 b: AssignedBabyBearExt4,
570 ) {
571 for (a, b) in a.0.iter().zip(b.0.iter()) {
572 self.base.assert_equal(ctx, *a, *b);
573 }
574 }
575
576 pub fn mul(
577 &self,
578 ctx: &mut Context<Fr>,
579 mut a: AssignedBabyBearExt4,
580 mut b: AssignedBabyBearExt4,
581 ) -> AssignedBabyBearExt4 {
582 let mut coeffs = Vec::with_capacity(7);
583 for s in 0..7 {
584 coeffs.push(self.base.special_inner_product(ctx, &mut a.0, &mut b.0, s));
585 }
586 let w = self
587 .base
588 .load_constant(ctx, <BabyBear as BinomiallyExtendable<4>>::W);
589 for i in 4..7 {
590 coeffs[i - 4] = self.base.mul_add(ctx, coeffs[i], w, coeffs[i - 4]);
591 }
592 coeffs.truncate(4);
593 let c = AssignedBabyBearExt4(coeffs.try_into().unwrap());
594 debug_assert_eq!(
595 c.to_extension_field(),
596 a.to_extension_field() * b.to_extension_field()
597 );
598 c
599 }
600
601 pub fn div(
602 &self,
603 ctx: &mut Context<Fr>,
604 a: AssignedBabyBearExt4,
605 b: AssignedBabyBearExt4,
606 ) -> AssignedBabyBearExt4 {
607 let b_val = b.to_extension_field();
608 let b_inv = b_val.try_inverse().unwrap();
609
610 let c = self.load_witness(ctx, a.to_extension_field() * b_inv);
611 let prod = self.mul(ctx, b, c);
613 self.assert_equal(ctx, a, prod);
614
615 debug_assert_eq!(
616 c.to_extension_field(),
617 a.to_extension_field() / b.to_extension_field()
618 );
619 c
620 }
621
622 pub fn reduce_max_bits(
623 &self,
624 ctx: &mut Context<Fr>,
625 a: AssignedBabyBearExt4,
626 ) -> AssignedBabyBearExt4 {
627 AssignedBabyBearExt4(
628 a.0.into_iter()
629 .map(|x| self.base.reduce_max_bits(ctx, x))
630 .collect::<Vec<_>>()
631 .try_into()
632 .unwrap(),
633 )
634 }
635}