halo2_ecc/bigint/
select.rs
1use super::{CRTInteger, OverflowInteger};
2use halo2_base::{gates::GateInstructions, utils::ScalarField, AssignedValue, Context};
3use itertools::Itertools;
4use std::cmp::max;
5
6pub fn assign<F: ScalarField>(
10 gate: &impl GateInstructions<F>,
11 ctx: &mut Context<F>,
12 a: OverflowInteger<F>,
13 b: OverflowInteger<F>,
14 sel: AssignedValue<F>,
15) -> OverflowInteger<F> {
16 let out_limbs = a
17 .limbs
18 .into_iter()
19 .zip_eq(b.limbs)
20 .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel))
21 .collect();
22
23 OverflowInteger::new(out_limbs, max(a.max_limb_bits, b.max_limb_bits))
24}
25
26pub fn crt<F: ScalarField>(
27 gate: &impl GateInstructions<F>,
28 ctx: &mut Context<F>,
29 a: CRTInteger<F>,
30 b: CRTInteger<F>,
31 sel: AssignedValue<F>,
32) -> CRTInteger<F> {
33 debug_assert_eq!(a.truncation.limbs.len(), b.truncation.limbs.len());
34 let out_limbs = a
35 .truncation
36 .limbs
37 .into_iter()
38 .zip_eq(b.truncation.limbs)
39 .map(|(a_limb, b_limb)| gate.select(ctx, a_limb, b_limb, sel))
40 .collect();
41
42 let out_trunc = OverflowInteger::new(
43 out_limbs,
44 max(a.truncation.max_limb_bits, b.truncation.max_limb_bits),
45 );
46
47 let out_native = gate.select(ctx, a.native, b.native, sel);
48 let out_val = if sel.value().is_zero_vartime() { b.value } else { a.value };
49 CRTInteger::new(out_trunc, out_native, out_val)
50}