ff_derive/
lib.rs

1#![recursion_limit = "1024"]
2
3extern crate proc_macro;
4extern crate proc_macro2;
5
6use num_bigint::BigUint;
7use num_integer::Integer;
8use num_traits::{One, ToPrimitive, Zero};
9use quote::quote;
10use quote::TokenStreamExt;
11use std::iter;
12use std::str::FromStr;
13
14mod pow_fixed;
15
16enum ReprEndianness {
17    Big,
18    Little,
19}
20
21impl FromStr for ReprEndianness {
22    type Err = ();
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        match s {
26            "big" => Ok(ReprEndianness::Big),
27            "little" => Ok(ReprEndianness::Little),
28            _ => Err(()),
29        }
30    }
31}
32
33impl ReprEndianness {
34    fn modulus_repr(&self, modulus: &BigUint, bytes: usize) -> Vec<u8> {
35        match self {
36            ReprEndianness::Big => {
37                let buf = modulus.to_bytes_be();
38                iter::repeat(0)
39                    .take(bytes - buf.len())
40                    .chain(buf.into_iter())
41                    .collect()
42            }
43            ReprEndianness::Little => {
44                let mut buf = modulus.to_bytes_le();
45                buf.extend(iter::repeat(0).take(bytes - buf.len()));
46                buf
47            }
48        }
49    }
50
51    fn from_repr(&self, name: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
52        let read_repr = match self {
53            ReprEndianness::Big => quote! {
54                ::ff::derive::byteorder::BigEndian::read_u64_into(r.as_ref(), &mut inner[..]);
55                inner.reverse();
56            },
57            ReprEndianness::Little => quote! {
58                ::ff::derive::byteorder::LittleEndian::read_u64_into(r.as_ref(), &mut inner[..]);
59            },
60        };
61
62        quote! {
63            use ::ff::derive::byteorder::ByteOrder;
64
65            let r = {
66                let mut inner = [0u64; #limbs];
67                #read_repr
68                #name(inner)
69            };
70        }
71    }
72
73    fn to_repr(
74        &self,
75        repr: proc_macro2::TokenStream,
76        mont_reduce_self_params: &proc_macro2::TokenStream,
77        limbs: usize,
78    ) -> proc_macro2::TokenStream {
79        let bytes = limbs * 8;
80
81        let write_repr = match self {
82            ReprEndianness::Big => quote! {
83                r.0.reverse();
84                ::ff::derive::byteorder::BigEndian::write_u64_into(&r.0, &mut repr[..]);
85            },
86            ReprEndianness::Little => quote! {
87                ::ff::derive::byteorder::LittleEndian::write_u64_into(&r.0, &mut repr[..]);
88            },
89        };
90
91        quote! {
92            use ::ff::derive::byteorder::ByteOrder;
93
94            let mut r = *self;
95            r.mont_reduce(
96                #mont_reduce_self_params
97            );
98
99            let mut repr = [0u8; #bytes];
100            #write_repr
101            #repr(repr)
102        }
103    }
104
105    fn iter_be(&self) -> proc_macro2::TokenStream {
106        match self {
107            ReprEndianness::Big => quote! {self.0.iter()},
108            ReprEndianness::Little => quote! {self.0.iter().rev()},
109        }
110    }
111}
112
113/// Derive the `PrimeField` trait.
114#[proc_macro_derive(
115    PrimeField,
116    attributes(PrimeFieldModulus, PrimeFieldGenerator, PrimeFieldReprEndianness)
117)]
118pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
119    // Parse the type definition
120    let ast: syn::DeriveInput = syn::parse(input).unwrap();
121
122    // We're given the modulus p of the prime field
123    let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
124        .expect("Please supply a PrimeFieldModulus attribute")
125        .parse()
126        .expect("PrimeFieldModulus should be a number");
127
128    // We may be provided with a generator of p - 1 order. It is required that this generator be quadratic
129    // nonresidue.
130    // TODO: Compute this ourselves.
131    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
132        .expect("Please supply a PrimeFieldGenerator attribute")
133        .parse()
134        .expect("PrimeFieldGenerator should be a number");
135
136    // Field element representations may be in little-endian or big-endian.
137    let endianness = fetch_attr("PrimeFieldReprEndianness", &ast.attrs)
138        .expect("Please supply a PrimeFieldReprEndianness attribute")
139        .parse()
140        .expect("PrimeFieldReprEndianness should be 'big' or 'little'");
141
142    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
143    // representation. Compute the number of limbs we need.
144    let mut limbs = 1;
145    {
146        let mod2 = (&modulus) << 1; // modulus * 2
147        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
148        while cur < mod2 {
149            limbs += 1;
150            cur <<= 64;
151        }
152    }
153
154    // The struct we're deriving for must be a wrapper around `pub [u64; limbs]`.
155    if let Some(err) = validate_struct(&ast, limbs) {
156        return err.into();
157    }
158
159    // Generate the identifier for the "Repr" type we must construct.
160    let repr_ident = syn::Ident::new(
161        &format!("{}Repr", ast.ident),
162        proc_macro2::Span::call_site(),
163    );
164
165    let mut gen = proc_macro2::TokenStream::new();
166
167    let (constants_impl, sqrt_impl) =
168        prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, generator);
169
170    gen.extend(constants_impl);
171    gen.extend(prime_field_repr_impl(&repr_ident, &endianness, limbs * 8));
172    gen.extend(prime_field_impl(
173        &ast.ident,
174        &repr_ident,
175        &modulus,
176        &endianness,
177        limbs,
178        sqrt_impl,
179    ));
180
181    // Return the generated impl
182    gen.into()
183}
184
185/// Checks that `body` contains `pub [u64; limbs]`.
186fn validate_struct(ast: &syn::DeriveInput, limbs: usize) -> Option<proc_macro2::TokenStream> {
187    // The body should be a struct.
188    let variant_data = match &ast.data {
189        syn::Data::Struct(x) => x,
190        _ => {
191            return Some(
192                syn::Error::new_spanned(ast, "PrimeField derive only works for structs.")
193                    .to_compile_error(),
194            )
195        }
196    };
197
198    // The struct should contain a single unnamed field.
199    let fields = match &variant_data.fields {
200        syn::Fields::Unnamed(x) if x.unnamed.len() == 1 => x,
201        _ => {
202            return Some(
203                syn::Error::new_spanned(
204                    &ast.ident,
205                    format!(
206                        "The struct must contain an array of limbs. Change this to `{}([u64; {}])`",
207                        ast.ident, limbs,
208                    ),
209                )
210                .to_compile_error(),
211            )
212        }
213    };
214    let field = &fields.unnamed[0];
215
216    // The field should be an array.
217    let arr = match &field.ty {
218        syn::Type::Array(x) => x,
219        _ => {
220            return Some(
221                syn::Error::new_spanned(
222                    field,
223                    format!(
224                        "The inner field must be an array of limbs. Change this to `[u64; {}]`",
225                        limbs,
226                    ),
227                )
228                .to_compile_error(),
229            )
230        }
231    };
232
233    // The array's element type should be `u64`.
234    if match arr.elem.as_ref() {
235        syn::Type::Path(path) => path
236            .path
237            .get_ident()
238            .map(|x| x.to_string() != "u64")
239            .unwrap_or(true),
240        _ => true,
241    } {
242        return Some(
243            syn::Error::new_spanned(
244                arr,
245                format!(
246                    "PrimeField derive requires 64-bit limbs. Change this to `[u64; {}]",
247                    limbs
248                ),
249            )
250            .to_compile_error(),
251        );
252    }
253
254    // The array's length should be a literal int equal to `limbs`.
255    let expr_lit = match &arr.len {
256        syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
257        syn::Expr::Group(expr_group) => match &*expr_group.expr {
258            syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
259            _ => None,
260        },
261        _ => None,
262    };
263    let lit_int = match match expr_lit {
264        Some(syn::Lit::Int(lit_int)) => Some(lit_int),
265        _ => None,
266    } {
267        Some(x) => x,
268        _ => {
269            return Some(
270                syn::Error::new_spanned(
271                    arr,
272                    format!("To derive PrimeField, change this to `[u64; {}]`.", limbs),
273                )
274                .to_compile_error(),
275            )
276        }
277    };
278    if lit_int.base10_digits() != limbs.to_string() {
279        return Some(
280            syn::Error::new_spanned(
281                lit_int,
282                format!("The given modulus requires {} limbs.", limbs),
283            )
284            .to_compile_error(),
285        );
286    }
287
288    // The field should not be public.
289    match &field.vis {
290        syn::Visibility::Inherited => (),
291        _ => {
292            return Some(
293                syn::Error::new_spanned(&field.vis, "Field must not be public.").to_compile_error(),
294            )
295        }
296    }
297
298    // Valid!
299    None
300}
301
302/// Fetch an attribute string from the derived struct.
303fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
304    for attr in attrs {
305        if let Ok(meta) = attr.parse_meta() {
306            match meta {
307                syn::Meta::NameValue(nv) => {
308                    if nv.path.get_ident().map(|i| i.to_string()) == Some(name.to_string()) {
309                        match nv.lit {
310                            syn::Lit::Str(ref s) => return Some(s.value()),
311                            _ => {
312                                panic!("attribute {} should be a string", name);
313                            }
314                        }
315                    }
316                }
317                _ => {
318                    panic!("attribute {} should be a string", name);
319                }
320            }
321        }
322    }
323
324    None
325}
326
327// Implement the wrapped ident `repr` with `bytes` bytes.
328fn prime_field_repr_impl(
329    repr: &syn::Ident,
330    endianness: &ReprEndianness,
331    bytes: usize,
332) -> proc_macro2::TokenStream {
333    let repr_iter_be = endianness.iter_be();
334
335    quote! {
336        #[derive(Copy, Clone)]
337        pub struct #repr(pub [u8; #bytes]);
338
339        impl ::ff::derive::subtle::ConstantTimeEq for #repr {
340            fn ct_eq(&self, other: &#repr) -> ::ff::derive::subtle::Choice {
341                self.0
342                    .iter()
343                    .zip(other.0.iter())
344                    .map(|(a, b)| a.ct_eq(b))
345                    .fold(1.into(), |acc, x| acc & x)
346            }
347        }
348
349        impl ::core::cmp::PartialEq for #repr {
350            fn eq(&self, other: &#repr) -> bool {
351                use ::ff::derive::subtle::ConstantTimeEq;
352                self.ct_eq(other).into()
353            }
354        }
355
356        impl ::core::cmp::Eq for #repr { }
357
358        impl ::core::default::Default for #repr {
359            fn default() -> #repr {
360                #repr([0u8; #bytes])
361            }
362        }
363
364        impl ::core::fmt::Debug for #repr
365        {
366            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
367                write!(f, "0x")?;
368                for i in #repr_iter_be {
369                    write!(f, "{:02x}", *i)?;
370                }
371
372                Ok(())
373            }
374        }
375
376        impl AsRef<[u8]> for #repr {
377            #[inline(always)]
378            fn as_ref(&self) -> &[u8] {
379                &self.0
380            }
381        }
382
383        impl AsMut<[u8]> for #repr {
384            #[inline(always)]
385            fn as_mut(&mut self) -> &mut [u8] {
386                &mut self.0
387            }
388        }
389    }
390}
391
392/// Convert BigUint into a vector of 64-bit limbs.
393fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
394    let m = BigUint::one() << 64;
395    let mut ret = vec![];
396
397    while v > BigUint::zero() {
398        let limb: BigUint = &v % &m;
399        ret.push(limb.to_u64().unwrap());
400        v >>= 64;
401    }
402
403    while ret.len() < limbs {
404        ret.push(0);
405    }
406
407    assert!(ret.len() == limbs);
408
409    ret
410}
411
412/// Convert BigUint into a tokenized vector of 64-bit limbs.
413fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
414    let ret = biguint_to_real_u64_vec(v, limbs);
415    quote!([#(#ret,)*])
416}
417
418fn biguint_num_bits(mut v: BigUint) -> u32 {
419    let mut bits = 0;
420
421    while v != BigUint::zero() {
422        v >>= 1;
423        bits += 1;
424    }
425
426    bits
427}
428
429/// BigUint modular exponentiation by square-and-multiply.
430fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
431    let mut ret = BigUint::one();
432
433    for i in exp
434        .to_bytes_be()
435        .into_iter()
436        .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
437    {
438        ret = (&ret * &ret) % modulus;
439        if i {
440            ret = (ret * &base) % modulus;
441        }
442    }
443
444    ret
445}
446
447#[test]
448fn test_exp() {
449    assert_eq!(
450        exp(
451            BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
452            &BigUint::from_str("5489673498567349856734895").unwrap(),
453            &BigUint::from_str(
454                "52435875175126190479447740508185965837690552500527637822603658699938581184513"
455            )
456            .unwrap()
457        ),
458        BigUint::from_str(
459            "4371221214068404307866768905142520595925044802278091865033317963560480051536"
460        )
461        .unwrap()
462    );
463}
464
465fn prime_field_constants_and_sqrt(
466    name: &syn::Ident,
467    modulus: &BigUint,
468    limbs: usize,
469    generator: BigUint,
470) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
471    let bytes = limbs * 8;
472    let modulus_num_bits = biguint_num_bits(modulus.clone());
473
474    // The number of bits we should "shave" from a randomly sampled reputation, i.e.,
475    // if our modulus is 381 bits and our representation is 384 bits, we should shave
476    // 3 bits from the beginning of a randomly sampled 384 bit representation to
477    // reduce the cost of rejection sampling.
478    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
479
480    // Compute R = 2**(64 * limbs) mod m
481    let r = (BigUint::one() << (limbs * 64)) % modulus;
482    let to_mont = |v| (v * &r) % modulus;
483
484    let two = BigUint::from_str("2").unwrap();
485    let p_minus_2 = modulus - &two;
486    let invert = |v| exp(v, &p_minus_2, &modulus);
487
488    // 2^-1 mod m
489    let two_inv = biguint_to_u64_vec(to_mont(invert(two)), limbs);
490
491    // modulus - 1 = 2^s * t
492    let mut s: u32 = 0;
493    let mut t = modulus - BigUint::from_str("1").unwrap();
494    while t.is_even() {
495        t >>= 1;
496        s += 1;
497    }
498
499    // Compute 2^s root of unity given the generator
500    let root_of_unity = exp(generator.clone(), &t, &modulus);
501    let root_of_unity_inv = biguint_to_u64_vec(to_mont(invert(root_of_unity.clone())), limbs);
502    let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity), limbs);
503    let delta = biguint_to_u64_vec(
504        to_mont(exp(generator.clone(), &(BigUint::one() << s), &modulus)),
505        limbs,
506    );
507    let generator = biguint_to_u64_vec(to_mont(generator), limbs);
508
509    let sqrt_impl =
510        if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
511            // Addition chain for (r + 1) // 4
512            let mod_plus_1_over_4 = pow_fixed::generate(
513                &quote! {self},
514                (modulus + BigUint::from_str("1").unwrap()) >> 2,
515            );
516
517            quote! {
518                use ::ff::derive::subtle::ConstantTimeEq;
519
520                // Because r = 3 (mod 4)
521                // sqrt can be done with only one exponentiation,
522                // via the computation of  self^((r + 1) // 4) (mod r)
523                let sqrt = {
524                    #mod_plus_1_over_4
525                };
526
527                ::ff::derive::subtle::CtOption::new(
528                    sqrt,
529                    (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
530                )
531            }
532        } else if (modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() {
533            // Addition chain for (t - 1) // 2
534            let t_minus_1_over_2 = if t == BigUint::one() {
535                quote!( #name::ONE )
536            } else {
537                pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
538            };
539
540            quote! {
541                // Tonelli-Shank's algorithm for q mod 16 = 1
542                // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
543                use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};
544
545                // w = self^((t - 1) // 2)
546                let w = {
547                    #t_minus_1_over_2
548                };
549
550                let mut v = S;
551                let mut x = *self * &w;
552                let mut b = x * &w;
553
554                // Initialize z as the 2^S root of unity.
555                let mut z = ROOT_OF_UNITY;
556
557                for max_v in (1..=S).rev() {
558                    let mut k = 1;
559                    let mut tmp = b.square();
560                    let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();
561
562                    for j in 2..max_v {
563                        let tmp_is_one = tmp.ct_eq(&#name::ONE);
564                        let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
565                        tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
566                        let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
567                        j_less_than_v &= !j.ct_eq(&v);
568                        k = u32::conditional_select(&j, &k, tmp_is_one);
569                        z = #name::conditional_select(&z, &new_z, j_less_than_v);
570                    }
571
572                    let result = x * &z;
573                    x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
574                    z = z.square();
575                    b *= &z;
576                    v = k;
577                }
578
579                ::ff::derive::subtle::CtOption::new(
580                    x,
581                    (x * &x).ct_eq(self), // Only return Some if it's the square root.
582                )
583            }
584        } else {
585            syn::Error::new_spanned(
586                &name,
587                "ff_derive can't generate a square root function for this field.",
588            )
589            .to_compile_error()
590        };
591
592    // Compute R^2 mod m
593    let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs);
594
595    let r = biguint_to_u64_vec(r, limbs);
596    let modulus_le_bytes = ReprEndianness::Little.modulus_repr(modulus, limbs * 8);
597    let modulus_str = format!("0x{}", modulus.to_str_radix(16));
598    let modulus = biguint_to_real_u64_vec(modulus.clone(), limbs);
599
600    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
601    let mut inv = 1u64;
602    for _ in 0..63 {
603        inv = inv.wrapping_mul(inv);
604        inv = inv.wrapping_mul(modulus[0]);
605    }
606    inv = inv.wrapping_neg();
607
608    (
609        quote! {
610            type REPR_BYTES = [u8; #bytes];
611            type REPR_BITS = REPR_BYTES;
612
613            /// This is the modulus m of the prime field
614            const MODULUS: REPR_BITS = [#(#modulus_le_bytes,)*];
615
616            /// This is the modulus m of the prime field in limb form
617            const MODULUS_LIMBS: #name = #name([#(#modulus,)*]);
618
619            /// This is the modulus m of the prime field in hex string form
620            const MODULUS_STR: &'static str = #modulus_str;
621
622            /// The number of bits needed to represent the modulus.
623            const MODULUS_BITS: u32 = #modulus_num_bits;
624
625            /// The number of bits that must be shaved from the beginning of
626            /// the representation when randomly sampling.
627            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
628
629            /// 2^{limbs*64} mod m
630            const R: #name = #name(#r);
631
632            /// 2^{limbs*64*2} mod m
633            const R2: #name = #name(#r2);
634
635            /// -(m^{-1} mod m) mod m
636            const INV: u64 = #inv;
637
638            /// 2^{-1} mod m
639            const TWO_INV: #name = #name(#two_inv);
640
641            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
642            /// nonresidue.
643            const GENERATOR: #name = #name(#generator);
644
645            /// 2^s * t = MODULUS - 1 with t odd
646            const S: u32 = #s;
647
648            /// 2^s root of unity computed by GENERATOR^t
649            const ROOT_OF_UNITY: #name = #name(#root_of_unity);
650
651            /// (2^s)^{-1} mod m
652            const ROOT_OF_UNITY_INV: #name = #name(#root_of_unity_inv);
653
654            /// GENERATOR^{2^s}
655            const DELTA: #name = #name(#delta);
656        },
657        sqrt_impl,
658    )
659}
660
661/// Implement PrimeField for the derived type.
662fn prime_field_impl(
663    name: &syn::Ident,
664    repr: &syn::Ident,
665    modulus: &BigUint,
666    endianness: &ReprEndianness,
667    limbs: usize,
668    sqrt_impl: proc_macro2::TokenStream,
669) -> proc_macro2::TokenStream {
670    // Returns r{n} as an ident.
671    fn get_temp(n: usize) -> syn::Ident {
672        syn::Ident::new(&format!("r{}", n), proc_macro2::Span::call_site())
673    }
674
675    // The parameter list for the mont_reduce() internal method.
676    // r0: u64, mut r1: u64, mut r2: u64, ...
677    let mut mont_paramlist = proc_macro2::TokenStream::new();
678    mont_paramlist.append_separated(
679        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
680            if i != 0 {
681                quote! {mut #x: u64}
682            } else {
683                quote! {#x: u64}
684            }
685        }),
686        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
687    );
688
689    // Implement montgomery reduction for some number of limbs
690    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
691        let mut gen = proc_macro2::TokenStream::new();
692
693        for i in 0..limbs {
694            {
695                let temp = get_temp(i);
696                gen.extend(quote! {
697                    let k = #temp.wrapping_mul(INV);
698                    let (_, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[0], 0);
699                });
700            }
701
702            for j in 1..limbs {
703                let temp = get_temp(i + j);
704                gen.extend(quote! {
705                    let (#temp, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[#j], carry);
706                });
707            }
708
709            let temp = get_temp(i + limbs);
710
711            if i == 0 {
712                gen.extend(quote! {
713                    let (#temp, carry2) = ::ff::derive::adc(#temp, 0, carry);
714                });
715            } else {
716                gen.extend(quote! {
717                    let (#temp, carry2) = ::ff::derive::adc(#temp, carry2, carry);
718                });
719            }
720        }
721
722        for i in 0..limbs {
723            let temp = get_temp(limbs + i);
724
725            gen.extend(quote! {
726                self.0[#i] = #temp;
727            });
728        }
729
730        gen
731    }
732
733    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
734        let mut gen = proc_macro2::TokenStream::new();
735
736        if limbs > 1 {
737            for i in 0..(limbs - 1) {
738                gen.extend(quote! {
739                    let carry = 0;
740                });
741
742                for j in (i + 1)..limbs {
743                    let temp = get_temp(i + j);
744                    if i == 0 {
745                        gen.extend(quote! {
746                            let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#j], carry);
747                        });
748                    } else {
749                        gen.extend(quote! {
750                            let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #a.0[#j], carry);
751                        });
752                    }
753                }
754
755                let temp = get_temp(i + limbs);
756
757                gen.extend(quote! {
758                    let #temp = carry;
759                });
760            }
761
762            for i in 1..(limbs * 2) {
763                let temp0 = get_temp(limbs * 2 - i);
764                let temp1 = get_temp(limbs * 2 - i - 1);
765
766                if i == 1 {
767                    gen.extend(quote! {
768                        let #temp0 = #temp1 >> 63;
769                    });
770                } else if i == (limbs * 2 - 1) {
771                    gen.extend(quote! {
772                        let #temp0 = #temp0 << 1;
773                    });
774                } else {
775                    gen.extend(quote! {
776                        let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
777                    });
778                }
779            }
780        } else {
781            let temp1 = get_temp(1);
782            gen.extend(quote! {
783                let #temp1 = 0;
784            });
785        }
786
787        for i in 0..limbs {
788            let temp0 = get_temp(i * 2);
789            let temp1 = get_temp(i * 2 + 1);
790            if i == 0 {
791                gen.extend(quote! {
792                    let (#temp0, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#i], 0);
793                });
794            } else {
795                gen.extend(quote! {
796                    let (#temp0, carry) = ::ff::derive::mac(#temp0, #a.0[#i], #a.0[#i], carry);
797                });
798            }
799
800            gen.extend(quote! {
801                let (#temp1, carry) = ::ff::derive::adc(#temp1, 0, carry);
802            });
803        }
804
805        let mut mont_calling = proc_macro2::TokenStream::new();
806        mont_calling.append_separated(
807            (0..(limbs * 2)).map(get_temp),
808            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
809        );
810
811        gen.extend(quote! {
812            let mut ret = *self;
813            ret.mont_reduce(#mont_calling);
814            ret
815        });
816
817        gen
818    }
819
820    fn mul_impl(
821        a: proc_macro2::TokenStream,
822        b: proc_macro2::TokenStream,
823        limbs: usize,
824    ) -> proc_macro2::TokenStream {
825        let mut gen = proc_macro2::TokenStream::new();
826
827        for i in 0..limbs {
828            gen.extend(quote! {
829                let carry = 0;
830            });
831
832            for j in 0..limbs {
833                let temp = get_temp(i + j);
834
835                if i == 0 {
836                    gen.extend(quote! {
837                        let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #b.0[#j], carry);
838                    });
839                } else {
840                    gen.extend(quote! {
841                        let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #b.0[#j], carry);
842                    });
843                }
844            }
845
846            let temp = get_temp(i + limbs);
847
848            gen.extend(quote! {
849                let #temp = carry;
850            });
851        }
852
853        let mut mont_calling = proc_macro2::TokenStream::new();
854        mont_calling.append_separated(
855            (0..(limbs * 2)).map(get_temp),
856            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
857        );
858
859        gen.extend(quote! {
860            self.mont_reduce(#mont_calling);
861        });
862
863        gen
864    }
865
866    /// Generates an implementation of multiplicative inversion within the target prime
867    /// field.
868    fn inv_impl(a: proc_macro2::TokenStream, modulus: &BigUint) -> proc_macro2::TokenStream {
869        // Addition chain for p - 2
870        let mod_minus_2 = pow_fixed::generate(&a, modulus - BigUint::from(2u64));
871
872        quote! {
873            use ::ff::derive::subtle::ConstantTimeEq;
874
875            // By Euler's theorem, if `a` is coprime to `p` (i.e. `gcd(a, p) = 1`), then:
876            //     a^-1 ≡ a^(phi(p) - 1) mod p
877            //
878            // `ff_derive` requires that `p` is prime; in this case, `phi(p) = p - 1`, and
879            // thus:
880            //     a^-1 ≡ a^(p - 2) mod p
881            let inv = {
882                #mod_minus_2
883            };
884
885            ::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
886        }
887    }
888
889    let squaring_impl = sqr_impl(quote! {self}, limbs);
890    let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
891    let invert_impl = inv_impl(quote! {self}, modulus);
892    let montgomery_impl = mont_impl(limbs);
893
894    // self.0[0].ct_eq(&other.0[0]) & self.0[1].ct_eq(&other.0[1]) & ...
895    let mut ct_eq_impl = proc_macro2::TokenStream::new();
896    ct_eq_impl.append_separated(
897        (0..limbs).map(|i| quote! { self.0[#i].ct_eq(&other.0[#i]) }),
898        proc_macro2::Punct::new('&', proc_macro2::Spacing::Alone),
899    );
900
901    fn mont_reduce_params(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
902        // a.0[0], a.0[1], ..., 0, 0, 0, 0, ...
903        let mut mont_reduce_params = proc_macro2::TokenStream::new();
904        mont_reduce_params.append_separated(
905            (0..limbs)
906                .map(|i| quote! { #a.0[#i] })
907                .chain((0..limbs).map(|_| quote! {0})),
908            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
909        );
910        mont_reduce_params
911    }
912
913    let mont_reduce_self_params = mont_reduce_params(quote! {self}, limbs);
914    let mont_reduce_other_params = mont_reduce_params(quote! {other}, limbs);
915
916    let from_repr_impl = endianness.from_repr(name, limbs);
917    let to_repr_impl = endianness.to_repr(quote! {#repr}, &mont_reduce_self_params, limbs);
918
919    cfg_if::cfg_if! {
920        if #[cfg(feature = "bits")] {
921            let to_le_bits_impl = ReprEndianness::Little.to_repr(
922                quote! {::ff::derive::bitvec::array::BitArray::new},
923                &mont_reduce_self_params,
924                limbs,
925            );
926
927            let prime_field_bits_impl = quote! {
928                impl ::ff::PrimeFieldBits for #name {
929                    type ReprBits = REPR_BITS;
930
931                    fn to_le_bits(&self) -> ::ff::FieldBits<REPR_BITS> {
932                        #to_le_bits_impl
933                    }
934
935                    fn char_le_bits() -> ::ff::FieldBits<REPR_BITS> {
936                        ::ff::FieldBits::new(MODULUS)
937                    }
938                }
939            };
940        } else {
941            let prime_field_bits_impl = quote! {};
942        }
943    };
944
945    let top_limb_index = limbs - 1;
946
947    quote! {
948        impl ::core::marker::Copy for #name { }
949
950        impl ::core::clone::Clone for #name {
951            fn clone(&self) -> #name {
952                *self
953            }
954        }
955
956        impl ::core::default::Default for #name {
957            fn default() -> #name {
958                use ::ff::Field;
959                #name::ZERO
960            }
961        }
962
963        impl ::ff::derive::subtle::ConstantTimeEq for #name {
964            fn ct_eq(&self, other: &#name) -> ::ff::derive::subtle::Choice {
965                use ::ff::PrimeField;
966                self.to_repr().ct_eq(&other.to_repr())
967            }
968        }
969
970        impl ::core::cmp::PartialEq for #name {
971            fn eq(&self, other: &#name) -> bool {
972                use ::ff::derive::subtle::ConstantTimeEq;
973                self.ct_eq(other).into()
974            }
975        }
976
977        impl ::core::cmp::Eq for #name { }
978
979        impl ::core::fmt::Debug for #name
980        {
981            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
982                use ::ff::PrimeField;
983                write!(f, "{}({:?})", stringify!(#name), self.to_repr())
984            }
985        }
986
987        /// Elements are ordered lexicographically.
988        impl Ord for #name {
989            #[inline(always)]
990            fn cmp(&self, other: &#name) -> ::core::cmp::Ordering {
991                let mut a = *self;
992                a.mont_reduce(
993                    #mont_reduce_self_params
994                );
995
996                let mut b = *other;
997                b.mont_reduce(
998                    #mont_reduce_other_params
999                );
1000
1001                a.cmp_native(&b)
1002            }
1003        }
1004
1005        impl PartialOrd for #name {
1006            #[inline(always)]
1007            fn partial_cmp(&self, other: &#name) -> Option<::core::cmp::Ordering> {
1008                Some(self.cmp(other))
1009            }
1010        }
1011
1012        impl From<u64> for #name {
1013            #[inline(always)]
1014            fn from(val: u64) -> #name {
1015                let mut raw = [0u64; #limbs];
1016                raw[0] = val;
1017                #name(raw) * R2
1018            }
1019        }
1020
1021        impl From<#name> for #repr {
1022            fn from(e: #name) -> #repr {
1023                use ::ff::PrimeField;
1024                e.to_repr()
1025            }
1026        }
1027
1028        impl<'a> From<&'a #name> for #repr {
1029            fn from(e: &'a #name) -> #repr {
1030                use ::ff::PrimeField;
1031                e.to_repr()
1032            }
1033        }
1034
1035        impl ::ff::derive::subtle::ConditionallySelectable for #name {
1036            fn conditional_select(a: &#name, b: &#name, choice: ::ff::derive::subtle::Choice) -> #name {
1037                let mut res = [0u64; #limbs];
1038                for i in 0..#limbs {
1039                    res[i] = u64::conditional_select(&a.0[i], &b.0[i], choice);
1040                }
1041                #name(res)
1042            }
1043        }
1044
1045        impl ::core::ops::Neg for #name {
1046            type Output = #name;
1047
1048            #[inline]
1049            fn neg(self) -> #name {
1050                use ::ff::Field;
1051
1052                let mut ret = self;
1053                if !ret.is_zero_vartime() {
1054                    let mut tmp = MODULUS_LIMBS;
1055                    tmp.sub_noborrow(&ret);
1056                    ret = tmp;
1057                }
1058                ret
1059            }
1060        }
1061
1062        impl<'r> ::core::ops::Add<&'r #name> for #name {
1063            type Output = #name;
1064
1065            #[inline]
1066            fn add(self, other: &#name) -> #name {
1067                use ::core::ops::AddAssign;
1068
1069                let mut ret = self;
1070                ret.add_assign(other);
1071                ret
1072            }
1073        }
1074
1075        impl ::core::ops::Add for #name {
1076            type Output = #name;
1077
1078            #[inline]
1079            fn add(self, other: #name) -> Self {
1080                self + &other
1081            }
1082        }
1083
1084        impl<'r> ::core::ops::AddAssign<&'r #name> for #name {
1085            #[inline]
1086            fn add_assign(&mut self, other: &#name) {
1087                // This cannot exceed the backing capacity.
1088                self.add_nocarry(other);
1089
1090                // However, it may need to be reduced.
1091                self.reduce();
1092            }
1093        }
1094
1095        impl ::core::ops::AddAssign for #name {
1096            #[inline]
1097            fn add_assign(&mut self, other: #name) {
1098                self.add_assign(&other);
1099            }
1100        }
1101
1102        impl<'r> ::core::ops::Sub<&'r #name> for #name {
1103            type Output = #name;
1104
1105            #[inline]
1106            fn sub(self, other: &#name) -> Self {
1107                use ::core::ops::SubAssign;
1108
1109                let mut ret = self;
1110                ret.sub_assign(other);
1111                ret
1112            }
1113        }
1114
1115        impl ::core::ops::Sub for #name {
1116            type Output = #name;
1117
1118            #[inline]
1119            fn sub(self, other: #name) -> Self {
1120                self - &other
1121            }
1122        }
1123
1124        impl<'r> ::core::ops::SubAssign<&'r #name> for #name {
1125            #[inline]
1126            fn sub_assign(&mut self, other: &#name) {
1127                // If `other` is larger than `self`, we'll need to add the modulus to self first.
1128                if other.cmp_native(self) == ::core::cmp::Ordering::Greater {
1129                    self.add_nocarry(&MODULUS_LIMBS);
1130                }
1131
1132                self.sub_noborrow(other);
1133            }
1134        }
1135
1136        impl ::core::ops::SubAssign for #name {
1137            #[inline]
1138            fn sub_assign(&mut self, other: #name) {
1139                self.sub_assign(&other);
1140            }
1141        }
1142
1143        impl<'r> ::core::ops::Mul<&'r #name> for #name {
1144            type Output = #name;
1145
1146            #[inline]
1147            fn mul(self, other: &#name) -> Self {
1148                use ::core::ops::MulAssign;
1149
1150                let mut ret = self;
1151                ret.mul_assign(other);
1152                ret
1153            }
1154        }
1155
1156        impl ::core::ops::Mul for #name {
1157            type Output = #name;
1158
1159            #[inline]
1160            fn mul(self, other: #name) -> Self {
1161                self * &other
1162            }
1163        }
1164
1165        impl<'r> ::core::ops::MulAssign<&'r #name> for #name {
1166            #[inline]
1167            fn mul_assign(&mut self, other: &#name)
1168            {
1169                #multiply_impl
1170            }
1171        }
1172
1173        impl ::core::ops::MulAssign for #name {
1174            #[inline]
1175            fn mul_assign(&mut self, other: #name)
1176            {
1177                self.mul_assign(&other);
1178            }
1179        }
1180
1181        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Sum<T> for #name {
1182            fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
1183                use ::ff::Field;
1184
1185                iter.fold(Self::ZERO, |acc, item| acc + item.borrow())
1186            }
1187        }
1188
1189        impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Product<T> for #name {
1190            fn product<I: Iterator<Item = T>>(iter: I) -> Self {
1191                use ::ff::Field;
1192
1193                iter.fold(Self::ONE, |acc, item| acc * item.borrow())
1194            }
1195        }
1196
1197        impl ::ff::PrimeField for #name {
1198            type Repr = #repr;
1199
1200            fn from_repr(r: #repr) -> ::ff::derive::subtle::CtOption<#name> {
1201                #from_repr_impl
1202
1203                // Try to subtract the modulus
1204                let borrow = r.0.iter().zip(MODULUS_LIMBS.0.iter()).fold(0, |borrow, (a, b)| {
1205                    ::ff::derive::sbb(*a, *b, borrow).1
1206                });
1207
1208                // If the element is smaller than MODULUS then the
1209                // subtraction will underflow, producing a borrow value
1210                // of 0xffff...ffff. Otherwise, it'll be zero.
1211                let is_some = ::ff::derive::subtle::Choice::from((borrow as u8) & 1);
1212
1213                // Convert to Montgomery form by computing
1214                // (a.R^0 * R^2) / R = a.R
1215                ::ff::derive::subtle::CtOption::new(r * &R2, is_some)
1216            }
1217
1218            fn from_repr_vartime(r: #repr) -> Option<#name> {
1219                #from_repr_impl
1220
1221                if r.is_valid() {
1222                    Some(r * R2)
1223                } else {
1224                    None
1225                }
1226            }
1227
1228            fn to_repr(&self) -> #repr {
1229                #to_repr_impl
1230            }
1231
1232            #[inline(always)]
1233            fn is_odd(&self) -> ::ff::derive::subtle::Choice {
1234                let mut r = *self;
1235                r.mont_reduce(
1236                    #mont_reduce_self_params
1237                );
1238
1239                // TODO: This looks like a constant-time result, but r.mont_reduce() is
1240                // currently implemented using variable-time code.
1241                ::ff::derive::subtle::Choice::from((r.0[0] & 1) as u8)
1242            }
1243
1244            const MODULUS: &'static str = MODULUS_STR;
1245
1246            const NUM_BITS: u32 = MODULUS_BITS;
1247
1248            const CAPACITY: u32 = Self::NUM_BITS - 1;
1249
1250            const TWO_INV: Self = TWO_INV;
1251
1252            const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
1253
1254            const S: u32 = S;
1255
1256            const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
1257
1258            const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
1259
1260            const DELTA: Self = DELTA;
1261        }
1262
1263        #prime_field_bits_impl
1264
1265        impl ::ff::Field for #name {
1266            const ZERO: Self = #name([0; #limbs]);
1267            const ONE: Self = R;
1268
1269            /// Computes a uniformly random element using rejection sampling.
1270            fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
1271                loop {
1272                    let mut tmp = {
1273                        let mut repr = [0u64; #limbs];
1274                        for i in 0..#limbs {
1275                            repr[i] = rng.next_u64();
1276                        }
1277                        #name(repr)
1278                    };
1279
1280                    // Mask away the unused most-significant bits.
1281                    // Note: In some edge cases, `REPR_SHAVE_BITS` could be 64, in which case
1282                    // `0xfff... >> REPR_SHAVE_BITS` overflows. So use `checked_shr` instead.
1283                    // This is always sufficient because we will have at most one spare limb
1284                    // to accommodate values of up to twice the modulus.
1285                    tmp.0.as_mut()[#top_limb_index] &= 0xffffffffffffffffu64.checked_shr(REPR_SHAVE_BITS).unwrap_or(0);
1286
1287                    if tmp.is_valid() {
1288                        return tmp
1289                    }
1290                }
1291            }
1292
1293            #[inline]
1294            fn is_zero_vartime(&self) -> bool {
1295                self.0.iter().all(|&e| e == 0)
1296            }
1297
1298            #[inline]
1299            fn double(&self) -> Self {
1300                let mut ret = *self;
1301
1302                // This cannot exceed the backing capacity.
1303                let mut last = 0;
1304                for i in &mut ret.0 {
1305                    let tmp = *i >> 63;
1306                    *i <<= 1;
1307                    *i |= last;
1308                    last = tmp;
1309                }
1310
1311                // However, it may need to be reduced.
1312                ret.reduce();
1313
1314                ret
1315            }
1316
1317            fn invert(&self) -> ::ff::derive::subtle::CtOption<Self> {
1318                #invert_impl
1319            }
1320
1321            #[inline]
1322            fn square(&self) -> Self
1323            {
1324                #squaring_impl
1325            }
1326
1327            fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) {
1328                ::ff::helpers::sqrt_ratio_generic(num, div)
1329            }
1330
1331            fn sqrt(&self) -> ::ff::derive::subtle::CtOption<Self> {
1332                #sqrt_impl
1333            }
1334        }
1335
1336        impl #name {
1337            /// Compares two elements in native representation. This is only used
1338            /// internally.
1339            #[inline(always)]
1340            fn cmp_native(&self, other: &#name) -> ::core::cmp::Ordering {
1341                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
1342                    if a < b {
1343                        return ::core::cmp::Ordering::Less
1344                    } else if a > b {
1345                        return ::core::cmp::Ordering::Greater
1346                    }
1347                }
1348
1349                ::core::cmp::Ordering::Equal
1350            }
1351
1352            /// Determines if the element is really in the field. This is only used
1353            /// internally.
1354            #[inline(always)]
1355            fn is_valid(&self) -> bool {
1356                // The Ord impl calls `reduce`, which in turn calls `is_valid`, so we use
1357                // this internal function to eliminate the cycle.
1358                self.cmp_native(&MODULUS_LIMBS) == ::core::cmp::Ordering::Less
1359            }
1360
1361            #[inline(always)]
1362            fn add_nocarry(&mut self, other: &#name) {
1363                let mut carry = 0;
1364
1365                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1366                    let (new_a, new_carry) = ::ff::derive::adc(*a, *b, carry);
1367                    *a = new_a;
1368                    carry = new_carry;
1369                }
1370            }
1371
1372            #[inline(always)]
1373            fn sub_noborrow(&mut self, other: &#name) {
1374                let mut borrow = 0;
1375
1376                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1377                    let (new_a, new_borrow) = ::ff::derive::sbb(*a, *b, borrow);
1378                    *a = new_a;
1379                    borrow = new_borrow;
1380                }
1381            }
1382
1383            /// Subtracts the modulus from this element if this element is not in the
1384            /// field. Only used interally.
1385            #[inline(always)]
1386            fn reduce(&mut self) {
1387                if !self.is_valid() {
1388                    self.sub_noborrow(&MODULUS_LIMBS);
1389                }
1390            }
1391
1392            #[allow(clippy::too_many_arguments)]
1393            #[inline(always)]
1394            fn mont_reduce(
1395                &mut self,
1396                #mont_paramlist
1397            )
1398            {
1399                // The Montgomery reduction here is based on Algorithm 14.32 in
1400                // Handbook of Applied Cryptography
1401                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1402
1403                #montgomery_impl
1404
1405                self.reduce();
1406            }
1407        }
1408    }
1409}