1use crate::halo2_curves;
4use crate::util::Itertools;
5pub use halo2_curves::{
6 group::{
7 ff::{BatchInvert, Field, FromUniformBytes, PrimeField},
8 prime::PrimeCurveAffine,
9 Curve, Group, GroupEncoding,
10 },
11 Coordinates, CurveAffine, CurveExt,
12};
13use num_bigint::BigUint;
14use num_traits::One;
15pub use pairing::MillerLoopResult;
16use serde::{Deserialize, Serialize};
17use std::{
18 cmp::Ordering,
19 fmt::Debug,
20 iter, mem,
21 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
22};
23
24pub trait MultiMillerLoop: pairing::MultiMillerLoop + Debug {}
26
27impl<M: pairing::MultiMillerLoop + Debug> MultiMillerLoop for M {}
28
29pub trait FieldExt: PrimeField + FromUniformBytes<64> + Ord {}
31
32impl<F: PrimeField + FromUniformBytes<64> + Ord> FieldExt for F {}
33
34pub trait FieldOps:
36 Sized
37 + Neg<Output = Self>
38 + Add<Output = Self>
39 + Sub<Output = Self>
40 + Mul<Output = Self>
41 + for<'a> Add<&'a Self, Output = Self>
42 + for<'a> Sub<&'a Self, Output = Self>
43 + for<'a> Mul<&'a Self, Output = Self>
44 + AddAssign
45 + SubAssign
46 + MulAssign
47 + for<'a> AddAssign<&'a Self>
48 + for<'a> SubAssign<&'a Self>
49 + for<'a> MulAssign<&'a Self>
50{
51 fn invert(&self) -> Option<Self>;
53}
54
55pub fn batch_invert_and_mul<F: PrimeField>(values: &mut [F], coeff: &F) {
57 if values.is_empty() {
58 return;
59 }
60 let products = values
61 .iter()
62 .scan(F::ONE, |acc, value| {
63 *acc *= value;
64 Some(*acc)
65 })
66 .collect_vec();
67
68 let mut all_product_inv = Option::<F>::from(products.last().unwrap().invert())
69 .expect("Attempted to batch invert an array containing zero")
70 * coeff;
71
72 for (value, product) in
73 values.iter_mut().rev().zip(products.into_iter().rev().skip(1).chain(Some(F::ONE)))
74 {
75 let mut inv = all_product_inv * product;
76 mem::swap(value, &mut inv);
77 all_product_inv *= inv;
78 }
79}
80
81pub fn batch_invert<F: PrimeField>(values: &mut [F]) {
83 batch_invert_and_mul(values, &F::ONE)
84}
85
86pub fn root_of_unity<F: PrimeField>(k: usize) -> F {
94 assert!(k <= F::S as usize);
95
96 iter::successors(Some(F::ROOT_OF_UNITY), |acc| Some(acc.square()))
97 .take(F::S as usize - k + 1)
98 .last()
99 .unwrap()
100}
101
102#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)]
104pub struct Rotation(pub i32);
105
106impl Rotation {
107 pub fn cur() -> Self {
109 Rotation(0)
110 }
111
112 pub fn prev() -> Self {
114 Rotation(-1)
115 }
116
117 pub fn next() -> Self {
119 Rotation(1)
120 }
121}
122
123impl From<i32> for Rotation {
124 fn from(rotation: i32) -> Self {
125 Self(rotation)
126 }
127}
128
129#[derive(Clone, Debug, Serialize, Deserialize)]
131pub struct Domain<F: PrimeField> {
132 pub k: usize,
134 pub n: usize,
136 pub n_inv: F,
138 pub gen: F,
140 pub gen_inv: F,
142}
143
144impl<F: PrimeField> Domain<F> {
145 pub fn new(k: usize, gen: F) -> Self {
147 let n = 1 << k;
148 let n_inv = F::from(n as u64).invert().unwrap();
149 let gen_inv = gen.invert().unwrap();
150
151 Self { k, n, n_inv, gen, gen_inv }
152 }
153
154 pub fn rotate_scalar(&self, scalar: F, rotation: Rotation) -> F {
156 match rotation.0.cmp(&0) {
157 Ordering::Equal => scalar,
158 Ordering::Greater => scalar * self.gen.pow_vartime([rotation.0 as u64]),
159 Ordering::Less => scalar * self.gen_inv.pow_vartime([(-rotation.0) as u64]),
160 }
161 }
162}
163
164#[derive(Clone, Debug)]
166pub struct Fraction<T> {
167 numer: Option<T>,
168 denom: T,
169 eval: Option<T>,
170 inv: bool,
171}
172
173impl<T> Fraction<T> {
174 pub fn new(numer: T, denom: T) -> Self {
176 Self { numer: Some(numer), denom, eval: None, inv: false }
177 }
178
179 pub fn one_over(denom: T) -> Self {
181 Self { numer: None, denom, eval: None, inv: false }
182 }
183
184 pub fn denom(&self) -> Option<&T> {
186 if !self.inv {
187 Some(&self.denom)
188 } else {
189 None
190 }
191 }
192
193 #[must_use = "To be inverted"]
194 pub fn denom_mut(&mut self) -> Option<&mut T> {
196 if !self.inv {
197 self.inv = true;
198 Some(&mut self.denom)
199 } else {
200 None
201 }
202 }
203}
204
205impl<T: FieldOps + Clone> Fraction<T> {
206 pub fn evaluate(&mut self) {
212 assert!(self.inv);
213
214 if self.eval.is_none() {
215 self.eval = Some(
216 self.numer
217 .take()
218 .map(|numer| numer * &self.denom)
219 .unwrap_or_else(|| self.denom.clone()),
220 );
221 }
222 }
223
224 pub fn evaluated(&self) -> &T {
230 assert!(self.eval.is_some());
231
232 self.eval.as_ref().unwrap()
233 }
234}
235
236pub fn modulus<F: PrimeField>() -> BigUint {
238 fe_to_big(-F::ONE) + 1usize
239}
240
241pub fn fe_from_big<F: PrimeField>(big: BigUint) -> F {
243 let bytes = big.to_bytes_le();
244 let mut repr = F::Repr::default();
245 assert!(bytes.len() <= repr.as_ref().len());
246 repr.as_mut()[..bytes.len()].clone_from_slice(bytes.as_slice());
247 F::from_repr(repr).unwrap()
248}
249
250pub fn fe_to_big<F: PrimeField>(fe: F) -> BigUint {
252 BigUint::from_bytes_le(fe.to_repr().as_ref())
253}
254
255pub fn fe_to_fe<F1: PrimeField, F2: PrimeField>(fe: F1) -> F2 {
257 fe_from_big(fe_to_big(fe) % modulus::<F2>())
258}
259
260pub fn fe_from_limbs<F1: PrimeField, F2: PrimeField, const LIMBS: usize, const BITS: usize>(
263 limbs: [F1; LIMBS],
264) -> F2 {
265 fe_from_big(
266 limbs
267 .iter()
268 .map(|limb| BigUint::from_bytes_le(limb.to_repr().as_ref()))
269 .zip((0usize..).step_by(BITS))
270 .map(|(limb, shift)| limb << shift)
271 .reduce(|acc, shifted| acc + shifted)
272 .unwrap(),
273 )
274}
275
276pub fn fe_to_limbs<F1: PrimeField, F2: PrimeField, const LIMBS: usize, const BITS: usize>(
279 fe: F1,
280) -> [F2; LIMBS] {
281 let big = BigUint::from_bytes_le(fe.to_repr().as_ref());
282 let mask = &((BigUint::one() << BITS) - 1usize);
283 (0usize..)
284 .step_by(BITS)
285 .take(LIMBS)
286 .map(|shift| fe_from_big((&big >> shift) & mask))
287 .collect_vec()
288 .try_into()
289 .unwrap()
290}
291
292pub fn powers<F: Field>(scalar: F) -> impl Iterator<Item = F> {
294 iter::successors(Some(F::ONE), move |power| Some(scalar * power))
295}
296
297pub fn inner_product<F: Field>(lhs: &[F], rhs: &[F]) -> F {
299 lhs.iter()
300 .zip_eq(rhs.iter())
301 .map(|(lhs, rhs)| *lhs * rhs)
302 .reduce(|acc, product| acc + product)
303 .unwrap_or_default()
304}