halo2_ecc/ecc/pippenger.rs
1use super::{
2 ec_add_unequal, ec_double, ec_select, ec_sub_unequal, into_strict_point, load_random_point,
3 strict_ec_select_from_bits, EcPoint,
4};
5use crate::{
6 ecc::ec_sub_strict,
7 fields::{FieldChip, Selectable},
8};
9use halo2_base::{
10 gates::{
11 flex_gate::threads::{parallelize_core, SinglePhaseCoreManager},
12 GateInstructions,
13 },
14 utils::{BigPrimeField, CurveAffineExt},
15 AssignedValue,
16};
17
18// Reference: https://jbootle.github.io/Misc/pippenger.pdf
19
20// Reduction to multi-products
21// Output:
22// * new_points: length `points.len() * radix`
23// * new_bool_scalars: 2d array `ceil(scalar_bits / radix)` by `points.len() * radix`
24//
25// Empirically `radix = 1` is best, so we don't use this function for now
26/*
27pub fn decompose<F, FC>(
28 chip: &FC,
29 ctx: &mut Context<F>,
30 points: &[EcPoint<F, FC::FieldPoint>],
31 scalars: &[Vec<AssignedValue<F>>],
32 max_scalar_bits_per_cell: usize,
33 radix: usize,
34) -> (Vec<EcPoint<F, FC::FieldPoint>>, Vec<Vec<AssignedValue<F>>>)
35where
36 F: PrimeField,
37 FC: FieldChip<F>,
38{
39 assert_eq!(points.len(), scalars.len());
40 let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
41 let t = (scalar_bits + radix - 1) / radix;
42
43 let mut new_points = Vec::with_capacity(radix * points.len());
44 let mut new_bool_scalars = vec![Vec::with_capacity(radix * points.len()); t];
45
46 let zero_cell = ctx.load_zero();
47 for (point, scalar) in points.iter().zip(scalars.iter()) {
48 assert_eq!(scalars[0].len(), scalar.len());
49 let mut g = point.clone();
50 new_points.push(g);
51 for _ in 1..radix {
52 // if radix > 1, this does not work if `points` contains identity point
53 g = ec_double(chip, ctx, new_points.last().unwrap());
54 new_points.push(g);
55 }
56 let mut bits = Vec::with_capacity(scalar_bits);
57 for x in scalar {
58 let mut new_bits = chip.gate().num_to_bits(ctx, *x, max_scalar_bits_per_cell);
59 bits.append(&mut new_bits);
60 }
61 for k in 0..t {
62 new_bool_scalars[k]
63 .extend_from_slice(&bits[(radix * k)..std::cmp::min(radix * (k + 1), scalar_bits)]);
64 }
65 new_bool_scalars[t - 1].extend(vec![zero_cell.clone(); radix * t - scalar_bits]);
66 }
67
68 (new_points, new_bool_scalars)
69}
70*/
71
72/* Left as reference; should always use msm_par
73// Given points[i] and bool_scalars[j][i],
74// compute G'[j] = sum_{i=0..points.len()} points[i] * bool_scalars[j][i]
75// output is [ G'[j] + rand_point ]_{j=0..bool_scalars.len()}, rand_point
76pub fn multi_product<F: PrimeField, FC, C>(
77 chip: &FC,
78 ctx: &mut Context<F>,
79 points: &[EcPoint<F, FC::FieldPoint>],
80 bool_scalars: &[Vec<AssignedValue<F>>],
81 clumping_factor: usize,
82) -> (Vec<StrictEcPoint<F, FC>>, EcPoint<F, FC::FieldPoint>)
83where
84 FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
85 C: CurveAffineExt<Base = FC::FieldType>,
86{
87 let c = clumping_factor; // this is `b` in Section 3 of Bootle
88
89 // to avoid adding two points that are equal or negative of each other,
90 // we use a trick from halo2wrong where we load a random C point as witness
91 // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
92 // TODO: an alternate approach is to use Fiat-Shamir transform (with Poseidon) to hash all the inputs (points, bool_scalars, ...) to get the random point. This could be worth it for large MSMs as we get savings from `add_unequal` in "non-strict" mode. Perhaps not worth the trouble / security concern, though.
93 let any_base = load_random_point::<F, FC, C>(chip, ctx);
94
95 let mut acc = Vec::with_capacity(bool_scalars.len());
96
97 let mut bucket = Vec::with_capacity(1 << c);
98 let mut any_point = any_base.clone();
99 for (round, points_clump) in points.chunks(c).enumerate() {
100 // compute all possible multi-products of elements in points[round * c .. round * (c+1)]
101
102 // for later addition collision-prevension, we need a different random point per round
103 // we take 2^round * rand_base
104 if round > 0 {
105 any_point = ec_double(chip, ctx, any_point);
106 }
107 // stores { rand_point, rand_point + points[0], rand_point + points[1], rand_point + points[0] + points[1] , ... }
108 // since rand_point is random, we can always use add_unequal (with strict constraint checking that the points are indeed unequal and not negative of each other)
109 bucket.clear();
110 let strict_any_point = into_strict_point(chip, ctx, any_point.clone());
111 bucket.push(strict_any_point);
112 for (i, point) in points_clump.iter().enumerate() {
113 // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates
114 // this can be checked by points[i].y == 0 iff points[i] == O
115 let is_infinity = chip.is_zero(ctx, &point.y);
116 let point = into_strict_point(chip, ctx, point.clone());
117
118 for j in 0..(1 << i) {
119 let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true);
120 // if points[i] is point at infinity, do nothing
121 new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity);
122 let new_point = into_strict_point(chip, ctx, new_point);
123 bucket.push(new_point);
124 }
125 }
126
127 // for each j, select using clump in e[j][i=...]
128 for (j, bits) in bool_scalars.iter().enumerate() {
129 let multi_prod = strict_ec_select_from_bits(
130 chip,
131 ctx,
132 &bucket,
133 &bits[round * c..round * c + points_clump.len()],
134 );
135 // since `bucket` is all `StrictEcPoint` and we are selecting from it, we know `multi_prod` is StrictEcPoint
136 // everything in bucket has already been enforced
137 if round == 0 {
138 acc.push(multi_prod);
139 } else {
140 let _acc = ec_add_unequal(chip, ctx, &acc[j], multi_prod, true);
141 acc[j] = into_strict_point(chip, ctx, _acc);
142 }
143 }
144 }
145
146 // we have acc[j] = G'[j] + (2^num_rounds - 1) * rand_base
147 any_point = ec_double(chip, ctx, any_point);
148 any_point = ec_sub_unequal(chip, ctx, any_point, any_base, false);
149
150 (acc, any_point)
151}
152
153/// Currently does not support if the final answer is actually the point at infinity (meaning constraints will fail in that case)
154///
155/// # Assumptions
156/// * `points.len() == scalars.len()`
157/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
158pub fn multi_exp<F: PrimeField, FC, C>(
159 chip: &FC,
160 ctx: &mut Context<F>,
161 points: &[EcPoint<F, FC::FieldPoint>],
162 scalars: Vec<Vec<AssignedValue<F>>>,
163 max_scalar_bits_per_cell: usize,
164 // radix: usize, // specialize to radix = 1
165 clump_factor: usize,
166) -> EcPoint<F, FC::FieldPoint>
167where
168 FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
169 C: CurveAffineExt<Base = FC::FieldType>,
170{
171 // let (points, bool_scalars) = decompose::<F, _>(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix);
172
173 debug_assert_eq!(points.len(), scalars.len());
174 let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
175 // bool_scalars: 2d array `scalar_bits` by `points.len()`
176 let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits];
177 for scalar in scalars {
178 for (scalar_chunk, bool_chunk) in
179 scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell))
180 {
181 let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell);
182 for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) {
183 bool_bit.push(bit);
184 }
185 }
186 }
187
188 let (mut agg, any_point) =
189 multi_product::<F, FC, C>(chip, ctx, points, &bool_scalars, clump_factor);
190 // everything in agg has been enforced
191
192 // compute sum_{k=0..t} agg[k] * 2^{radix * k} - (sum_k 2^{radix * k}) * rand_point
193 // (sum_{k=0..t} 2^{radix * k}) = (2^{radix * t} - 1)/(2^radix - 1)
194 let mut sum = agg.pop().unwrap().into();
195 let mut any_sum = any_point.clone();
196 for g in agg.iter().rev() {
197 any_sum = ec_double(chip, ctx, any_sum);
198 // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g`
199 sum = ec_double(chip, ctx, sum);
200 sum = ec_add_unequal(chip, ctx, sum, g, true);
201 }
202
203 any_sum = ec_double(chip, ctx, any_sum);
204 // assume 2^scalar_bits != +-1 mod modulus::<F>()
205 any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, false);
206
207 ec_sub_unequal(chip, ctx, sum, any_sum, true)
208}
209*/
210
211/// Multi-thread witness generation for multi-scalar multiplication.
212///
213/// # Assumptions
214/// * `points.len() == scalars.len()`
215/// * `scalars[i].len() == scalars[j].len()` for all `i, j`
216/// * `points` are all on the curve or the point at infinity
217/// * `points[i]` is allowed to be (0, 0) to represent the point at infinity (identity point)
218/// * Currently implementation assumes that the only point on curve with y-coordinate equal to `0` is identity point
219pub fn multi_exp_par<F: BigPrimeField, FC, C>(
220 chip: &FC,
221 // these are the "threads" within a single Phase
222 builder: &mut SinglePhaseCoreManager<F>,
223 points: &[EcPoint<F, FC::FieldPoint>],
224 scalars: Vec<Vec<AssignedValue<F>>>,
225 max_scalar_bits_per_cell: usize,
226 // radix: usize, // specialize to radix = 1
227 clump_factor: usize,
228) -> EcPoint<F, FC::FieldPoint>
229where
230 FC: FieldChip<F> + Selectable<F, FC::FieldPoint> + Selectable<F, FC::ReducedFieldPoint>,
231 C: CurveAffineExt<Base = FC::FieldType>,
232{
233 // let (points, bool_scalars) = decompose::<F, _>(chip, ctx, points, scalars, max_scalar_bits_per_cell, radix);
234
235 assert_eq!(points.len(), scalars.len());
236 let scalar_bits = max_scalar_bits_per_cell * scalars[0].len();
237 // bool_scalars: 2d array `scalar_bits` by `points.len()`
238 let mut bool_scalars = vec![Vec::with_capacity(points.len()); scalar_bits];
239
240 // get a main thread
241 let ctx = builder.main();
242 // single-threaded computation:
243 for scalar in scalars {
244 for (scalar_chunk, bool_chunk) in
245 scalar.into_iter().zip(bool_scalars.chunks_mut(max_scalar_bits_per_cell))
246 {
247 let bits = chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell);
248 for (bit, bool_bit) in bits.into_iter().zip(bool_chunk.iter_mut()) {
249 bool_bit.push(bit);
250 }
251 }
252 }
253
254 let c = clump_factor;
255 let num_rounds = points.len().div_ceil(c);
256 // to avoid adding two points that are equal or negative of each other,
257 // we use a trick from halo2wrong where we load a "sufficiently generic" `C` point as witness
258 // note that while we load a random point, an adversary could load a specifically chosen point, so we must carefully handle edge cases with constraints
259 // we call it "any point" instead of "random point" to emphasize that "any" sufficiently generic point will do
260 let any_base = load_random_point::<F, FC, C>(chip, ctx);
261 let mut any_points = Vec::with_capacity(num_rounds);
262 any_points.push(any_base);
263 for _ in 1..num_rounds {
264 any_points.push(ec_double(chip, ctx, any_points.last().unwrap()));
265 }
266
267 // now begins multi-threading
268 // multi_prods is 2d vector of size `num_rounds` by `scalar_bits`
269 let multi_prods = parallelize_core(
270 builder,
271 points.chunks(c).zip(any_points.iter()).enumerate().collect(),
272 |ctx, (round, (points_clump, any_point))| {
273 // compute all possible multi-products of elements in points[round * c .. round * (c+1)]
274 // stores { any_point, any_point + points[0], any_point + points[1], any_point + points[0] + points[1] , ... }
275 let mut bucket = Vec::with_capacity(1 << c);
276 let any_point = into_strict_point(chip, ctx, any_point.clone());
277 bucket.push(any_point);
278 for (i, point) in points_clump.iter().enumerate() {
279 // we allow for points[i] to be the point at infinity, represented by (0, 0) in affine coordinates
280 // this can be checked by points[i].y == 0 iff points[i] == O
281 let is_infinity = chip.is_zero(ctx, &point.y);
282 let point = into_strict_point(chip, ctx, point.clone());
283
284 for j in 0..(1 << i) {
285 let mut new_point = ec_add_unequal(chip, ctx, &bucket[j], &point, true);
286 // if points[i] is point at infinity, do nothing
287 new_point = ec_select(chip, ctx, (&bucket[j]).into(), new_point, is_infinity);
288 let new_point = into_strict_point(chip, ctx, new_point);
289 bucket.push(new_point);
290 }
291 }
292 bool_scalars
293 .iter()
294 .map(|bits| {
295 strict_ec_select_from_bits(
296 chip,
297 ctx,
298 &bucket,
299 &bits[round * c..round * c + points_clump.len()],
300 )
301 })
302 .collect::<Vec<_>>()
303 },
304 );
305
306 // agg[j] = sum_{i=0..num_rounds} multi_prods[i][j] for j = 0..scalar_bits
307 let mut agg = parallelize_core(builder, (0..scalar_bits).collect(), |ctx, i| {
308 let mut acc = multi_prods[0][i].clone();
309 for multi_prod in multi_prods.iter().skip(1) {
310 let _acc = ec_add_unequal(chip, ctx, &acc, &multi_prod[i], true);
311 acc = into_strict_point(chip, ctx, _acc);
312 }
313 acc
314 });
315
316 // gets the LAST thread for single threaded work
317 let ctx = builder.main();
318 // we have agg[j] = G'[j] + (2^num_rounds - 1) * any_base
319 // let any_point = (2^num_rounds - 1) * any_base
320 // TODO: can we remove all these random point operations somehow?
321 let mut any_point = ec_double(chip, ctx, any_points.last().unwrap());
322 any_point = ec_sub_unequal(chip, ctx, any_point, &any_points[0], true);
323
324 // compute sum_{k=0..scalar_bits} agg[k] * 2^k - (sum_{k=0..scalar_bits} 2^k) * rand_point
325 // (sum_{k=0..scalar_bits} 2^k) = (2^scalar_bits - 1)
326 let mut sum = agg.pop().unwrap().into();
327 let mut any_sum = any_point.clone();
328 for g in agg.iter().rev() {
329 any_sum = ec_double(chip, ctx, any_sum);
330 // cannot use ec_double_and_add_unequal because you cannot guarantee that `sum != g`
331 sum = ec_double(chip, ctx, sum);
332 sum = ec_add_unequal(chip, ctx, sum, g, true);
333 }
334
335 any_sum = ec_double(chip, ctx, any_sum);
336 any_sum = ec_sub_unequal(chip, ctx, any_sum, any_point, true);
337
338 ec_sub_strict(chip, ctx, sum, any_sum)
339}