halo2_ecc/bigint/check_carry_to_zero.rs
1use super::OverflowInteger;
2use halo2_base::{
3 gates::{GateInstructions, RangeInstructions},
4 utils::{bigint_to_fe, fe_to_bigint, BigPrimeField},
5 Context,
6 QuantumCell::{Constant, Existing, Witness},
7};
8use num_bigint::BigInt;
9
10// check that `a` carries to `0 mod 2^{a.limb_bits * a.limbs.len()}`
11// same as `assign` above except we need to provide `c_{k - 1}` witness as well
12// checks there exist d_i = -c_i so that
13// a_0 = c_0 * 2^n
14// a_i + c_{i - 1} = c_i * 2^n for i = 1..=k - 1
15// and c_i \in [-2^{m - n + EPSILON}, 2^{m - n + EPSILON}], with EPSILON >= 1 for i = 0..=k-1
16// where m = a.max_limb_size.bits() and we choose EPSILON to round up to the next multiple of the range check table size
17//
18// translated to d_i, this becomes:
19// a_0 + d_0 * 2^n = 0
20// a_i + d_i * 2^n = d_{i - 1} for i = 1.. k - 1
21
22// aztec optimization:
23// note that a_i + c_{i - 1} = c_i * 2^n can be expanded to
24// a_i * 2^{n*w} + a_{i - 1} * 2^{n*(w-1)} + ... + a_{i - w} + c_{i - w - 1} = c_i * 2^{n*(w+1)}
25// which is valid as long as `(m - n + EPSILON) + n * (w+1) < native_modulus::<F>().bits() - 1`
26// so we only need to range check `c_i` every `w + 1` steps, starting with `i = w`
27pub fn truncate<F: BigPrimeField>(
28 range: &impl RangeInstructions<F>,
29 ctx: &mut Context<F>,
30 a: OverflowInteger<F>,
31 limb_bits: usize,
32 limb_base: F,
33 limb_base_big: &BigInt,
34) {
35 let k = a.limbs.len();
36 let max_limb_bits = a.max_limb_bits;
37
38 let mut carries = Vec::with_capacity(k);
39
40 for a_limb in a.limbs.iter() {
41 let a_val_big = fe_to_bigint(a_limb.value());
42 let carry = if let Some(carry_val) = carries.last() {
43 (a_val_big + carry_val) / limb_base_big
44 } else {
45 // warning: using >> on negative integer produces undesired effect
46 a_val_big / limb_base_big
47 };
48 carries.push(carry);
49 }
50
51 // round `max_limb_bits - limb_bits + EPSILON + 1` up to the next multiple of range.lookup_bits
52 const EPSILON: usize = 1;
53 let range_bits = max_limb_bits - limb_bits + EPSILON;
54 let range_bits =
55 ((range_bits + range.lookup_bits()) / range.lookup_bits()) * range.lookup_bits() - 1;
56 // `window = w + 1` valid as long as `range_bits + n * (w+1) < native_modulus::<F>().bits() - 1`
57 // let window = (F::NUM_BITS as usize - 2 - range_bits) / limb_bits;
58 // assert!(window > 0);
59 // In practice, we are currently always using window = 1 so the above is commented out
60
61 let shift_val = range.gate().pow_of_two()[range_bits];
62 // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1;
63
64 let mut previous = None;
65 for (a_limb, carry) in a.limbs.into_iter().zip(carries) {
66 let neg_carry_val = bigint_to_fe(&-carry);
67 ctx.assign_region(
68 [
69 Existing(a_limb),
70 Witness(neg_carry_val),
71 Constant(limb_base),
72 previous.map(Existing).unwrap_or_else(|| Constant(F::ZERO)),
73 ],
74 [0],
75 );
76 let neg_carry = ctx.get(-3);
77
78 // i in 0..num_windows {
79 // let idx = std::cmp::min(window * i + window - 1, k - 1);
80 // let carry_cell = &neg_carry_assignments[idx];
81 let shifted_carry = range.gate().add(ctx, neg_carry, Constant(shift_val));
82 range.range_check(ctx, shifted_carry, range_bits + 1);
83
84 previous = Some(neg_carry);
85 }
86}