halo2_ecc/ecc/
fixed_base.rs

1#![allow(non_snake_case)]
2use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
3use crate::ecc::{ec_sub_strict, load_random_point};
4use crate::ff::Field;
5use crate::fields::{FieldChip, Selectable};
6use crate::group::Curve;
7use halo2_base::gates::flex_gate::threads::{parallelize_core, SinglePhaseCoreManager};
8use halo2_base::utils::BigPrimeField;
9use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context};
10use itertools::Itertools;
11use rayon::prelude::*;
12use std::cmp::min;
13
14/// Computes `[scalar] * P` on y^2 = x^3 + b where `P` is fixed (constant)
15/// - `scalar` is represented as a non-empty reference array of `AssignedValue`s
16/// - `scalar = sum_i scalar_i * 2^{max_bits * i}`
17/// - an array of length > 1 is needed when `scalar` exceeds the modulus of scalar field `F`
18///
19/// # Assumptions
20/// - `scalar_i < 2^{max_bits} for all i` (constrained by num_to_bits)
21/// - `scalar > 0`
22/// - `max_bits <= modulus::<F>.bits()`
23pub fn scalar_multiply<F, FC, C>(
24    chip: &FC,
25    ctx: &mut Context<F>,
26    point: &C,
27    scalar: Vec<AssignedValue<F>>,
28    max_bits: usize,
29    window_bits: usize,
30) -> EcPoint<F, FC::FieldPoint>
31where
32    F: BigPrimeField,
33    C: CurveAffineExt,
34    FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
35{
36    if point.is_identity().into() {
37        let zero = chip.load_constant(ctx, C::Base::ZERO);
38        return EcPoint::new(zero.clone(), zero);
39    }
40    assert!(!scalar.is_empty());
41    assert!((max_bits as u32) <= F::NUM_BITS);
42
43    let total_bits = max_bits * scalar.len();
44    let num_windows = total_bits.div_ceil(window_bits);
45
46    // Jacobian coordinate
47    let base_pt = point.to_curve();
48    // cached_points[i * 2^w + j] holds `[j * 2^(i * w)] * point` for j in {0, ..., 2^w - 1}
49
50    // first we compute all cached points in Jacobian coordinates since it's fastest
51    let mut increment = base_pt;
52    let cached_points_jacobian = (0..num_windows)
53        .flat_map(|i| {
54            let mut curr = increment;
55            // start with increment at index 0 instead of identity just as a dummy value to avoid divide by 0 issues
56            let cache_vec = std::iter::once(increment)
57                .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(|_| {
58                    let prev = curr;
59                    curr += increment;
60                    prev
61                }))
62                .collect::<Vec<_>>();
63            increment = curr;
64            cache_vec
65        })
66        .collect::<Vec<_>>();
67    // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive
68    // initialize to all 0s
69    let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
70    C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);
71
72    // TODO: do not assign and use select_from_bits on Constant(_) QuantumCells
73    let cached_points = cached_points_affine
74        .into_iter()
75        .map(|point| {
76            let (x, y) = point.into_coordinates();
77            let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x));
78            EcPoint::new(x, y)
79        })
80        .collect_vec();
81
82    let bits = scalar
83        .into_iter()
84        .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits))
85        .collect::<Vec<_>>();
86
87    let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
88    let bit_window_rev = bits.chunks(window_bits).rev();
89    let any_point = load_random_point::<F, FC, C>(chip, ctx);
90    let mut curr_point = any_point.clone();
91    for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
92        let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied());
93        // are we just adding a window of all 0s? if so, skip
94        let is_zero_window = chip.gate().is_zero(ctx, bit_sum);
95        curr_point = {
96            let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window);
97            let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true);
98            ec_select(chip, ctx, curr_point, sum, is_zero_window)
99        };
100    }
101    ec_sub_strict(chip, ctx, curr_point, any_point)
102}
103
104// basically just adding up individual fixed_base::scalar_multiply except that we do all batched normalization of cached points at once to further save inversion time during witness generation
105// we also use the random accumulator for some extra efficiency (which also works in scalar multiply case but that is TODO)
106
107/// # Assumptions
108/// * `points.len() = scalars.len()`
109/// * `scalars[i].len() = scalars[j].len()` for all `i,j`
110/// * `points` are all on the curve
111/// * `points[i]` is not point at infinity (0, 0); these should be filtered out beforehand
112/// * The integer value of `scalars[i]` is less than the order of `points[i]`
113/// * Output may be point at infinity, in which case (0, 0) is returned
114pub fn msm_par<F, FC, C>(
115    chip: &EccChip<F, FC>,
116    builder: &mut SinglePhaseCoreManager<F>,
117    points: &[C],
118    scalars: Vec<Vec<AssignedValue<F>>>,
119    max_scalar_bits_per_cell: usize,
120    window_bits: usize,
121) -> EcPoint<F, FC::FieldPoint>
122where
123    F: BigPrimeField,
124    C: CurveAffineExt,
125    FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
126{
127    if points.is_empty() {
128        return chip.assign_constant_point(builder.main(), C::identity());
129    }
130    assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS);
131    assert_eq!(points.len(), scalars.len());
132    assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point");
133    let scalar_len = scalars[0].len();
134    let total_bits = max_scalar_bits_per_cell * scalar_len;
135    let num_windows = total_bits.div_ceil(window_bits);
136
137    // `cached_points` is a flattened 2d vector
138    // first we compute all cached points in Jacobian coordinates since it's fastest
139    let cached_points_jacobian = points
140        .par_iter()
141        .flat_map(|point| -> Vec<_> {
142            let base_pt = point.to_curve();
143            // cached_points[idx][i * 2^w + j] holds `[j * 2^(i * w)] * points[idx]` for j in {0, ..., 2^w - 1}
144            // EXCEPT cached_points[idx][0] = points[idx]
145            let mut increment = base_pt;
146            (0..num_windows)
147                .flat_map(|i| {
148                    let mut curr = increment;
149                    let cache_vec = std::iter::once(increment)
150                        .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(
151                            |_| {
152                                let prev = curr;
153                                curr += increment;
154                                prev
155                            },
156                        ))
157                        .collect::<Vec<_>>();
158                    increment = curr;
159                    cache_vec
160                })
161                .collect()
162        })
163        .collect::<Vec<_>>();
164    // for use in circuits we need affine coordinates, so we do a batch normalize: this is much more efficient than calling `to_affine` one by one since field inversion is very expensive
165    // initialize to all 0s
166    let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
167    C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);
168
169    let field_chip = chip.field_chip();
170    let ctx = builder.main();
171    let any_point = chip.load_random_point::<C>(ctx);
172
173    let scalar_mults = parallelize_core(
174        builder,
175        cached_points_affine
176            .chunks(cached_points_affine.len() / points.len())
177            .zip_eq(scalars)
178            .collect(),
179        |ctx, (cached_points, scalar)| {
180            let cached_points = cached_points
181                .iter()
182                .map(|point| chip.assign_constant_point(ctx, *point))
183                .collect_vec();
184            let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
185
186            assert_eq!(scalar.len(), scalar_len);
187            let bits = scalar
188                .into_iter()
189                .flat_map(|scalar_chunk| {
190                    field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell)
191                })
192                .collect::<Vec<_>>();
193            let bit_window_rev = bits.chunks(window_bits).rev();
194            let mut curr_point = any_point.clone();
195            for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
196                let is_zero_window = {
197                    let sum = field_chip.gate().sum(ctx, bit_window.iter().copied());
198                    field_chip.gate().is_zero(ctx, sum)
199                };
200                curr_point = {
201                    let add_point =
202                        ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window);
203                    let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true);
204                    ec_select(field_chip, ctx, curr_point, sum, is_zero_window)
205                };
206            }
207            curr_point
208        },
209    );
210    let ctx = builder.main();
211    // sum `scalar_mults` but take into account possiblity of identity points
212    let any_point2 = chip.load_random_point::<C>(ctx);
213    let mut acc = any_point2.clone();
214    for point in scalar_mults {
215        let new_acc = chip.add_unequal(ctx, &acc, point, true);
216        acc = chip.sub_unequal(ctx, new_acc, &any_point, true);
217    }
218    ec_sub_strict(field_chip, ctx, acc, any_point2)
219}