1#![allow(non_snake_case)]
2use super::{ec_add_unequal, ec_select, ec_select_from_bits, EcPoint, EccChip};
3use crate::ecc::{ec_sub_strict, load_random_point};
4use crate::ff::Field;
5use crate::fields::{FieldChip, Selectable};
6use crate::group::Curve;
7use halo2_base::gates::flex_gate::threads::{parallelize_core, SinglePhaseCoreManager};
8use halo2_base::utils::BigPrimeField;
9use halo2_base::{gates::GateInstructions, utils::CurveAffineExt, AssignedValue, Context};
10use itertools::Itertools;
11use rayon::prelude::*;
12use std::cmp::min;
13
14pub fn scalar_multiply<F, FC, C>(
24 chip: &FC,
25 ctx: &mut Context<F>,
26 point: &C,
27 scalar: Vec<AssignedValue<F>>,
28 max_bits: usize,
29 window_bits: usize,
30) -> EcPoint<F, FC::FieldPoint>
31where
32 F: BigPrimeField,
33 C: CurveAffineExt,
34 FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
35{
36 if point.is_identity().into() {
37 let zero = chip.load_constant(ctx, C::Base::ZERO);
38 return EcPoint::new(zero.clone(), zero);
39 }
40 assert!(!scalar.is_empty());
41 assert!((max_bits as u32) <= F::NUM_BITS);
42
43 let total_bits = max_bits * scalar.len();
44 let num_windows = total_bits.div_ceil(window_bits);
45
46 let base_pt = point.to_curve();
48 let mut increment = base_pt;
52 let cached_points_jacobian = (0..num_windows)
53 .flat_map(|i| {
54 let mut curr = increment;
55 let cache_vec = std::iter::once(increment)
57 .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(|_| {
58 let prev = curr;
59 curr += increment;
60 prev
61 }))
62 .collect::<Vec<_>>();
63 increment = curr;
64 cache_vec
65 })
66 .collect::<Vec<_>>();
67 let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
70 C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);
71
72 let cached_points = cached_points_affine
74 .into_iter()
75 .map(|point| {
76 let (x, y) = point.into_coordinates();
77 let [x, y] = [x, y].map(|x| chip.load_constant(ctx, x));
78 EcPoint::new(x, y)
79 })
80 .collect_vec();
81
82 let bits = scalar
83 .into_iter()
84 .flat_map(|scalar_chunk| chip.gate().num_to_bits(ctx, scalar_chunk, max_bits))
85 .collect::<Vec<_>>();
86
87 let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
88 let bit_window_rev = bits.chunks(window_bits).rev();
89 let any_point = load_random_point::<F, FC, C>(chip, ctx);
90 let mut curr_point = any_point.clone();
91 for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
92 let bit_sum = chip.gate().sum(ctx, bit_window.iter().copied());
93 let is_zero_window = chip.gate().is_zero(ctx, bit_sum);
95 curr_point = {
96 let add_point = ec_select_from_bits(chip, ctx, cached_point_window, bit_window);
97 let sum = ec_add_unequal(chip, ctx, &curr_point, &add_point, true);
98 ec_select(chip, ctx, curr_point, sum, is_zero_window)
99 };
100 }
101 ec_sub_strict(chip, ctx, curr_point, any_point)
102}
103
104pub fn msm_par<F, FC, C>(
115 chip: &EccChip<F, FC>,
116 builder: &mut SinglePhaseCoreManager<F>,
117 points: &[C],
118 scalars: Vec<Vec<AssignedValue<F>>>,
119 max_scalar_bits_per_cell: usize,
120 window_bits: usize,
121) -> EcPoint<F, FC::FieldPoint>
122where
123 F: BigPrimeField,
124 C: CurveAffineExt,
125 FC: FieldChip<F, FieldType = C::Base> + Selectable<F, FC::FieldPoint>,
126{
127 if points.is_empty() {
128 return chip.assign_constant_point(builder.main(), C::identity());
129 }
130 assert!((max_scalar_bits_per_cell as u32) <= F::NUM_BITS);
131 assert_eq!(points.len(), scalars.len());
132 assert!(!points.is_empty(), "fixed_base::msm_par requires at least one point");
133 let scalar_len = scalars[0].len();
134 let total_bits = max_scalar_bits_per_cell * scalar_len;
135 let num_windows = total_bits.div_ceil(window_bits);
136
137 let cached_points_jacobian = points
140 .par_iter()
141 .flat_map(|point| -> Vec<_> {
142 let base_pt = point.to_curve();
143 let mut increment = base_pt;
146 (0..num_windows)
147 .flat_map(|i| {
148 let mut curr = increment;
149 let cache_vec = std::iter::once(increment)
150 .chain((1..(1usize << min(window_bits, total_bits - i * window_bits))).map(
151 |_| {
152 let prev = curr;
153 curr += increment;
154 prev
155 },
156 ))
157 .collect::<Vec<_>>();
158 increment = curr;
159 cache_vec
160 })
161 .collect()
162 })
163 .collect::<Vec<_>>();
164 let mut cached_points_affine = vec![C::default(); cached_points_jacobian.len()];
167 C::Curve::batch_normalize(&cached_points_jacobian, &mut cached_points_affine);
168
169 let field_chip = chip.field_chip();
170 let ctx = builder.main();
171 let any_point = chip.load_random_point::<C>(ctx);
172
173 let scalar_mults = parallelize_core(
174 builder,
175 cached_points_affine
176 .chunks(cached_points_affine.len() / points.len())
177 .zip_eq(scalars)
178 .collect(),
179 |ctx, (cached_points, scalar)| {
180 let cached_points = cached_points
181 .iter()
182 .map(|point| chip.assign_constant_point(ctx, *point))
183 .collect_vec();
184 let cached_point_window_rev = cached_points.chunks(1usize << window_bits).rev();
185
186 assert_eq!(scalar.len(), scalar_len);
187 let bits = scalar
188 .into_iter()
189 .flat_map(|scalar_chunk| {
190 field_chip.gate().num_to_bits(ctx, scalar_chunk, max_scalar_bits_per_cell)
191 })
192 .collect::<Vec<_>>();
193 let bit_window_rev = bits.chunks(window_bits).rev();
194 let mut curr_point = any_point.clone();
195 for (cached_point_window, bit_window) in cached_point_window_rev.zip(bit_window_rev) {
196 let is_zero_window = {
197 let sum = field_chip.gate().sum(ctx, bit_window.iter().copied());
198 field_chip.gate().is_zero(ctx, sum)
199 };
200 curr_point = {
201 let add_point =
202 ec_select_from_bits(field_chip, ctx, cached_point_window, bit_window);
203 let sum = ec_add_unequal(field_chip, ctx, &curr_point, &add_point, true);
204 ec_select(field_chip, ctx, curr_point, sum, is_zero_window)
205 };
206 }
207 curr_point
208 },
209 );
210 let ctx = builder.main();
211 let any_point2 = chip.load_random_point::<C>(ctx);
213 let mut acc = any_point2.clone();
214 for point in scalar_mults {
215 let new_acc = chip.add_unequal(ctx, &acc, point, true);
216 acc = chip.sub_unequal(ctx, new_acc, &any_point, true);
217 }
218 ec_sub_strict(field_chip, ctx, acc, any_point2)
219}