1mod arith;
2#[cfg(feature = "asm")]
3mod asm;
4
5use num_bigint::BigUint;
6use num_integer::Integer;
7use num_traits::{Num, One};
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::quote;
11use syn::Token;
12
13struct FieldConfig {
14 identifier: String,
15 field: syn::Ident,
16 modulus: BigUint,
17 mul_gen: BigUint,
18 zeta: BigUint,
19 endian: String,
20 from_uniform: Vec<usize>,
21}
22
23impl syn::parse::Parse for FieldConfig {
24 fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
25 let identifier: syn::Ident = input.parse()?;
26 let identifier = identifier.to_string();
27 input.parse::<syn::Token![,]>()?;
28
29 let field: syn::Ident = input.parse()?;
30 input.parse::<syn::Token![,]>()?;
31
32 let get_big = |is_key: &str| -> Result<BigUint, syn::Error> {
33 let key: syn::Ident = input.parse()?;
34 assert_eq!(key.to_string(), is_key);
35 input.parse::<Token![=]>()?;
36 let n: syn::LitStr = input.parse()?;
37 let n = BigUint::from_str_radix(&n.value(), 16)
38 .map_err(|err| syn::Error::new(Span::call_site(), err.to_string()))?;
39 input.parse::<Token![,]>()?;
40 Ok(n)
41 };
42
43 let get_str = |is_key: &str| -> Result<String, syn::Error> {
44 let key: syn::Ident = input.parse()?;
45 assert_eq!(key.to_string(), is_key);
46 input.parse::<Token![=]>()?;
47 let n: syn::LitStr = input.parse()?;
48 let n = n.value();
49 input.parse::<Token![,]>()?;
50 Ok(n)
51 };
52
53 let get_usize_list = |is_key: &str| -> Result<Vec<usize>, syn::Error> {
54 let key: syn::Ident = input.parse()?;
55 assert_eq!(key.to_string(), is_key);
56 input.parse::<Token![=]>()?;
57
58 let content;
60 syn::bracketed!(content in input);
61 let punctuated: syn::punctuated::Punctuated<syn::LitInt, Token![,]> =
62 content.parse_terminated(syn::LitInt::parse)?;
63 let values = punctuated
64 .into_iter()
65 .map(|lit| lit.base10_parse::<usize>())
66 .collect::<Result<Vec<_>, _>>()?;
67 input.parse::<Token![,]>()?;
68 Ok(values)
69 };
70
71 let modulus = get_big("modulus")?;
72 let mul_gen = get_big("mul_gen")?;
73 let zeta = get_big("zeta")?;
74 let from_uniform = get_usize_list("from_uniform")?;
75 let endian = get_str("endian")?;
76 assert!(endian == "little" || endian == "big");
77 assert!(input.is_empty());
78
79 Ok(FieldConfig {
80 identifier,
81 field,
82 modulus,
83 mul_gen,
84 zeta,
85 from_uniform,
86 endian,
87 })
88 }
89}
90
91pub(crate) fn impl_field(input: TokenStream) -> TokenStream {
92 use crate::utils::{big_to_token, mod_inv};
93 let FieldConfig {
94 identifier,
95 field,
96 modulus,
97 mul_gen,
98 zeta,
99 from_uniform,
100 endian,
101 } = syn::parse_macro_input!(input as FieldConfig);
102 let _ = identifier;
103
104 let num_bits = modulus.bits() as u32;
105 let limb_size = 64;
106 let num_limbs = ((num_bits - 1) / limb_size + 1) as usize;
107 let size = num_limbs * 8;
108 let modulus_limbs = crate::utils::big_to_limbs(&modulus, num_limbs);
109 let modulus_str = format!("0x{}", modulus.to_str_radix(16));
110 let modulus_limbs_ident = quote! {[#(#modulus_limbs,)*]};
111
112 let modulus_limbs_32 = crate::utils::big_to_limbs_32(&modulus, num_limbs * 2);
113 let modulus_limbs_32_ident = quote! {[#(#modulus_limbs_32,)*]};
114
115 let to_token = |e: &BigUint| big_to_token(e, num_limbs);
116 let half_modulus = (&modulus - 1usize) >> 1;
117 let half_modulus = to_token(&half_modulus);
118
119 let t = BigUint::from(1u64) << (num_limbs * limb_size as usize);
121 let r1: BigUint = &t % &modulus;
123 let mont = |v: &BigUint| (v * &r1) % &modulus;
124 let r2: BigUint = (&r1 * &r1) % &modulus;
126 let r3: BigUint = (&r1 * &r1 * &r1) % &modulus;
128
129 let r1 = to_token(&r1);
130 let r2 = to_token(&r2);
131 let r3 = to_token(&r3);
132
133 let mut inv64 = 1u64;
135 for _ in 0..63 {
136 inv64 = inv64.wrapping_mul(inv64);
137 inv64 = inv64.wrapping_mul(modulus_limbs[0]);
138 }
139 inv64 = inv64.wrapping_neg();
140
141 let mut by_inverter_constant: usize = 2;
142 loop {
143 let t = BigUint::from(1u64) << (62 * by_inverter_constant - 64);
144 if t > modulus {
145 break;
146 }
147 by_inverter_constant += 1;
148 }
149
150 let mut jacobi_constant: usize = 1;
151 loop {
152 let t = BigUint::from(1u64) << (64 * jacobi_constant - 31);
153 if t > modulus {
154 break;
155 }
156 jacobi_constant += 1;
157 }
158
159 let mut s: u32 = 0;
160 let mut t = &modulus - BigUint::one();
161 while t.is_even() {
162 t >>= 1;
163 s += 1;
164 }
165
166 let two_inv = mod_inv(&BigUint::from(2usize), &modulus);
167
168 let sqrt_impl = {
169 if &modulus % 16u64 == BigUint::from(1u64) {
170 let tm1o2 = ((&t - 1usize) * &two_inv) % &modulus;
171 let tm1o2 = big_to_token(&tm1o2, num_limbs);
172 quote! {
173 fn sqrt(&self) -> subtle::CtOption<Self> {
174 ff::helpers::sqrt_tonelli_shanks(self, #tm1o2)
175 }
176 }
177 } else if &modulus % 4u64 == BigUint::from(3u64) {
178 let exp = (&modulus + 1usize) >> 2;
179 let exp = big_to_token(&exp, num_limbs);
180 quote! {
181 fn sqrt(&self) -> subtle::CtOption<Self> {
182 use subtle::ConstantTimeEq;
183 let t = self.pow(#exp);
184 subtle::CtOption::new(t, t.square().ct_eq(self))
185 }
186 }
187 } else {
188 panic!("unsupported modulus")
189 }
190 };
191
192 let root_of_unity = mul_gen.modpow(&t, &modulus);
193 let root_of_unity_inv = mod_inv(&root_of_unity, &modulus);
194 let delta = mul_gen.modpow(&(BigUint::one() << s), &modulus);
195
196 let root_of_unity = to_token(&mont(&root_of_unity));
197 let root_of_unity_inv = to_token(&mont(&root_of_unity_inv));
198 let two_inv = to_token(&mont(&two_inv));
199 let mul_gen = to_token(&mont(&mul_gen));
200 let delta = to_token(&mont(&delta));
201 let zeta = to_token(&mont(&zeta));
202
203 let endian = match endian.as_str() {
204 "little" => {
205 quote! { LE }
206 }
207 "big" => {
208 quote! { BE }
209 }
210 _ => {
211 unreachable!()
212 }
213 };
214
215 let impl_field = quote! {
216 #[derive(Clone, Copy, PartialEq, Eq, Hash, Default)]
217 pub struct #field(pub(crate) [u64; #num_limbs]);
218
219 impl core::fmt::Debug for #field {
220 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
221 use ff::PrimeField;
222 let tmp = self.to_repr();
223 write!(f, "0x")?;
224 for &b in tmp.as_ref().iter().rev() {
225 write!(f, "{:02x}", b)?;
226 }
227 Ok(())
228 }
229 }
230
231 impl ConstantTimeEq for #field {
232 fn ct_eq(&self, other: &Self) -> Choice {
233 Choice::from(
234 self.0
235 .iter()
236 .zip(other.0)
237 .all(|(a, b)| bool::from(a.ct_eq(&b))) as u8,
238 )
239 }
240 }
241
242 impl ConditionallySelectable for #field {
243 fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
244 let limbs = (0..#num_limbs)
245 .map(|i| u64::conditional_select(&a.0[i], &b.0[i], choice))
246 .collect::<Vec<_>>()
247 .try_into()
248 .unwrap();
249 #field(limbs)
250 }
251 }
252
253 impl core::cmp::PartialOrd for #field {
254 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
255 Some(self.cmp(other))
256 }
257 }
258
259 impl core::cmp::Ord for #field {
260 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
261 use ff::PrimeField;
262 let left = self.to_repr();
263 let right = other.to_repr();
264 left.as_ref().iter()
265 .zip(right.as_ref().iter())
266 .rev()
267 .find_map(|(left_byte, right_byte)| match left_byte.cmp(right_byte) {
268 core::cmp::Ordering::Equal => None,
269 res => Some(res),
270 })
271 .unwrap_or(core::cmp::Ordering::Equal)
272 }
273 }
274
275 impl<T: ::core::borrow::Borrow<#field>> ::core::iter::Sum<T> for #field {
276 fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
277 iter.fold(Self::zero(), |acc, item| acc + item.borrow())
278 }
279 }
280
281 impl<T: ::core::borrow::Borrow<#field>> ::core::iter::Product<T> for #field {
282 fn product<I: Iterator<Item = T>>(iter: I) -> Self {
283 iter.fold(Self::one(), |acc, item| acc * item.borrow())
284 }
285 }
286
287 impl crate::serde::endian::EndianRepr for #field {
288 const ENDIAN: crate::serde::endian::Endian = crate::serde::endian::Endian::#endian;
289
290 fn to_bytes(&self) -> Vec<u8> {
291 self.to_bytes().to_vec()
292 }
293
294 fn from_bytes(bytes: &[u8]) -> subtle::CtOption<Self> {
295 #field::from_bytes(bytes[..#field::SIZE].try_into().unwrap())
296 }
297 }
298
299 impl #field {
300 pub const SIZE: usize = #num_limbs * 8;
301 pub const NUM_LIMBS: usize = #num_limbs;
302 pub(crate) const MODULUS_LIMBS: [u64; Self::NUM_LIMBS] = #modulus_limbs_ident;
303 pub(crate) const MODULUS_LIMBS_32: [u32; Self::NUM_LIMBS*2] = #modulus_limbs_32_ident;
304 const R: Self = Self(#r1);
305 const R2: Self = Self(#r2);
306 const R3: Self = Self(#r3);
307
308 #[inline(always)]
310 pub const fn zero() -> #field {
311 #field([0; Self::NUM_LIMBS])
312 }
313
314 #[inline(always)]
316 pub const fn one() -> #field {
317 Self::R
318 }
319
320 pub const fn from_raw(val: [u64; Self::NUM_LIMBS]) -> Self {
323 Self(val).mul_const(&Self::R2)
324 }
325
326 pub fn from_bytes(bytes: &[u8; Self::SIZE]) -> subtle::CtOption<Self> {
329 use crate::serde::endian::EndianRepr;
330 let mut el = #field::default();
331 #field::ENDIAN.from_bytes(bytes, &mut el.0);
332 subtle::CtOption::new(el * Self::R2, subtle::Choice::from(Self::is_less_than_modulus(&el.0) as u8))
333 }
334
335
336 pub fn to_bytes(&self) -> [u8; Self::SIZE] {
339 use crate::serde::endian::EndianRepr;
340 let el = self.from_mont();
341 let mut res = [0; Self::SIZE];
342 #field::ENDIAN.to_bytes(&mut res, &el);
343 res.into()
344 }
345
346
347 #[inline(always)]
353 fn jacobi(&self) -> i64 {
354 crate::ff_ext::jacobi::jacobi::<#jacobi_constant>(&self.0, &#modulus_limbs_ident)
355 }
356
357
358 #[inline(always)]
359 pub(crate) fn is_less_than_modulus(limbs: &[u64; Self::NUM_LIMBS]) -> bool {
360 let borrow = limbs.iter().enumerate().fold(0, |borrow, (i, limb)| {
361 crate::arithmetic::sbb(*limb, Self::MODULUS_LIMBS[i], borrow).1
362 });
363 (borrow as u8) & 1 == 1
364 }
365
366 pub fn lexicographically_largest(&self) -> Choice {
369 const HALF_MODULUS: [u64; #num_limbs]= #half_modulus;
370 let tmp = self.from_mont();
371 let borrow = tmp
372 .into_iter()
373 .zip(HALF_MODULUS.into_iter())
374 .fold(0, |borrow, (t, m)| crate::arithmetic::sbb(t, m, borrow).1);
375 !Choice::from((borrow as u8) & 1)
376 }
377 }
378
379 impl ff::Field for #field {
380 const ZERO: Self = Self::zero();
381 const ONE: Self = Self::one();
382
383 fn random(mut rng: impl RngCore) -> Self {
384 let mut wide = [0u8; Self::SIZE * 2];
385 rng.fill_bytes(&mut wide);
386 <#field as ff::FromUniformBytes<{ #field::SIZE * 2 }>>::from_uniform_bytes(&wide)
387 }
388
389 #[inline(always)]
390 #[must_use]
391 fn double(&self) -> Self {
392 self.double()
393 }
394
395 #[inline(always)]
396 #[must_use]
397 fn square(&self) -> Self {
398 self.square()
399 }
400
401 #[inline(always)]
403 fn invert(&self) -> CtOption<Self> {
404 const BYINVERTOR: crate::ff_ext::inverse::BYInverter<#by_inverter_constant> =
405 crate::ff_ext::inverse::BYInverter::<#by_inverter_constant>::new(&#modulus_limbs_ident, &#r2);
406
407 if let Some(inverse) = BYINVERTOR.invert::<{ Self::NUM_LIMBS }>(&self.0) {
408 subtle::CtOption::new(Self(inverse), subtle::Choice::from(1))
409 } else {
410 subtle::CtOption::new(Self::zero(), subtle::Choice::from(0))
411 }
412 }
413
414 #sqrt_impl
415
416 fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
417 ff::helpers::sqrt_ratio_generic(num, div)
418 }
419 }
420 };
421
422 let impl_prime_field = quote! {
423
424 impl From<#field> for crate::serde::Repr<{ #field::SIZE }> {
426 fn from(value: #field) -> crate::serde::Repr<{ #field::SIZE }> {
427 use ff::PrimeField;
428 value.to_repr()
429 }
430 }
431
432 impl<'a> From<&'a #field> for crate::serde::Repr<{ #field::SIZE }> {
433 fn from(value: &'a #field) -> crate::serde::Repr<{ #field::SIZE }> {
434 use ff::PrimeField;
435 value.to_repr()
436 }
437 }
438
439 impl ff::PrimeField for #field {
440 const NUM_BITS: u32 = #num_bits;
441 const CAPACITY: u32 = #num_bits-1;
442 const TWO_INV :Self = Self(#two_inv);
443 const MULTIPLICATIVE_GENERATOR: Self = Self(#mul_gen);
444 const S: u32 = #s;
445 const ROOT_OF_UNITY: Self = Self(#root_of_unity);
446 const ROOT_OF_UNITY_INV: Self = Self(#root_of_unity_inv);
447 const DELTA: Self = Self(#delta);
448 const MODULUS: &'static str = #modulus_str;
449
450 type Repr = crate::serde::Repr<{ #field::SIZE }>;
451
452 fn from_u128(v: u128) -> Self {
453 Self::R2 * Self(
454 [v as u64, (v >> 64) as u64]
455 .into_iter()
456 .chain(std::iter::repeat(0))
457 .take(Self::NUM_LIMBS)
458 .collect::<Vec<_>>()
459 .try_into()
460 .unwrap(),
461 )
462 }
463
464 fn from_repr(repr: Self::Repr) -> subtle::CtOption<Self> {
465 let mut el = #field::default();
466 crate::serde::endian::Endian::LE.from_bytes(repr.as_ref(), &mut el.0);
467 subtle::CtOption::new(el * Self::R2, subtle::Choice::from(Self::is_less_than_modulus(&el.0) as u8))
468 }
469
470 fn to_repr(&self) -> Self::Repr {
471 use crate::serde::endian::Endian;
472 let el = self.from_mont();
473 let mut res = [0; #size];
474 crate::serde::endian::Endian::LE.to_bytes(&mut res, &el);
475 res.into()
476 }
477
478 fn is_odd(&self) -> Choice {
479 Choice::from(self.to_repr()[0] & 1)
480 }
481 }
482 };
483
484 let impl_serde_object = quote! {
485 impl crate::serde::SerdeObject for #field {
486 fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self {
487 debug_assert_eq!(bytes.len(), #size);
488
489 let inner = (0..#num_limbs)
490 .map(|off| {
491 u64::from_le_bytes(bytes[off * 8..(off + 1) * 8].try_into().unwrap())
492 })
493 .collect::<Vec<_>>();
494 Self(inner.try_into().unwrap())
495 }
496
497 fn from_raw_bytes(bytes: &[u8]) -> Option<Self> {
498 if bytes.len() != #size {
499 return None;
500 }
501 let elt = Self::from_raw_bytes_unchecked(bytes);
502 Self::is_less_than_modulus(&elt.0).then(|| elt)
503 }
504
505 fn to_raw_bytes(&self) -> Vec<u8> {
506 let mut res = Vec::with_capacity(#num_limbs * 4);
507 for limb in self.0.iter() {
508 res.extend_from_slice(&limb.to_le_bytes());
509 }
510 res
511 }
512
513 fn read_raw_unchecked<R: std::io::Read>(reader: &mut R) -> Self {
514 let inner = [(); #num_limbs].map(|_| {
515 let mut buf = [0; 8];
516 reader.read_exact(&mut buf).unwrap();
517 u64::from_le_bytes(buf)
518 });
519 Self(inner)
520 }
521
522 fn read_raw<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
523 let mut inner = [0u64; #num_limbs];
524 for limb in inner.iter_mut() {
525 let mut buf = [0; 8];
526 reader.read_exact(&mut buf)?;
527 *limb = u64::from_le_bytes(buf);
528 }
529 let elt = Self(inner);
530 Self::is_less_than_modulus(&elt.0)
531 .then(|| elt)
532 .ok_or_else(|| {
533 std::io::Error::new(
534 std::io::ErrorKind::InvalidData,
535 "input number is not less than field modulus",
536 )
537 })
538 }
539 fn write_raw<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
540 for limb in self.0.iter() {
541 writer.write_all(&limb.to_le_bytes())?;
542 }
543 Ok(())
544 }
545 }
546 };
547
548 #[cfg(feature = "asm")]
549 let impl_arith = {
550 if num_limbs == 4 && num_bits < 256 {
551 println!("implementing asm, {}", identifier);
552 asm::limb4::impl_arith(&field, inv64)
553 } else {
554 arith::impl_arith(&field, num_limbs, inv64)
555 }
556 };
557 #[cfg(not(feature = "asm"))]
558 let impl_arith = arith::impl_arith(&field, num_limbs, inv64);
559
560 let impl_arith_always_const = arith::impl_arith_always_const(&field, num_limbs, inv64);
561
562 let impl_from_uniform_bytes = from_uniform
563 .iter()
564 .map(|input_size| {
565 assert!(*input_size >= size);
566 assert!(*input_size <= size*2);
567 quote! {
568 impl ff::FromUniformBytes<#input_size> for #field {
569 fn from_uniform_bytes(bytes: &[u8; #input_size]) -> Self {
570 let mut wide = [0u8; Self::SIZE * 2];
571 wide[..#input_size].copy_from_slice(bytes);
572 let (a0, a1) = wide.split_at(Self::SIZE);
573
574 let a0: [u64; Self::NUM_LIMBS] = (0..Self::NUM_LIMBS)
575 .map(|off| u64::from_le_bytes(a0[off * 8..(off + 1) * 8].try_into().unwrap()))
576 .collect::<Vec<_>>()
577 .try_into()
578 .unwrap();
579 let a0 = #field(a0);
580
581 let a1: [u64; Self::NUM_LIMBS] = (0..Self::NUM_LIMBS)
582 .map(|off| u64::from_le_bytes(a1[off * 8..(off + 1) * 8].try_into().unwrap()))
583 .collect::<Vec<_>>()
584 .try_into()
585 .unwrap();
586 let a1 = #field(a1);
587
588 a0.mul_const(&Self::R2) + a1.mul_const(&Self::R3)
590
591 }
592 }
593 }
594 })
595 .collect::<proc_macro2::TokenStream>();
596
597 let impl_zeta = quote! {
598 impl ff::WithSmallOrderMulGroup<3> for #field {
599 const ZETA: Self = Self(#zeta);
600 }
601 };
602
603 let output = quote! {
604 #impl_arith
605 #impl_arith_always_const
606 #impl_field
607 #impl_prime_field
608 #impl_serde_object
609 #impl_from_uniform_bytes
610 #impl_zeta
611 };
612
613 output.into()
614}