halo2_ecc/bigint/
carry_mod.rs

1use std::{cmp::max, iter};
2
3use halo2_base::{
4    gates::{GateInstructions, RangeInstructions},
5    utils::{decompose_bigint, BigPrimeField},
6    AssignedValue, Context,
7    QuantumCell::{Constant, Existing, Witness},
8};
9use num_bigint::BigInt;
10use num_integer::Integer;
11use num_traits::{One, Signed};
12
13use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint};
14
15// Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs
16// Output is `out = a (mod modulus)` as CRTInteger with
17// `out.value = a.value (mod modulus)`
18// `out.trunction = (a (mod modulus)) % 2^t` a proper BigInt of length `k` with limbs in [0, 2^limb_bits)`
19// The witness for `out.truncation` is a BigInt in [0, modulus), but we do not constrain the inequality
20// `out.native = (a (mod modulus)) % (native_modulus::<F>)`
21// We constrain `a = out + modulus * quotient` and range check `out` and `quotient`
22//
23// Assumption: the leading two bits (in big endian) are 1,
24/// # Assumptions
25/// * abs(a) <= 2<sup>n * k - 1 + F::NUM_BITS - 2</sup> (A weaker assumption is also enough, but this is good enough for forseeable use cases)
26/// * `native_modulus::<F>` requires *exactly* `k = a.limbs.len()` limbs to represent
27// This is currently optimized for limbs greater than 64 bits, so we need `F` to be a `BigPrimeField`
28// In the future we'll need a slightly different implementation for limbs that fit in 32 or 64 bits (e.g., `F` is Goldilocks)
29pub fn crt<F: BigPrimeField>(
30    range: &impl RangeInstructions<F>,
31    // chip: &BigIntConfig<F>,
32    ctx: &mut Context<F>,
33    a: CRTInteger<F>,
34    k_bits: usize, // = a.len().bits()
35    modulus: &BigInt,
36    mod_vec: &[F],
37    mod_native: F,
38    limb_bits: usize,
39    limb_bases: &[F],
40    limb_base_big: &BigInt,
41) -> ProperCrtUint<F> {
42    let n = limb_bits;
43    let k = a.truncation.limbs.len();
44    let trunc_len = n * k;
45
46    debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2);
47
48    // in order for CRT method to work, we need `abs(out + modulus * quotient - a) < 2^{trunc_len - 1} * native_modulus::<F>`
49    // this is ensured if `0 <= out < 2^{n*k}` and
50    // `abs(modulus * quotient) < 2^{trunc_len - 1} * native_modulus::<F> - abs(a)
51    // which is ensured if
52    // `abs(modulus * quotient) < 2^{trunc_len - 1 + F::NUM_BITS - 1} <= 2^{trunc_len - 1} * native_modulus::<F> - abs(a)` given our assumption `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}`
53    let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize);
54    debug_assert!(quot_max_bits < trunc_len);
55    // Let n' <= quot_max_bits - n(k-1) - 1
56    // If quot[i] <= 2^n for i < k - 1 and quot[k-1] <= 2^{n'} then
57    // quot < 2^{n(k-1)+1} + 2^{n' + n(k-1)} = (2+2^{n'}) 2^{n(k-1)} < 2^{n'+1} * 2^{n(k-1)} <= 2^{quot_max_bits - n(k-1)} * 2^{n(k-1)}
58    let quot_last_limb_bits = quot_max_bits - n * (k - 1);
59
60    let out_max_bits = modulus.bits() as usize;
61    // we assume `modulus` requires *exactly* `k` limbs to represent (if `< k` limbs ok, you should just be using that)
62    let out_last_limb_bits = out_max_bits - n * (k - 1);
63
64    // these are witness vectors:
65    // we need to find `out_vec` as a proper BigInt with k limbs
66    // we need to find `quot_vec` as a proper BigInt with k limbs
67
68    let (quot_val, out_val) = a.value.div_mod_floor(modulus);
69
70    debug_assert!(out_val < (BigInt::one() << (n * k)));
71    debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits));
72
73    // decompose_bigint just throws away signed limbs in index >= k
74    let out_vec = decompose_bigint::<F>(&out_val, k, n);
75    let quot_vec = decompose_bigint::<F>(&quot_val, k, n);
76
77    // we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F`
78    // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F`
79
80    // assert!(modulus < &(BigUint::one() << (n * k)));
81    assert_eq!(mod_vec.len(), k);
82    // We need to show `out - a + modulus * quotient` is:
83    // - congruent to `0 (mod 2^trunc_len)`
84    // - equal to 0 in native field `F`
85
86    // Modulo 2^trunc_len, using OverflowInteger:
87    // ------------------------------------------
88    // Goal: assign cells to `out - a + modulus * quotient`
89    // 1. we effectively do mul_no_carry::truncate(mod_vec, quot_vec) while assigning `mod_vec` and `quot_vec` as we go
90    //    call the output `prod` which has len k
91    // 2. for prod[i] we can compute `prod + out - a`
92    //    where we assign `out_vec` as we go
93
94    let mut quot_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
95    let mut out_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
96    let mut check_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
97
98    // strategies where we carry out school-book multiplication in some form:
99    //    BigIntStrategy::Simple => {
100    for (i, ((a_limb, quot_v), out_v)) in
101        a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate()
102    {
103        let (prod, new_quot_cell) = range.gate().inner_product_left_last(
104            ctx,
105            quot_assigned.iter().map(|a| Existing(*a)).chain(iter::once(Witness(quot_v))),
106            mod_vec[..=i].iter().rev().map(|c| Constant(*c)),
107        );
108        // let gate_index = prod.column();
109
110        // perform step 2: compute prod - a + out
111        let temp1 = *prod.value() - a_limb.value();
112        let check_val = temp1 + out_v;
113
114        // transpose of:
115        // | prod | -1 | a | prod - a | 1 | out | prod - a + out
116        // where prod is at relative row `offset`
117        ctx.assign_region(
118            [
119                Constant(-F::ONE),
120                Existing(a_limb),
121                Witness(temp1),
122                Constant(F::ONE),
123                Witness(out_v),
124                Witness(check_val),
125            ],
126            [-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call
127        );
128        let check_cell = ctx.last().unwrap();
129        let out_cell = ctx.get(-2);
130
131        quot_assigned.push(new_quot_cell);
132        out_assigned.push(out_cell);
133        check_assigned.push(check_cell);
134    }
135    //    }
136    //}
137
138    // range check limbs of `out` are in [0, 2^n) except last limb should be in [0, 2^out_last_limb_bits)
139    for (out_index, out_cell) in out_assigned.iter().enumerate() {
140        let limb_bits = if out_index == k - 1 { out_last_limb_bits } else { n };
141        range.range_check(ctx, *out_cell, limb_bits);
142    }
143
144    // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits)
145    for (q_index, quot_cell) in quot_assigned.iter().enumerate() {
146        let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n };
147        let limb_base =
148            if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] };
149
150        // compute quot_cell + 2^n and range check with n + 1 bits
151        let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base));
152        range.range_check(ctx, quot_shift, limb_bits + 1);
153    }
154
155    let check_overflow_int = OverflowInteger::new(
156        check_assigned,
157        max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits),
158    );
159
160    // check that `out - a + modulus * quotient == 0 mod 2^{trunc_len}` after carry
161    check_carry_to_zero::truncate::<F>(
162        range,
163        ctx,
164        check_overflow_int,
165        limb_bits,
166        limb_bases[1],
167        limb_base_big,
168    );
169
170    // Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F`
171    let quot_native =
172        OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);
173
174    // Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F`
175    let out_native =
176        OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases);
177    // We save 1 cell by connecting `out_native` computation with the following:
178
179    // Check `out + modulus * quotient - a = 0` in native field
180    // | out | modulus | quotient | a |
181    ctx.assign_region(
182        [Constant(mod_native), Existing(quot_native), Existing(a.native)],
183        [-1], // negative index because -1 relative offset is `out_native` assigned value
184    );
185
186    ProperCrtUint(CRTInteger::new(
187        ProperUint(out_assigned).into_overflow(limb_bits),
188        out_native,
189        out_val,
190    ))
191}