halo2_ecc/bigint/
check_carry_to_zero.rs

1use super::OverflowInteger;
2use halo2_base::{
3    gates::{GateInstructions, RangeInstructions},
4    utils::{bigint_to_fe, fe_to_bigint, BigPrimeField},
5    Context,
6    QuantumCell::{Constant, Existing, Witness},
7};
8use num_bigint::BigInt;
9
10// check that `a` carries to `0 mod 2^{a.limb_bits * a.limbs.len()}`
11// same as `assign` above except we need to provide `c_{k - 1}` witness as well
12// checks there exist d_i = -c_i so that
13// a_0 = c_0 * 2^n
14// a_i + c_{i - 1} = c_i * 2^n for i = 1..=k - 1
15// and c_i \in [-2^{m - n + EPSILON}, 2^{m - n + EPSILON}], with EPSILON >= 1 for i = 0..=k-1
16// where m = a.max_limb_size.bits() and we choose EPSILON to round up to the next multiple of the range check table size
17//
18// translated to d_i, this becomes:
19// a_0 + d_0 * 2^n = 0
20// a_i + d_i * 2^n = d_{i - 1} for i = 1.. k - 1
21
22// aztec optimization:
23// note that a_i + c_{i - 1} = c_i * 2^n can be expanded to
24// a_i * 2^{n*w} + a_{i - 1} * 2^{n*(w-1)} + ... + a_{i - w} + c_{i - w - 1} = c_i * 2^{n*(w+1)}
25// which is valid as long as `(m - n + EPSILON) + n * (w+1) < native_modulus::<F>().bits() - 1`
26// so we only need to range check `c_i` every `w + 1` steps, starting with `i = w`
27pub fn truncate<F: BigPrimeField>(
28    range: &impl RangeInstructions<F>,
29    ctx: &mut Context<F>,
30    a: OverflowInteger<F>,
31    limb_bits: usize,
32    limb_base: F,
33    limb_base_big: &BigInt,
34) {
35    let k = a.limbs.len();
36    let max_limb_bits = a.max_limb_bits;
37
38    let mut carries = Vec::with_capacity(k);
39
40    for a_limb in a.limbs.iter() {
41        let a_val_big = fe_to_bigint(a_limb.value());
42        let carry = if let Some(carry_val) = carries.last() {
43            (a_val_big + carry_val) / limb_base_big
44        } else {
45            // warning: using >> on negative integer produces undesired effect
46            a_val_big / limb_base_big
47        };
48        carries.push(carry);
49    }
50
51    // round `max_limb_bits - limb_bits + EPSILON + 1` up to the next multiple of range.lookup_bits
52    const EPSILON: usize = 1;
53    let range_bits = max_limb_bits - limb_bits + EPSILON;
54    let range_bits =
55        ((range_bits + range.lookup_bits()) / range.lookup_bits()) * range.lookup_bits() - 1;
56    // `window = w + 1` valid as long as `range_bits + n * (w+1) < native_modulus::<F>().bits() - 1`
57    // let window = (F::NUM_BITS as usize - 2 - range_bits) / limb_bits;
58    // assert!(window > 0);
59    // In practice, we are currently always using window = 1 so the above is commented out
60
61    let shift_val = range.gate().pow_of_two()[range_bits];
62    // let num_windows = (k - 1) / window + 1; // = ((k - 1) - (window - 1) + window - 1) / window + 1;
63
64    let mut previous = None;
65    for (a_limb, carry) in a.limbs.into_iter().zip(carries) {
66        let neg_carry_val = bigint_to_fe(&-carry);
67        ctx.assign_region(
68            [
69                Existing(a_limb),
70                Witness(neg_carry_val),
71                Constant(limb_base),
72                previous.map(Existing).unwrap_or_else(|| Constant(F::ZERO)),
73            ],
74            [0],
75        );
76        let neg_carry = ctx.get(-3);
77
78        // i in 0..num_windows {
79        // let idx = std::cmp::min(window * i + window - 1, k - 1);
80        // let carry_cell = &neg_carry_assignments[idx];
81        let shifted_carry = range.gate().add(ctx, neg_carry, Constant(shift_val));
82        range.range_check(ctx, shifted_carry, range_bits + 1);
83
84        previous = Some(neg_carry);
85    }
86}