halo2_ecc/bigint/
check_carry_mod_to_zero.rs

1use super::{check_carry_to_zero, CRTInteger, OverflowInteger};
2use halo2_base::{
3    gates::{GateInstructions, RangeInstructions},
4    utils::{decompose_bigint, BigPrimeField},
5    AssignedValue, Context,
6    QuantumCell::{Constant, Existing, Witness},
7};
8use num_bigint::BigInt;
9use num_integer::Integer;
10use num_traits::{One, Signed, Zero};
11use std::{cmp::max, iter};
12
13// same as carry_mod::crt but `out = 0` so no need to range check
14//
15// Assumption: the leading two bits (in big endian) are 1, and `a.max_size <= 2^{n * k - 1 + F::NUM_BITS - 2}` (A weaker assumption is also enough)
16pub fn crt<F: BigPrimeField>(
17    range: &impl RangeInstructions<F>,
18    ctx: &mut Context<F>,
19    a: CRTInteger<F>,
20    k_bits: usize, // = a.len().bits()
21    modulus: &BigInt,
22    mod_vec: &[F],
23    mod_native: F,
24    limb_bits: usize,
25    limb_bases: &[F],
26    limb_base_big: &BigInt,
27) {
28    let n = limb_bits;
29    let k = a.truncation.limbs.len();
30    let trunc_len = n * k;
31
32    debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2);
33
34    // see carry_mod.rs for explanation
35    let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize);
36    assert!(quot_max_bits < trunc_len);
37    let quot_last_limb_bits = quot_max_bits - n * (k - 1);
38
39    // these are witness vectors:
40    // we need to find `quot_vec` as a proper BigInt with k limbs
41    // we need to find `quot_native` as a native F element
42
43    // we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F`
44    let (quot_val, _out_val) = a.value.div_mod_floor(modulus);
45
46    // only perform safety checks in debug mode
47    debug_assert_eq!(_out_val, BigInt::zero());
48    debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits));
49
50    let quot_vec = decompose_bigint::<F>(&quot_val, k, n);
51
52    debug_assert!(modulus < &(BigInt::one() << (n * k)));
53
54    // We need to show `modulus * quotient - a` is:
55    // - congruent to `0 (mod 2^trunc_len)`
56    // - equal to 0 in native field `F`
57
58    // Modulo 2^trunc_len, using OverflowInteger:
59    // ------------------------------------------
60    // Goal: assign cells to `modulus * quotient - a`
61    // 1. we effectively do mul_no_carry::truncate(mod_vec, quot_vec) while assigning `mod_vec` and `quot_vec` as we go
62    //    call the output `prod` which has len k
63    // 2. for prod[i] we can compute prod - a by using the transpose of
64    //    | prod | -1 | a | prod - a |
65
66    let mut quot_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
67    let mut check_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
68
69    // match chip.strategy {
70    //    BigIntStrategy::Simple => {
71    for (i, (a_limb, quot_v)) in a.truncation.limbs.into_iter().zip(quot_vec).enumerate() {
72        let (prod, new_quot_cell) = range.gate().inner_product_left_last(
73            ctx,
74            quot_assigned.iter().map(|x| Existing(*x)).chain(iter::once(Witness(quot_v))),
75            mod_vec[0..=i].iter().rev().map(|c| Constant(*c)),
76        );
77
78        // perform step 2: compute prod - a + out
79        // transpose of:
80        // | prod | -1 | a | prod - a |
81        let check_val = *prod.value() - a_limb.value();
82        let check_cell =
83            ctx.assign_region_last([Constant(-F::ONE), Existing(a_limb), Witness(check_val)], [-1]);
84
85        quot_assigned.push(new_quot_cell);
86        check_assigned.push(check_cell);
87    }
88    //    }
89    // }
90
91    // range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits)
92    for (q_index, quot_cell) in quot_assigned.iter().enumerate() {
93        let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n };
94        let limb_base =
95            if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] };
96
97        // compute quot_cell + 2^n and range check with n + 1 bits
98        let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base));
99        range.range_check(ctx, quot_shift, limb_bits + 1);
100    }
101
102    let check_overflow_int =
103        OverflowInteger::new(check_assigned, max(a.truncation.max_limb_bits, 2 * n + k_bits));
104
105    // check that `modulus * quotient - a == 0 mod 2^{trunc_len}` after carry
106    check_carry_to_zero::truncate::<F>(
107        range,
108        ctx,
109        check_overflow_int,
110        limb_bits,
111        limb_bases[1],
112        limb_base_big,
113    );
114
115    // Constrain `quot_native = sum_i out_assigned[i] * 2^{n*i}` in `F`
116    let quot_native =
117        OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);
118
119    // Check `0 + modulus * quotient - a = 0` in native field
120    // | 0 | modulus | quotient | a |
121    ctx.assign_region(
122        [Constant(F::ZERO), Constant(mod_native), Existing(quot_native), Existing(a.native)],
123        [0],
124    );
125}