halo2_ecc/bigint/
select_by_indicator.rs
1use super::{CRTInteger, OverflowInteger};
2use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};
3use num_bigint::BigInt;
4use num_traits::Zero;
5use std::cmp::max;
6
7pub fn assign<F: ScalarField>(
9 gate: &impl GateInstructions<F>,
10 ctx: &mut Context<F>,
11 a: &[OverflowInteger<F>],
12 coeffs: &[AssignedValue<F>],
13) -> OverflowInteger<F> {
14 let k = a[0].limbs.len();
15
16 let out_limbs = (0..k)
17 .map(|idx| {
18 let int_limbs = a.iter().map(|a| a.limbs[idx]);
19 gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied())
20 })
21 .collect();
22
23 let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.max_limb_bits));
24
25 OverflowInteger::new(out_limbs, max_limb_bits)
26}
27
28pub fn crt<F: ScalarField>(
30 gate: &impl GateInstructions<F>,
31 ctx: &mut Context<F>,
32 a: &[impl AsRef<CRTInteger<F>>],
33 coeffs: &[AssignedValue<F>],
34 limb_bases: &[F],
35) -> CRTInteger<F> {
36 assert_eq!(a.len(), coeffs.len());
37 let k = a[0].as_ref().truncation.limbs.len();
38
39 let out_limbs = (0..k)
40 .map(|idx| {
41 let int_limbs = a.iter().map(|a| a.as_ref().truncation.limbs[idx]);
42 gate.select_by_indicator(ctx, int_limbs, coeffs.iter().copied())
43 })
44 .collect();
45
46 let max_limb_bits = a.iter().fold(0, |acc, x| max(acc, x.as_ref().truncation.max_limb_bits));
47
48 let out_trunc = OverflowInteger::new(out_limbs, max_limb_bits);
49 let out_native = if a.len() > k {
50 OverflowInteger::evaluate_native(
51 ctx,
52 gate,
53 out_trunc.limbs.iter().copied(),
54 &limb_bases[..k],
55 )
56 } else {
57 let a_native = a.iter().map(|x| x.as_ref().native);
58 gate.select_by_indicator(ctx, a_native, coeffs.iter().copied())
59 };
60 let out_val = a.iter().zip(coeffs.iter()).fold(BigInt::zero(), |acc, (x, y)| {
61 if y.value().is_zero_vartime() {
62 acc
63 } else {
64 x.as_ref().value.clone()
65 }
66 });
67
68 CRTInteger::new(out_trunc, out_native, out_val)
69}