halo2_ecc/bigint/
select_by_indicator.rs

1use super::{CRTInteger, OverflowInteger};
2use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};
3use num_bigint::BigInt;
4use num_traits::Zero;
5use std::cmp::max;
6
7/// only use case is when coeffs has only a single 1, rest are 0
8pub fn assign<F: ScalarField>(
9    gate: &impl GateInstructions<F>,
10    ctx: &mut Context<F>,
11    a: &[OverflowInteger<F>],
12    coeffs: &[AssignedValue<F>],
13) -> OverflowInteger<F> {
14    let k = a[0].limbs.len();
15
16    let out_limbs = (0..k)
17        .map(|idx| {
18            let int_limbs = a.iter().map(|a| a.limbs[idx]);
19            gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied())
20        })
21        .collect();
22
23    let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.max_limb_bits));
24
25    OverflowInteger::new(out_limbs, max_limb_bits)
26}
27
28/// only use case is when coeffs has only a single 1, rest are 0
29pub fn crt<F: ScalarField>(
30    gate: &impl GateInstructions<F>,
31    ctx: &mut Context<F>,
32    a: &[impl AsRef<CRTInteger<F>>],
33    coeffs: &[AssignedValue<F>],
34    limb_bases: &[F],
35) -> CRTInteger<F> {
36    assert_eq!(a.len(), coeffs.len());
37    let k = a[0].as_ref().truncation.limbs.len();
38
39    let out_limbs = (0..k)
40        .map(|idx| {
41            let int_limbs = a.iter().map(|a| a.as_ref().truncation.limbs[idx]);
42            gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied())
43        })
44        .collect();
45
46    let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.as_ref().truncation.max_limb_bits));
47
48    let out_trunc = OverflowInteger::new(out_limbs, max_limb_bits);
49    let out_native = if a.len() > k {
50        OverflowInteger::evaluate_native(
51            ctx,
52            gate,
53            out_trunc.limbs.iter().copied(),
54            &limb_bases[..k],
55        )
56    } else {
57        let a_native = a.iter().map(|x| x.as_ref().native);
58        gate.select_by_indicator(ctx, a_native, coeffs.iter().copied())
59    };
60    let out_val = a.iter().zip(coeffs.iter()).fold(BigInt::zero(), |acc, (x, y)| {
61        if y.value().is_zero_vartime() {
62            acc
63        } else {
64            x.as_ref().value.clone()
65        }
66    });
67
68    CRTInteger::new(out_trunc, out_native, out_val)
69}