halo2_ecc/ecc/
pippenger.rs

1use super::{
2    ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point,
3    strict_ec_select_from_bits, EcPoint,
4};
5use crate::{
6    ecc::ec_sub_strict,
7    fields::{FieldChip, Selectable},
8};
9use halo2_base::{
10    gates::{
11        flex_gate::threads::{parallelize_core, SinglePhaseCoreManager},
12        GateInstructions,
13    },
14    utils::{BigPrimeField, CurveAffineExt},
15    AssignedValue,
16};
17
18// Reference: https://jbootle.github.io/Misc/pippenger.pdf
19
20// Reduction to multi-products
21// Output:
22// * new_points: length `points.len() * radix`
23// * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix`
24//
25// Empirically `radix = 1` is best, so we don't use this function for now
26/*
27pub fn decompose<F, FC>(
28    chip: &FC,
29    ctx: &mut Context<F>,
30    points: &[EcPoint<F, FC::FieldPoint>],
31    scalars: &[Vec<AssignedValue<F>>],
32    max_scalar_bits_per_cell: usize,
33    radix: usize,
34) -> (Vec<EcPoint<F, FC::FieldPoint>>, Vec<Vec<AssignedValue<F>>>)
35where
36    F: PrimeField,
37    FC: FieldChip<F>,
38{
39    assert_eq!(points.len(), scalars.len());
40    let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
41    let t = (scalar_bits + radix - 1) / radix;
42
43    let mut new_points = Vec::with_capacity(radix * points.len());
44    let mut new_bool_scalars = vec![Vec::with_capacity(radix * points.len()); t];
45
46    let zero_cell = ctx.load_zero();
47    for (point, scalar) in points.iter().zip(scalars.iter()) {
48        assert_eq!(scalars[0].len(), scalar.len());
49        let mut g = point.clone();
50        new_points.push(g);
51        for _ in 1..radix {
52            // if radix > 1, this does not work if `points` contains identity point
53            g = ec_double(chip, ctx, new_points.last().unwrap());
54            new_points.push(g);
55        }
56        let mut bits = Vec::with_capacity(scalar_bits);
57        for x in scalar {
58            let mut new_bits = chip.gate().num_to_bits(ctx, *x, max_scalar_bits_per_cell);
59            bits.append(&mut new_bits);
60        }
61        for k in 0..t {
62            new_bool_scalars[k]
63                .extend_from_slice(&bits[(radix * k)..std::cmp::min(radix * (k + 1), scalar_bits)]);
64        }
65        new_bool_scalars[t - 1].extend(vec![zero_cell.clone(); radix * t - scalar_bits]);
66    }
67
68    (new_points, new_bool_scalars)
69}
70*/
71
72/* Left as reference; should always use msm_par
73// Given points[i] and bool_scalars[j][i],
74// compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i]
75// output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point
76pub fn multi_product<F: PrimeField, FC, C>(
77    chip: &FC,
78    ctx: &mut Context<F>,
79    points: &[EcPoint<F, FC::FieldPoint>],
80    bool_scalars: &[Vec<AssignedValue<F>>],
81    clumping_factor: usize,
82) -> (Vec<StrictEcPoint<F, FC>>, EcPoint<F, FC::FieldPoint>)
83where
84    FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
85    C: CurveAffineExt<Base = FC::FieldType>,
86{
87    let c = clumping_factor; // this is `b` in Section 3 of Bootle
88
89    // to avoid adding two points that are equal or negative of each other,
90    // we use a trick from halo2wrong where we load a random C point as witness
91    // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
92    // TODO: an alternate approach is to use Fiat-Shamir transform (with Poseidon) to hash all the inputs (points, bool_scalars, ...) to get the random point. This could be worth it for large MSMs as we get savings from `add_unequal` in "non-strict" mode. Perhaps not worth the trouble / security concern, though.
93    let any_base = load_random_point::<F, FC, C>(chip, ctx);
94
95    let mut acc = Vec::with_capacity(bool_scalars.len());
96
97    let mut bucket = Vec::with_capacity(1 << c);
98    let mut any_point = any_base.clone();
99    for (round, points_clump) in points.chunks(c).enumerate() {
100        // compute all possible multi-products of elements in points[round * c .. round * (c+1)]
101
102        // for later addition collision-prevension, we need a different random point per round
103        // we take 2^round * rand_base
104        if round > 0 {
105            any_point = ec_double(chip, ctx, any_point);
106        }
107        // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... }
108        // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other)
109        bucket.clear();
110        let strict_any_point = into_strict_point(chip, ctx, any_point.clone());
111        bucket.push(strict_any_point);
112        for (i, point) in points_clump.iter().enumerate() {
113            // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates
114            // this can be checked by points[i].y == 0 iff points[i] == O
115            let is_infinity = chip.is_zero(ctx, &point.y);
116            let point = into_strict_point(chip, ctx, point.clone());
117
118            for j in 0..(1 << i) {
119                let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true);
120                // if points[i] is point at infinity, do nothing
121                new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity);
122                let new_point = into_strict_point(chip, ctx, new_point);
123                bucket.push(new_point);
124            }
125        }
126
127        // for each j, select using clump in e[j][i=...]
128        for (j, bits) in bool_scalars.iter().enumerate() {
129            let multi_prod = strict_ec_select_from_bits(
130                chip,
131                ctx,
132                &bucket,
133                &bits[round * c..round * c + points_clump.len()],
134            );
135            // since `bucket` is all `StrictEcPoint` and we are selecting from it, we know `multi_prod` is StrictEcPoint
136            // everything in bucket has already been enforced
137            if round == 0 {
138                acc.push(multi_prod);
139            } else {
140                let _acc = ec_add_unequal(chip, ctx, &acc[j], multi_prod, true);
141                acc[j] = into_strict_point(chip, ctx, _acc);
142            }
143        }
144    }
145
146    // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base
147    any_point = ec_double(chip, ctx, any_point);
148    any_point = ec_sub_unequal(chip, ctx, any_point, any_base, false);
149
150    (acc, any_point)
151}
152
153/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case)
154///
155/// # Assumptions
156/// * `points.len() == scalars.len()`
157/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
158pub fn multi_exp<F: PrimeField, FC, C>(
159    chip: &FC,
160    ctx: &mut Context<F>,
161    points: &[EcPoint<F, FC::FieldPoint>],
162    scalars: Vec<Vec<AssignedValue<F>>>,
163    max_scalar_bits_per_cell: usize,
164    // radix: usize, // specialize to radix = 1
165    clump_factor: usize,
166) -> EcPoint<F, FC::FieldPoint>
167where
168    FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
169    C: CurveAffineExt<Base = FC::FieldType>,
170{
171    // let (points, bool_scalars) = decompose::<F, _>(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix);
172
173    debug_assert_eq!(points.len(), scalars.len());
174    let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
175    // bool_scalars: 2d array `scalar_bits` by `points.len()`
176    let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits];
177    for scalar in scalars {
178        for (scalar_chunk, bool_chunk) in
179            scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell))
180        {
181            let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell);
182            for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) {
183                bool_bit.push(bit);
184            }
185        }
186    }
187
188    let (mut agg, any_point) =
189        multi_product::<F, FC, C>(chip, ctx, points, &bool_scalars, clump_factor);
190    // everything in agg has been enforced
191
192    // compute sum_{k=0..t} agg[k] * 2^{radix * k} - (sum_k 2^{radix * k}) * rand_point
193    // (sum_{k=0..t} 2^{radix * k}) = (2^{radix * t} - 1)/(2^radix - 1)
194    let mut sum = agg.pop().unwrap().into();
195    let mut any_sum = any_point.clone();
196    for g in agg.iter().rev() {
197        any_sum = ec_double(chip, ctx, any_sum);
198        // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g`
199        sum = ec_double(chip, ctx, sum);
200        sum = ec_add_unequal(chip, ctx, sum, g, true);
201    }
202
203    any_sum = ec_double(chip, ctx, any_sum);
204    // assume 2^scalar_bits != +-1 mod modulus::<F>()
205    any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false);
206
207    ec_sub_unequal(chip, ctx, sum, any_sum, true)
208}
209*/
210
211/// Multi-thread witness generation for multi-scalar multiplication.
212///
213/// # Assumptions
214/// * `points.len() == scalars.len()`
215/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
216/// * `points` are all on the curve or the point at infinity
217/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
218/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
219pub fn multi_exp_par<F: BigPrimeField, FC, C>(
220    chip: &FC,
221    // these are the "threads" within a single Phase
222    builder: &mut SinglePhaseCoreManager<F>,
223    points: &[EcPoint<F, FC::FieldPoint>],
224    scalars: Vec<Vec<AssignedValue<F>>>,
225    max_scalar_bits_per_cell: usize,
226    // radix: usize, // specialize to radix = 1
227    clump_factor: usize,
228) -> EcPoint<F, FC::FieldPoint>
229where
230    FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
231    C: CurveAffineExt<Base = FC::FieldType>,
232{
233    // let (points, bool_scalars) = decompose::<F, _>(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix);
234
235    assert_eq!(points.len(), scalars.len());
236    let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
237    // bool_scalars: 2d array `scalar_bits` by `points.len()`
238    let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits];
239
240    // get a main thread
241    let ctx = builder.main();
242    // single-threaded computation:
243    for scalar in scalars {
244        for (scalar_chunk, bool_chunk) in
245            scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell))
246        {
247            let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell);
248            for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) {
249                bool_bit.push(bit);
250            }
251        }
252    }
253
254    let c = clump_factor;
255    let num_rounds = points.len().div_ceil(c);
256    // to avoid adding two points that are equal or negative of each other,
257    // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness
258    // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
259    // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do
260    let any_base = load_random_point::<F, FC, C>(chip, ctx);
261    let mut any_points = Vec::with_capacity(num_rounds);
262    any_points.push(any_base);
263    for _ in 1..num_rounds {
264        any_points.push(ec_double(chip, ctx, any_points.last().unwrap()));
265    }
266
267    // now begins multi-threading
268    // multi_prods is 2d vector of size `num_rounds` by `scalar_bits`
269    let multi_prods = parallelize_core(
270        builder,
271        points.chunks(c).zip(any_points.iter()).enumerate().collect(),
272        |ctx, (round, (points_clump, any_point))| {
273            // compute all possible multi-products of elements in points[round * c .. round * (c+1)]
274            // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... }
275            let mut bucket = Vec::with_capacity(1 << c);
276            let any_point = into_strict_point(chip, ctx, any_point.clone());
277            bucket.push(any_point);
278            for (i, point) in points_clump.iter().enumerate() {
279                // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates
280                // this can be checked by points[i].y == 0 iff points[i] == O
281                let is_infinity = chip.is_zero(ctx, &point.y);
282                let point = into_strict_point(chip, ctx, point.clone());
283
284                for j in 0..(1 << i) {
285                    let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true);
286                    // if points[i] is point at infinity, do nothing
287                    new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity);
288                    let new_point = into_strict_point(chip, ctx, new_point);
289                    bucket.push(new_point);
290                }
291            }
292            bool_scalars
293                .iter()
294                .map(|bits| {
295                    strict_ec_select_from_bits(
296                        chip,
297                        ctx,
298                        &bucket,
299                        &bits[round * c..round * c + points_clump.len()],
300                    )
301                })
302                .collect::<Vec<_>>()
303        },
304    );
305
306    // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits
307    let mut agg = parallelize_core(builder, (0..scalar_bits).collect(), |ctx, i| {
308        let mut acc = multi_prods[0][i].clone();
309        for multi_prod in multi_prods.iter().skip(1) {
310            let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
311            acc = into_strict_point(chip, ctx, _acc);
312        }
313        acc
314    });
315
316    // gets the LAST thread for single threaded work
317    let ctx = builder.main();
318    // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base
319    // let any_point = (2^num_rounds - 1) * any_base
320    // TODO: can we remove all these random point operations somehow?
321    let mut any_point = ec_double(chip, ctx, any_points.last().unwrap());
322    any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true);
323
324    // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point
325    // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1)
326    let mut sum = agg.pop().unwrap().into();
327    let mut any_sum = any_point.clone();
328    for g in agg.iter().rev() {
329        any_sum = ec_double(chip, ctx, any_sum);
330        // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g`
331        sum = ec_double(chip, ctx, sum);
332        sum = ec_add_unequal(chip, ctx, sum, g, true);
333    }
334
335    any_sum = ec_double(chip, ctx, any_sum);
336    any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true);
337
338    ec_sub_strict(chip, ctx, sum, any_sum)
339}