halo2_ecc/bigint/carry_mod.rs
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
use std::{cmp::max, iter};
use halo2_base::{
gates::{GateInstructions, RangeInstructions},
utils::{decompose_bigint, BigPrimeField},
AssignedValue, Context,
QuantumCell::{Constant, Existing, Witness},
};
use num_bigint::BigInt;
use num_integer::Integer;
use num_traits::{One, Signed};
use super::{check_carry_to_zero, CRTInteger, OverflowInteger, ProperCrtUint, ProperUint};
// Input `a` is `CRTInteger` with `a.truncation` of length `k` with "signed" limbs
// Output is `out = a (mod modulus)` as CRTInteger with
// `out.value = a.value (mod modulus)`
// `out.trunction = (a (mod modulus)) % 2^t` a proper BigInt of length `k` with limbs in [0, 2^limb_bits)`
// The witness for `out.truncation` is a BigInt in [0, modulus), but we do not constrain the inequality
// `out.native = (a (mod modulus)) % (native_modulus::<F>)`
// We constrain `a = out + modulus * quotient` and range check `out` and `quotient`
//
// Assumption: the leading two bits (in big endian) are 1,
/// # Assumptions
/// * abs(a) <= 2<sup>n * k - 1 + F::NUM_BITS - 2</sup> (A weaker assumption is also enough, but this is good enough for forseeable use cases)
/// * `native_modulus::<F>` requires *exactly* `k = a.limbs.len()` limbs to represent
// This is currently optimized for limbs greater than 64 bits, so we need `F` to be a `BigPrimeField`
// In the future we'll need a slightly different implementation for limbs that fit in 32 or 64 bits (e.g., `F` is Goldilocks)
pub fn crt<F: BigPrimeField>(
range: &impl RangeInstructions<F>,
// chip: &BigIntConfig<F>,
ctx: &mut Context<F>,
a: CRTInteger<F>,
k_bits: usize, // = a.len().bits()
modulus: &BigInt,
mod_vec: &[F],
mod_native: F,
limb_bits: usize,
limb_bases: &[F],
limb_base_big: &BigInt,
) -> ProperCrtUint<F> {
let n = limb_bits;
let k = a.truncation.limbs.len();
let trunc_len = n * k;
debug_assert!(a.value.bits() as usize <= n * k - 1 + (F::NUM_BITS as usize) - 2);
// in order for CRT method to work, we need `abs(out + modulus * quotient - a) < 2^{trunc_len - 1} * native_modulus::<F>`
// this is ensured if `0 <= out < 2^{n*k}` and
// `abs(modulus * quotient) < 2^{trunc_len - 1} * native_modulus::<F> - abs(a)
// which is ensured if
// `abs(modulus * quotient) < 2^{trunc_len - 1 + F::NUM_BITS - 1} <= 2^{trunc_len - 1} * native_modulus::<F> - abs(a)` given our assumption `abs(a) <= 2^{n * k - 1 + F::NUM_BITS - 2}`
let quot_max_bits = trunc_len - 1 + (F::NUM_BITS as usize) - 1 - (modulus.bits() as usize);
debug_assert!(quot_max_bits < trunc_len);
// Let n' <= quot_max_bits - n(k-1) - 1
// If quot[i] <= 2^n for i < k - 1 and quot[k-1] <= 2^{n'} then
// quot < 2^{n(k-1)+1} + 2^{n' + n(k-1)} = (2+2^{n'}) 2^{n(k-1)} < 2^{n'+1} * 2^{n(k-1)} <= 2^{quot_max_bits - n(k-1)} * 2^{n(k-1)}
let quot_last_limb_bits = quot_max_bits - n * (k - 1);
let out_max_bits = modulus.bits() as usize;
// we assume `modulus` requires *exactly* `k` limbs to represent (if `< k` limbs ok, you should just be using that)
let out_last_limb_bits = out_max_bits - n * (k - 1);
// these are witness vectors:
// we need to find `out_vec` as a proper BigInt with k limbs
// we need to find `quot_vec` as a proper BigInt with k limbs
let (quot_val, out_val) = a.value.div_mod_floor(modulus);
debug_assert!(out_val < (BigInt::one() << (n * k)));
debug_assert!(quot_val.abs() < (BigInt::one() << quot_max_bits));
// decompose_bigint just throws away signed limbs in index >= k
let out_vec = decompose_bigint::<F>(&out_val, k, n);
let quot_vec = decompose_bigint::<F>("_val, k, n);
// we need to constrain that `sum_i out_vec[i] * 2^{n*i} = out_native` in `F`
// we need to constrain that `sum_i quot_vec[i] * 2^{n*i} = quot_native` in `F`
// assert!(modulus < &(BigUint::one() << (n * k)));
assert_eq!(mod_vec.len(), k);
// We need to show `out - a + modulus * quotient` is:
// - congruent to `0 (mod 2^trunc_len)`
// - equal to 0 in native field `F`
// Modulo 2^trunc_len, using OverflowInteger:
// ------------------------------------------
// Goal: assign cells to `out - a + modulus * quotient`
// 1. we effectively do mul_no_carry::truncate(mod_vec, quot_vec) while assigning `mod_vec` and `quot_vec` as we go
// call the output `prod` which has len k
// 2. for prod[i] we can compute `prod + out - a`
// where we assign `out_vec` as we go
let mut quot_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
let mut out_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
let mut check_assigned: Vec<AssignedValue<F>> = Vec::with_capacity(k);
// strategies where we carry out school-book multiplication in some form:
// BigIntStrategy::Simple => {
for (i, ((a_limb, quot_v), out_v)) in
a.truncation.limbs.into_iter().zip(quot_vec).zip(out_vec).enumerate()
{
let (prod, new_quot_cell) = range.gate().inner_product_left_last(
ctx,
quot_assigned.iter().map(|a| Existing(*a)).chain(iter::once(Witness(quot_v))),
mod_vec[..=i].iter().rev().map(|c| Constant(*c)),
);
// let gate_index = prod.column();
// perform step 2: compute prod - a + out
let temp1 = *prod.value() - a_limb.value();
let check_val = temp1 + out_v;
// transpose of:
// | prod | -1 | a | prod - a | 1 | out | prod - a + out
// where prod is at relative row `offset`
ctx.assign_region(
[
Constant(-F::ONE),
Existing(a_limb),
Witness(temp1),
Constant(F::ONE),
Witness(out_v),
Witness(check_val),
],
[-1, 2], // note the NEGATIVE index! this is using gate overlapping with the previous inner product call
);
let check_cell = ctx.last().unwrap();
let out_cell = ctx.get(-2);
quot_assigned.push(new_quot_cell);
out_assigned.push(out_cell);
check_assigned.push(check_cell);
}
// }
//}
// range check limbs of `out` are in [0, 2^n) except last limb should be in [0, 2^out_last_limb_bits)
for (out_index, out_cell) in out_assigned.iter().enumerate() {
let limb_bits = if out_index == k - 1 { out_last_limb_bits } else { n };
range.range_check(ctx, *out_cell, limb_bits);
}
// range check that quot_cell in quot_assigned is in [-2^n, 2^n) except for last cell check it's in [-2^quot_last_limb_bits, 2^quot_last_limb_bits)
for (q_index, quot_cell) in quot_assigned.iter().enumerate() {
let limb_bits = if q_index == k - 1 { quot_last_limb_bits } else { n };
let limb_base =
if q_index == k - 1 { range.gate().pow_of_two()[limb_bits] } else { limb_bases[1] };
// compute quot_cell + 2^n and range check with n + 1 bits
let quot_shift = range.gate().add(ctx, *quot_cell, Constant(limb_base));
range.range_check(ctx, quot_shift, limb_bits + 1);
}
let check_overflow_int = OverflowInteger::new(
check_assigned,
max(max(limb_bits, a.truncation.max_limb_bits) + 1, 2 * n + k_bits),
);
// check that `out - a + modulus * quotient == 0 mod 2^{trunc_len}` after carry
check_carry_to_zero::truncate::<F>(
range,
ctx,
check_overflow_int,
limb_bits,
limb_bases[1],
limb_base_big,
);
// Constrain `quot_native = sum_i quot_assigned[i] * 2^{n*i}` in `F`
let quot_native =
OverflowInteger::evaluate_native(ctx, range.gate(), quot_assigned, limb_bases);
// Constrain `out_native = sum_i out_assigned[i] * 2^{n*i}` in `F`
let out_native =
OverflowInteger::evaluate_native(ctx, range.gate(), out_assigned.clone(), limb_bases);
// We save 1 cell by connecting `out_native` computation with the following:
// Check `out + modulus * quotient - a = 0` in native field
// | out | modulus | quotient | a |
ctx.assign_region(
[Constant(mod_native), Existing(quot_native), Existing(a.native)],
[-1], // negative index because -1 relative offset is `out_native` assigned value
);
ProperCrtUint(CRTInteger::new(
ProperUint(out_assigned).into_overflow(limb_bits),
out_native,
out_val,
))
}