openvm_ff_derive/
lib.rs

1#![recursion_limit = "1024"]
2#![allow(clippy::manual_repeat_n)]
3
4extern crate proc_macro;
5extern crate proc_macro2;
6
7use std::{iter, str::FromStr};
8
9use num_bigint::BigUint;
10use num_integer::Integer;
11use num_traits::{One, ToPrimitive, Zero};
12use quote::{quote, TokenStreamExt};
13
14mod pow_fixed;
15
16enum ReprEndianness {
17    Big,
18    Little,
19}
20
21impl FromStr for ReprEndianness {
22    type Err = ();
23
24    fn from_str(s: &str) -> Result<Self, Self::Err> {
25        match s {
26            "big" => Ok(ReprEndianness::Big),
27            "little" => Ok(ReprEndianness::Little),
28            _ => Err(()),
29        }
30    }
31}
32
33impl ReprEndianness {
34    fn modulus_repr(&self, modulus: &BigUint, bytes: usize) -> Vec<u8> {
35        match self {
36            ReprEndianness::Big => {
37                let buf = modulus.to_bytes_be();
38                iter::repeat(0).take(bytes - buf.len()).chain(buf).collect()
39            }
40            ReprEndianness::Little => {
41                let mut buf = modulus.to_bytes_le();
42                buf.extend(iter::repeat(0).take(bytes - buf.len()));
43                buf
44            }
45        }
46    }
47
48    // Clippy things methods named from_* don't take self as a parameter
49    #[allow(clippy::wrong_self_convention)]
50    fn from_repr(&self, name: &syn::Ident, limbs: usize) -> proc_macro2::TokenStream {
51        let read_repr = match self {
52            ReprEndianness::Big => quote! {
53                <#name as ::openvm_algebra_guest::IntMod>::from_be_bytes(&r.as_ref()).unwrap()
54            },
55            ReprEndianness::Little => quote! {
56                <#name as ::openvm_algebra_guest::IntMod>::from_le_bytes(&r.as_ref()).unwrap()
57            },
58        };
59
60        let zkvm_impl = quote! {
61            #read_repr
62        };
63
64        let read_repr = match self {
65            ReprEndianness::Big => quote! {
66                ::ff::derive::byteorder::BigEndian::read_u64_into(r.as_ref(), &mut inner[..]);
67                inner.reverse();
68            },
69            ReprEndianness::Little => quote! {
70                ::ff::derive::byteorder::LittleEndian::read_u64_into(r.as_ref(), &mut inner[..]);
71            },
72        };
73
74        let non_zkvm_impl = quote! {
75            {
76                use ::ff::derive::byteorder::ByteOrder;
77
78                let mut inner = [0u64; #limbs];
79                #read_repr
80                #name(inner)
81            }
82        };
83
84        quote! {
85            #[cfg(target_os = "zkvm")]
86                let r = #zkvm_impl;
87            #[cfg(not(target_os = "zkvm"))]
88                let r = #non_zkvm_impl;
89        }
90    }
91
92    fn to_repr(
93        &self,
94        repr: proc_macro2::TokenStream,
95        mont_reduce_self_params: &proc_macro2::TokenStream,
96        limbs: usize,
97    ) -> proc_macro2::TokenStream {
98        let bytes = limbs * 8;
99
100        let write_repr = match self {
101            ReprEndianness::Big => quote! {
102                <Self as ::openvm_algebra_guest::IntMod>::to_be_bytes(self)[..#bytes].try_into().unwrap()
103            },
104            ReprEndianness::Little => quote! {
105                <Self as ::openvm_algebra_guest::IntMod>::as_le_bytes(self)[..#bytes].try_into().unwrap()
106            },
107        };
108
109        let zkvm_impl = quote! {
110            #repr(#write_repr)
111        };
112
113        let write_repr = match self {
114            ReprEndianness::Big => quote! {
115                r.0.reverse();
116                ::ff::derive::byteorder::BigEndian::write_u64_into(&r.0, &mut repr[..]);
117            },
118            ReprEndianness::Little => quote! {
119                ::ff::derive::byteorder::LittleEndian::write_u64_into(&r.0, &mut repr[..]);
120            },
121        };
122
123        let non_zkvm_impl = quote! {
124            use ::ff::derive::byteorder::ByteOrder;
125
126            let mut r = *self;
127            r.mont_reduce(
128                #mont_reduce_self_params
129            );
130
131            let mut repr = [0u8; #bytes];
132            #write_repr
133            #repr(repr)
134        };
135
136        quote! {
137            #[cfg(target_os = "zkvm")]
138            {
139                #zkvm_impl
140            }
141            #[cfg(not(target_os = "zkvm"))]
142            {
143                #non_zkvm_impl
144            }
145        }
146    }
147
148    fn iter_be(&self) -> proc_macro2::TokenStream {
149        // We aren't implementing for zkvm here because this function is only used in the prime
150        // field repr impl which is a plain array of bytes
151        match self {
152            ReprEndianness::Big => quote! {self.0.iter()},
153            ReprEndianness::Little => quote! {self.0.iter().rev()},
154        }
155    }
156}
157
158/// Derive the `PrimeField` trait.
159/// **Warning**: This macro removes the struct definition and inserts our own zkvm-compatible struct
160/// into the token stream.
161/// Currently the memory layout of the new struct will always be either 32 or 48-bytes, where the
162/// smallest that will fit the field's modulus is used. Moduli that are larger than 48-bytes are not
163/// yet supported.
164// Required attributes: PrimeFieldModulus, PrimeFieldGenerator, PrimeFieldReprEndianness
165// Note: In our fork, we changed the macro from a derive macro to an attribute-style macro because
166// we need to be able to remove the struct definition and insert our own into the token stream.
167#[proc_macro_attribute]
168pub fn openvm_prime_field(
169    _: proc_macro::TokenStream,
170    input: proc_macro::TokenStream,
171) -> proc_macro::TokenStream {
172    // Parse the type definition
173    let ast: syn::Item = syn::parse_macro_input!(input);
174
175    // The attribute should be applied to a struct.
176    let mut ast = match ast {
177        syn::Item::Struct(ast) => ast,
178        _ => {
179            return syn::Error::new_spanned(ast, "PrimeField derive only works for structs.")
180                .to_compile_error()
181                .into();
182        }
183    };
184
185    // We're given the modulus p of the prime field
186    let modulus: BigUint = fetch_attr("PrimeFieldModulus", &ast.attrs)
187        .expect("Please supply a PrimeFieldModulus attribute")
188        .parse()
189        .expect("PrimeFieldModulus should be a number");
190
191    // We may be provided with a generator of p - 1 order. It is required that this generator be
192    // quadratic nonresidue.
193    // TODO: Compute this ourselves.
194    let generator: BigUint = fetch_attr("PrimeFieldGenerator", &ast.attrs)
195        .expect("Please supply a PrimeFieldGenerator attribute")
196        .parse()
197        .expect("PrimeFieldGenerator should be a number");
198
199    // Field element representations may be in little-endian or big-endian.
200    let endianness = fetch_attr("PrimeFieldReprEndianness", &ast.attrs)
201        .expect("Please supply a PrimeFieldReprEndianness attribute")
202        .parse()
203        .expect("PrimeFieldReprEndianness should be 'big' or 'little'");
204
205    // The arithmetic in this library only works if the modulus*2 is smaller than the backing
206    // representation. Compute the number of 64-bit limbs we need.
207    let mut limbs = 1;
208    {
209        let mod2 = (&modulus) << 1; // modulus * 2
210        let mut cur = BigUint::one() << 64; // always 64-bit limbs for now
211        while cur < mod2 {
212            limbs += 1;
213            cur <<= 64;
214        }
215    }
216
217    let bytes = modulus.bits().div_ceil(8);
218    let zkvm_limbs = if bytes <= 32 {
219        32
220    } else if bytes <= 48 {
221        48
222    } else {
223        // A limitation of our zkvm implementation is that we only support moduli up to 48 bytes.
224        return syn::Error::new_spanned(
225            ast,
226            "PrimeField modulus is too large. Only 48 byte moduli are supported.",
227        )
228        .to_compile_error()
229        .into();
230    };
231
232    // The struct we're deriving for must be a wrapper around `pub [u64; limbs]`.
233    if let Some(err) = validate_struct(&ast, limbs) {
234        return err.into();
235    }
236
237    // Generate the identifier for the "Repr" type we must construct.
238    let repr_ident = syn::Ident::new(
239        &format!("{}Repr", ast.ident),
240        proc_macro2::Span::call_site(),
241    );
242
243    let mut gen = proc_macro2::TokenStream::new();
244
245    // Remove the attributes from the struct so we can insert it back into the code
246    ast.attrs.clear();
247
248    // Call moduli_declare! to define the struct
249    let openvm_struct = openvm_struct_impl(&ast, &modulus);
250
251    // TODO: test the non-zkvm case
252    gen.extend(quote! {
253        #[cfg(target_os = "zkvm")]
254            #openvm_struct
255        #[cfg(not(target_os = "zkvm"))]
256            #ast
257    });
258
259    let (constants_impl, sqrt_impl) =
260        prime_field_constants_and_sqrt(&ast.ident, &modulus, limbs, zkvm_limbs, generator);
261
262    gen.extend(constants_impl);
263    gen.extend(prime_field_repr_impl(&repr_ident, &endianness, limbs * 8));
264    gen.extend(prime_field_impl(
265        &ast.ident,
266        &repr_ident,
267        &modulus,
268        &endianness,
269        limbs,
270        zkvm_limbs,
271        sqrt_impl,
272    ));
273
274    // Return the generated impl
275    gen.into()
276}
277
278fn openvm_struct_impl(ast: &syn::ItemStruct, modulus: &BigUint) -> proc_macro2::TokenStream {
279    let struct_ident = &ast.ident;
280    let modulus_str = modulus.to_str_radix(10);
281    quote! {
282        ::openvm_algebra_moduli_macros::moduli_declare! {
283            #struct_ident {
284                modulus = #modulus_str
285            }
286        }
287    }
288}
289
290/// Checks that `body` contains `pub [u64; limbs]`.
291fn validate_struct(ast: &syn::ItemStruct, limbs: usize) -> Option<proc_macro2::TokenStream> {
292    // The struct should contain a single unnamed field.
293    let fields = match &ast.fields {
294        syn::Fields::Unnamed(x) if x.unnamed.len() == 1 => x,
295        _ => {
296            return Some(
297                syn::Error::new_spanned(
298                    &ast.ident,
299                    format!(
300                        "The struct must contain an array of limbs. Change this to `{}([u64; {}])`",
301                        ast.ident, limbs,
302                    ),
303                )
304                .to_compile_error(),
305            )
306        }
307    };
308    let field = &fields.unnamed[0];
309
310    // The field should be an array.
311    let arr = match &field.ty {
312        syn::Type::Array(x) => x,
313        _ => {
314            return Some(
315                syn::Error::new_spanned(
316                    field,
317                    format!(
318                    "The inner field must be an array of limbs. Change this to `[u64; {limbs}]`",
319                ),
320                )
321                .to_compile_error(),
322            )
323        }
324    };
325
326    // The array's element type should be `u64`.
327    if match arr.elem.as_ref() {
328        syn::Type::Path(path) => path.path.get_ident().map(|x| *x != "u64").unwrap_or(true),
329        _ => true,
330    } {
331        return Some(
332            syn::Error::new_spanned(
333                arr,
334                format!("PrimeField derive requires 64-bit limbs. Change this to `[u64; {limbs}]"),
335            )
336            .to_compile_error(),
337        );
338    }
339
340    // The array's length should be a literal int equal to `limbs`.
341    let expr_lit = match &arr.len {
342        syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
343        syn::Expr::Group(expr_group) => match &*expr_group.expr {
344            syn::Expr::Lit(expr_lit) => Some(&expr_lit.lit),
345            _ => None,
346        },
347        _ => None,
348    };
349    let lit_int = match match expr_lit {
350        Some(syn::Lit::Int(lit_int)) => Some(lit_int),
351        _ => None,
352    } {
353        Some(x) => x,
354        _ => {
355            return Some(
356                syn::Error::new_spanned(
357                    arr,
358                    format!("To derive PrimeField, change this to `[u64; {limbs}]`."),
359                )
360                .to_compile_error(),
361            )
362        }
363    };
364    if lit_int.base10_digits() != limbs.to_string() {
365        return Some(
366            syn::Error::new_spanned(
367                lit_int,
368                format!("The given modulus requires {limbs} limbs."),
369            )
370            .to_compile_error(),
371        );
372    }
373
374    // The field should not be public.
375    match &field.vis {
376        syn::Visibility::Inherited => (),
377        _ => {
378            return Some(
379                syn::Error::new_spanned(&field.vis, "Field must not be public.").to_compile_error(),
380            )
381        }
382    }
383
384    // Valid!
385    None
386}
387
388/// Fetch an attribute string from the derived struct.
389fn fetch_attr(name: &str, attrs: &[syn::Attribute]) -> Option<String> {
390    for attr in attrs {
391        if let Ok(meta) = attr.parse_meta() {
392            match meta {
393                syn::Meta::NameValue(nv) => {
394                    if nv.path.get_ident().map(|i| i.to_string()) == Some(name.to_string()) {
395                        match nv.lit {
396                            syn::Lit::Str(ref s) => return Some(s.value()),
397                            _ => {
398                                panic!("attribute {name} should be a string");
399                            }
400                        }
401                    }
402                }
403                _ => {
404                    panic!("attribute {name} should be a string");
405                }
406            }
407        }
408    }
409
410    None
411}
412
413// Implement the wrapped ident `repr` with `bytes` bytes.
414fn prime_field_repr_impl(
415    repr: &syn::Ident,
416    endianness: &ReprEndianness,
417    bytes: usize,
418) -> proc_macro2::TokenStream {
419    let repr_iter_be = endianness.iter_be();
420
421    quote! {
422        #[derive(Copy, Clone)]
423        pub struct #repr(pub [u8; #bytes]);
424
425        impl ::ff::derive::subtle::ConstantTimeEq for #repr {
426            fn ct_eq(&self, other: &#repr) -> ::ff::derive::subtle::Choice {
427                self.0
428                    .iter()
429                    .zip(other.0.iter())
430                    .map(|(a, b)| a.ct_eq(b))
431                    .fold(1.into(), |acc, x| acc & x)
432            }
433        }
434
435        impl ::core::cmp::PartialEq for #repr {
436            fn eq(&self, other: &#repr) -> bool {
437                use ::ff::derive::subtle::ConstantTimeEq;
438                self.ct_eq(other).into()
439            }
440        }
441
442        impl ::core::cmp::Eq for #repr { }
443
444        impl ::core::default::Default for #repr {
445            fn default() -> #repr {
446                #repr([0u8; #bytes])
447            }
448        }
449
450        impl ::core::fmt::Debug for #repr
451        {
452            fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
453                write!(f, "0x")?;
454                for i in #repr_iter_be {
455                    write!(f, "{:02x}", *i)?;
456                }
457
458                Ok(())
459            }
460        }
461
462        impl AsRef<[u8]> for #repr {
463            #[inline(always)]
464            fn as_ref(&self) -> &[u8] {
465                &self.0
466            }
467        }
468
469        impl AsMut<[u8]> for #repr {
470            #[inline(always)]
471            fn as_mut(&mut self) -> &mut [u8] {
472                &mut self.0
473            }
474        }
475    }
476}
477
478/// Convert BigUint into a vector of 64-bit limbs.
479fn biguint_to_real_u64_vec(mut v: BigUint, limbs: usize) -> Vec<u64> {
480    let m = BigUint::one() << 64;
481    let mut ret = vec![];
482
483    while v > BigUint::zero() {
484        let limb: BigUint = &v % &m;
485        ret.push(limb.to_u64().unwrap());
486        v >>= 64;
487    }
488
489    while ret.len() < limbs {
490        ret.push(0);
491    }
492
493    assert!(ret.len() == limbs);
494
495    ret
496}
497
498/// Convert BigUint into a tokenized vector of 64-bit limbs.
499fn biguint_to_u64_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
500    let ret = biguint_to_real_u64_vec(v, limbs);
501    quote!([#(#ret,)*])
502}
503
504/// Returns a token stream containing a little-endian bytes representation of `v`.
505fn biguint_to_u8_vec(v: BigUint, limbs: usize) -> proc_macro2::TokenStream {
506    let mut bytes = v.to_bytes_le();
507    while bytes.len() < limbs {
508        bytes.push(0);
509    }
510    quote!([#(#bytes,)*])
511}
512
513fn biguint_num_bits(mut v: BigUint) -> u32 {
514    let mut bits = 0;
515
516    while v != BigUint::zero() {
517        v >>= 1;
518        bits += 1;
519    }
520
521    bits
522}
523
524/// BigUint modular exponentiation by square-and-multiply.
525fn exp(base: BigUint, exp: &BigUint, modulus: &BigUint) -> BigUint {
526    let mut ret = BigUint::one();
527
528    for i in exp
529        .to_bytes_be()
530        .into_iter()
531        .flat_map(|x| (0..8).rev().map(move |i| (x >> i).is_odd()))
532    {
533        ret = (&ret * &ret) % modulus;
534        if i {
535            ret = (ret * &base) % modulus;
536        }
537    }
538
539    ret
540}
541
542#[test]
543fn test_exp() {
544    assert_eq!(
545        exp(
546            BigUint::from_str("4398572349857239485729348572983472345").unwrap(),
547            &BigUint::from_str("5489673498567349856734895").unwrap(),
548            &BigUint::from_str(
549                "52435875175126190479447740508185965837690552500527637822603658699938581184513"
550            )
551            .unwrap()
552        ),
553        BigUint::from_str(
554            "4371221214068404307866768905142520595925044802278091865033317963560480051536"
555        )
556        .unwrap()
557    );
558}
559
560fn prime_field_constants_and_sqrt(
561    name: &syn::Ident,
562    modulus: &BigUint,
563    limbs: usize,
564    zkvm_limbs: usize,
565    generator: BigUint,
566) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
567    let bytes = limbs * 8;
568    let modulus_num_bits = biguint_num_bits(modulus.clone());
569
570    // The number of bits we should "shave" from a randomly sampled representation, i.e.,
571    // if our modulus is 381 bits and our representation is 384 bits, we should shave
572    // 3 bits from the beginning of a randomly sampled 384 bit representation to
573    // reduce the cost of rejection sampling.
574    let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone());
575
576    // Compute R = 2**(64 * limbs) mod m
577    let r = (BigUint::one() << (limbs * 64)) % modulus;
578    let to_mont = |v| (v * &r) % modulus;
579
580    let two = BigUint::from_str("2").unwrap();
581    let p_minus_2 = modulus - &two;
582    let invert = |v| exp(v, &p_minus_2, modulus);
583
584    // 2^-1 mod m
585    let two_inv = biguint_to_u64_vec(to_mont(invert(two.clone())), limbs);
586    let two_inv_bytes = biguint_to_u8_vec(invert(two), zkvm_limbs);
587
588    // modulus - 1 = 2^s * t
589    let mut s: u32 = 0;
590    let mut t = modulus - BigUint::from_str("1").unwrap();
591    while t.is_even() {
592        t >>= 1;
593        s += 1;
594    }
595
596    // Compute 2^s root of unity given the generator
597    let root_of_unity_biguint = exp(generator.clone(), &t, modulus);
598
599    let root_of_unity_inv_biguint = invert(root_of_unity_biguint.clone());
600    let root_of_unity_inv = biguint_to_u64_vec(to_mont(root_of_unity_inv_biguint.clone()), limbs);
601    let root_of_unity_inv_bytes = biguint_to_u8_vec(root_of_unity_inv_biguint, zkvm_limbs);
602
603    let root_of_unity = biguint_to_u64_vec(to_mont(root_of_unity_biguint.clone()), limbs);
604    let root_of_unity_bytes = biguint_to_u8_vec(root_of_unity_biguint, zkvm_limbs);
605
606    let delta_biguint = exp(generator.clone(), &(BigUint::one() << s), modulus);
607    let delta = biguint_to_u64_vec(to_mont(delta_biguint.clone()), limbs);
608    let delta_bytes = biguint_to_u8_vec(delta_biguint, zkvm_limbs);
609
610    let generator_u64_limbs = biguint_to_u64_vec(to_mont(generator.clone()), limbs);
611    let generator_bytes = biguint_to_u8_vec(generator, zkvm_limbs);
612
613    let sqrt_impl =
614        if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() {
615            // Addition chain for (r + 1) // 4
616            let mod_plus_1_over_4 = pow_fixed::generate(
617                &quote! {self},
618                (modulus + BigUint::from_str("1").unwrap()) >> 2,
619            );
620
621            quote! {
622                use ::ff::derive::subtle::ConstantTimeEq;
623
624                // Because r = 3 (mod 4)
625                // sqrt can be done with only one exponentiation,
626                // via the computation of  self^((r + 1) // 4) (mod r)
627                let sqrt = {
628                    #mod_plus_1_over_4
629                };
630
631                ::ff::derive::subtle::CtOption::new(
632                    sqrt,
633                    (sqrt * &sqrt).ct_eq(self), // Only return Some if it's the square root.
634                )
635            }
636        } else {
637            // Addition chain for (t - 1) // 2
638            let t_minus_1_over_2 = if t == BigUint::one() {
639                quote!( #name::ONE )
640            } else {
641                pow_fixed::generate(&quote! {self}, (&t - BigUint::one()) >> 1)
642            };
643
644            quote! {
645                // Tonelli-Shanks algorithm works for every remaining odd prime.
646                // https://eprint.iacr.org/2012/685.pdf (page 12, algorithm 5)
647                use ::ff::derive::subtle::{ConditionallySelectable, ConstantTimeEq};
648
649                // w = self^((t - 1) // 2)
650                let w = {
651                    #t_minus_1_over_2
652                };
653
654                let mut v = S;
655                let mut x = *self * &w;
656                let mut b = x * &w;
657
658                // Initialize z as the 2^S root of unity.
659                let mut z = ROOT_OF_UNITY;
660
661                for max_v in (1..=S).rev() {
662                    let mut k = 1;
663                    let mut tmp = b.square();
664                    let mut j_less_than_v: ::ff::derive::subtle::Choice = 1.into();
665
666                    for j in 2..max_v {
667                        let tmp_is_one = tmp.ct_eq(&#name::ONE);
668                        let squared = #name::conditional_select(&tmp, &z, tmp_is_one).square();
669                        tmp = #name::conditional_select(&squared, &tmp, tmp_is_one);
670                        let new_z = #name::conditional_select(&z, &squared, tmp_is_one);
671                        j_less_than_v &= !j.ct_eq(&v);
672                        k = u32::conditional_select(&j, &k, tmp_is_one);
673                        z = #name::conditional_select(&z, &new_z, j_less_than_v);
674                    }
675
676                    let result = x * &z;
677                    x = #name::conditional_select(&result, &x, b.ct_eq(&#name::ONE));
678                    z = z.square();
679                    b *= &z;
680                    v = k;
681                }
682
683                ::ff::derive::subtle::CtOption::new(
684                    x,
685                    (x * &x).ct_eq(self), // Only return Some if it's the square root.
686                )
687            }
688        };
689
690    // Compute R^2 mod m
691    let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs);
692
693    let r = biguint_to_u64_vec(r, limbs);
694    let modulus_le_bytes = ReprEndianness::Little.modulus_repr(modulus, limbs * 8);
695    let modulus_str = format!("0x{}", modulus.to_str_radix(16));
696    let modulus = biguint_to_real_u64_vec(modulus.clone(), limbs);
697
698    // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1
699    let mut inv = 1u64;
700    for _ in 0..63 {
701        inv = inv.wrapping_mul(inv);
702        inv = inv.wrapping_mul(modulus[0]);
703    }
704    inv = inv.wrapping_neg();
705
706    (
707        quote! {
708            type REPR_BYTES = [u8; #bytes];
709            type REPR_BITS = REPR_BYTES;
710
711            /// This is the modulus m of the prime field
712            const MODULUS: REPR_BITS = [#(#modulus_le_bytes,)*];
713
714            /// This is the modulus m of the prime field in limb form
715            #[cfg(not(target_os = "zkvm"))]
716            const MODULUS_LIMBS: #name = #name([#(#modulus,)*]);
717
718            /// This is the modulus m of the prime field in hex string form
719            const MODULUS_STR: &'static str = #modulus_str;
720
721            /// The number of bits needed to represent the modulus.
722            const MODULUS_BITS: u32 = #modulus_num_bits;
723
724            /// The number of bits that must be shaved from the beginning of
725            /// the representation when randomly sampling.
726            const REPR_SHAVE_BITS: u32 = #repr_shave_bits;
727
728            /// 2^{limbs*64} mod m
729            #[cfg(not(target_os = "zkvm"))]
730            const R: #name = #name(#r);
731
732            /// 2^{limbs*64*2} mod m
733            #[cfg(not(target_os = "zkvm"))]
734            const R2: #name = #name(#r2);
735
736            /// -(m^{-1} mod m) mod m
737            #[cfg(not(target_os = "zkvm"))]
738            const INV: u64 = #inv;
739
740            /// 2^{-1} mod m
741            #[cfg(target_os = "zkvm")]
742            const TWO_INV: #name = <#name>::from_const_bytes(#two_inv_bytes);
743            #[cfg(not(target_os = "zkvm"))]
744            const TWO_INV: #name = #name(#two_inv);
745
746            /// Multiplicative generator of `MODULUS` - 1 order, also quadratic
747            /// nonresidue.
748            #[cfg(target_os = "zkvm")]
749            const GENERATOR: #name = <#name>::from_const_bytes(#generator_bytes);
750            #[cfg(not(target_os = "zkvm"))]
751            const GENERATOR: #name = #name(#generator_u64_limbs);
752
753            /// 2^s * t = MODULUS - 1 with t odd
754            const S: u32 = #s;
755
756            /// 2^s root of unity computed by GENERATOR^t
757            #[cfg(target_os = "zkvm")]
758            const ROOT_OF_UNITY: #name = <#name>::from_const_bytes(#root_of_unity_bytes);
759            #[cfg(not(target_os = "zkvm"))]
760            const ROOT_OF_UNITY: #name = #name(#root_of_unity);
761
762            /// (2^s)^{-1} mod m
763            #[cfg(target_os = "zkvm")]
764            const ROOT_OF_UNITY_INV: #name = <#name>::from_const_bytes(#root_of_unity_inv_bytes);
765            #[cfg(not(target_os = "zkvm"))]
766            const ROOT_OF_UNITY_INV: #name = #name(#root_of_unity_inv);
767
768            /// GENERATOR^{2^s}
769            #[cfg(target_os = "zkvm")]
770            const DELTA: #name = <#name>::from_const_bytes(#delta_bytes);
771            #[cfg(not(target_os = "zkvm"))]
772            const DELTA: #name = #name(#delta);
773        },
774        sqrt_impl,
775    )
776}
777
778/// Implement PrimeField for the derived type.
779fn prime_field_impl(
780    name: &syn::Ident,
781    repr: &syn::Ident,
782    modulus: &BigUint,
783    endianness: &ReprEndianness,
784    limbs: usize,
785    zkvm_limbs: usize,
786    sqrt_impl: proc_macro2::TokenStream,
787) -> proc_macro2::TokenStream {
788    // Returns r{n} as an ident.
789    fn get_temp(n: usize) -> syn::Ident {
790        syn::Ident::new(&format!("r{n}"), proc_macro2::Span::call_site())
791    }
792
793    // The parameter list for the mont_reduce() internal method.
794    // r0: u64, mut r1: u64, mut r2: u64, ...
795    let mut mont_paramlist = proc_macro2::TokenStream::new();
796    mont_paramlist.append_separated(
797        (0..(limbs * 2)).map(|i| (i, get_temp(i))).map(|(i, x)| {
798            if i != 0 {
799                quote! {mut #x: u64}
800            } else {
801                quote! {#x: u64}
802            }
803        }),
804        proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
805    );
806
807    // Implement montgomery reduction for some number of limbs
808    fn mont_impl(limbs: usize) -> proc_macro2::TokenStream {
809        #[cfg(target_os = "zkvm")]
810        {
811            quote! {
812                unimplemented!();
813            }
814        }
815        #[cfg(not(target_os = "zkvm"))]
816        {
817            let mut gen = proc_macro2::TokenStream::new();
818
819            for i in 0..limbs {
820                {
821                    let temp = get_temp(i);
822                    gen.extend(quote! {
823                        let k = #temp.wrapping_mul(INV);
824                        let (_, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[0], 0);
825                    });
826                }
827
828                for j in 1..limbs {
829                    let temp = get_temp(i + j);
830                    gen.extend(quote! {
831                    let (#temp, carry) = ::ff::derive::mac(#temp, k, MODULUS_LIMBS.0[#j], carry);
832                });
833                }
834
835                let temp = get_temp(i + limbs);
836
837                if i == 0 {
838                    gen.extend(quote! {
839                        let (#temp, carry2) = ::ff::derive::adc(#temp, 0, carry);
840                    });
841                } else {
842                    gen.extend(quote! {
843                        let (#temp, carry2) = ::ff::derive::adc(#temp, carry2, carry);
844                    });
845                }
846            }
847
848            for i in 0..limbs {
849                let temp = get_temp(limbs + i);
850
851                gen.extend(quote! {
852                    self.0[#i] = #temp;
853                });
854            }
855
856            gen
857        }
858    }
859
860    fn sqr_impl(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
861        let mut gen = proc_macro2::TokenStream::new();
862
863        if limbs > 1 {
864            for i in 0..(limbs - 1) {
865                gen.extend(quote! {
866                    let carry = 0;
867                });
868
869                for j in (i + 1)..limbs {
870                    let temp = get_temp(i + j);
871                    if i == 0 {
872                        gen.extend(quote! {
873                            let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#j], carry);
874                        });
875                    } else {
876                        gen.extend(quote! {
877                            let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #a.0[#j], carry);
878                        });
879                    }
880                }
881
882                let temp = get_temp(i + limbs);
883
884                gen.extend(quote! {
885                    let #temp = carry;
886                });
887            }
888
889            for i in 1..(limbs * 2) {
890                let temp0 = get_temp(limbs * 2 - i);
891                let temp1 = get_temp(limbs * 2 - i - 1);
892
893                if i == 1 {
894                    gen.extend(quote! {
895                        let #temp0 = #temp1 >> 63;
896                    });
897                } else if i == (limbs * 2 - 1) {
898                    gen.extend(quote! {
899                        let #temp0 = #temp0 << 1;
900                    });
901                } else {
902                    gen.extend(quote! {
903                        let #temp0 = (#temp0 << 1) | (#temp1 >> 63);
904                    });
905                }
906            }
907        } else {
908            let temp1 = get_temp(1);
909            gen.extend(quote! {
910                let #temp1 = 0;
911            });
912        }
913
914        for i in 0..limbs {
915            let temp0 = get_temp(i * 2);
916            let temp1 = get_temp(i * 2 + 1);
917            if i == 0 {
918                gen.extend(quote! {
919                    let (#temp0, carry) = ::ff::derive::mac(0, #a.0[#i], #a.0[#i], 0);
920                });
921            } else {
922                gen.extend(quote! {
923                    let (#temp0, carry) = ::ff::derive::mac(#temp0, #a.0[#i], #a.0[#i], carry);
924                });
925            }
926
927            gen.extend(quote! {
928                let (#temp1, carry) = ::ff::derive::adc(#temp1, 0, carry);
929            });
930        }
931
932        let mut mont_calling = proc_macro2::TokenStream::new();
933        mont_calling.append_separated(
934            (0..(limbs * 2)).map(get_temp),
935            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
936        );
937
938        gen.extend(quote! {
939            let mut ret = *self;
940            ret.mont_reduce(#mont_calling);
941            ret
942        });
943
944        gen
945    }
946
947    fn mul_impl(
948        a: proc_macro2::TokenStream,
949        b: proc_macro2::TokenStream,
950        limbs: usize,
951    ) -> proc_macro2::TokenStream {
952        let mut gen = proc_macro2::TokenStream::new();
953
954        for i in 0..limbs {
955            gen.extend(quote! {
956                let carry = 0;
957            });
958
959            for j in 0..limbs {
960                let temp = get_temp(i + j);
961
962                if i == 0 {
963                    gen.extend(quote! {
964                        let (#temp, carry) = ::ff::derive::mac(0, #a.0[#i], #b.0[#j], carry);
965                    });
966                } else {
967                    gen.extend(quote! {
968                        let (#temp, carry) = ::ff::derive::mac(#temp, #a.0[#i], #b.0[#j], carry);
969                    });
970                }
971            }
972
973            let temp = get_temp(i + limbs);
974
975            gen.extend(quote! {
976                let #temp = carry;
977            });
978        }
979
980        let mut mont_calling = proc_macro2::TokenStream::new();
981        mont_calling.append_separated(
982            (0..(limbs * 2)).map(get_temp),
983            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
984        );
985
986        gen.extend(quote! {
987            self.mont_reduce(#mont_calling);
988        });
989
990        gen
991    }
992
993    /// Generates an implementation of multiplicative inversion within the target prime
994    /// field.
995    fn inv_impl(a: proc_macro2::TokenStream, modulus: &BigUint) -> proc_macro2::TokenStream {
996        // Addition chain for p - 2
997        let mod_minus_2 = pow_fixed::generate(&a, modulus - BigUint::from(2u64));
998
999        quote! {
1000            use ::ff::derive::subtle::ConstantTimeEq;
1001
1002            // By Euler's theorem, if `a` is coprime to `p` (i.e. `gcd(a, p) = 1`), then:
1003            //     a^-1 ≡ a^(phi(p) - 1) mod p
1004            //
1005            // `ff_derive` requires that `p` is prime; in this case, `phi(p) = p - 1`, and
1006            // thus:
1007            //     a^-1 ≡ a^(p - 2) mod p
1008            let inv = {
1009                #mod_minus_2
1010            };
1011
1012            ::ff::derive::subtle::CtOption::new(inv, !#a.is_zero())
1013        }
1014    }
1015
1016    let squaring_impl = sqr_impl(quote! {self}, limbs);
1017    let multiply_impl = mul_impl(quote! {self}, quote! {other}, limbs);
1018    let invert_impl = inv_impl(quote! {self}, modulus);
1019    let montgomery_impl = mont_impl(limbs);
1020
1021    fn mont_reduce_params(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream {
1022        // a.0[0], a.0[1], ..., 0, 0, 0, 0, ...
1023        let mut mont_reduce_params = proc_macro2::TokenStream::new();
1024        mont_reduce_params.append_separated(
1025            (0..limbs)
1026                .map(|i| quote! { #a.0[#i] })
1027                .chain((0..limbs).map(|_| quote! {0})),
1028            proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone),
1029        );
1030        mont_reduce_params
1031    }
1032
1033    let mont_reduce_self_params = mont_reduce_params(quote! {self}, limbs);
1034    let mont_reduce_other_params = mont_reduce_params(quote! {other}, limbs);
1035
1036    let from_repr_impl = endianness.from_repr(name, limbs);
1037    let to_repr_impl = endianness.to_repr(quote! {#repr}, &mont_reduce_self_params, limbs);
1038
1039    let prime_field_bits_impl = if cfg!(feature = "bits") {
1040        let to_le_bits_impl = ReprEndianness::Little.to_repr(
1041            quote! {::ff::derive::bitvec::array::BitArray::new},
1042            &mont_reduce_self_params,
1043            limbs,
1044        );
1045
1046        Some(quote! {
1047            impl ::ff::PrimeFieldBits for #name {
1048                type ReprBits = REPR_BITS;
1049
1050                fn to_le_bits(&self) -> ::ff::FieldBits<REPR_BITS> {
1051                    #to_le_bits_impl
1052                }
1053
1054                fn char_le_bits() -> ::ff::FieldBits<REPR_BITS> {
1055                    ::ff::FieldBits::new(MODULUS)
1056                }
1057            }
1058        })
1059    } else {
1060        None
1061    };
1062
1063    let top_limb_index = limbs - 1;
1064
1065    // Since moduli_declare! already implements some traits, we need to conditionally
1066    // compile some of the trait impls depending on whether we're in zkvm or not.
1067    // So, we create a new module with #[cfg(not(target_os = "zkvm"))] and place the impls in there.
1068    let impl_module_ident =
1069        syn::Ident::new(&format!("impl_{name}"), proc_macro2::Span::call_site());
1070
1071    let zero_impl = quote! {
1072        #[cfg(target_os = "zkvm")]
1073            const ZERO: Self = <Self as ::openvm_algebra_guest::IntMod>::ZERO;
1074        #[cfg(not(target_os = "zkvm"))]
1075            const ZERO: Self = #name([0; #limbs]);
1076    };
1077    let one_impl = quote! {
1078        #[cfg(target_os = "zkvm")]
1079            const ONE: Self = <Self as ::openvm_algebra_guest::IntMod>::ONE;
1080        #[cfg(not(target_os = "zkvm"))]
1081            const ONE: Self = R;
1082    };
1083
1084    quote! {
1085        impl ::core::marker::Copy for #name { }
1086
1087        impl ::core::default::Default for #name {
1088            fn default() -> #name {
1089                use ::ff::Field;
1090                #name::ZERO
1091            }
1092        }
1093
1094        impl ::ff::derive::subtle::ConstantTimeEq for #name {
1095            fn ct_eq(&self, other: &#name) -> ::ff::derive::subtle::Choice {
1096                use ::ff::PrimeField;
1097                self.to_repr().ct_eq(&other.to_repr())
1098            }
1099        }
1100
1101        /// Elements are ordered lexicographically.
1102        impl Ord for #name {
1103            #[inline(always)]
1104            fn cmp(&self, other: &#name) -> ::core::cmp::Ordering {
1105                #[cfg(target_os = "zkvm")]
1106                {
1107                    <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(self);
1108                    <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(other);
1109
1110                    self.cmp_native(other)
1111                }
1112                #[cfg(not(target_os = "zkvm"))]
1113                {
1114                    let mut a = *self;
1115                    a.mont_reduce(
1116                        #mont_reduce_self_params
1117                    );
1118
1119                    let mut b = *other;
1120                    b.mont_reduce(
1121                        #mont_reduce_other_params
1122                    );
1123
1124                    a.cmp_native(&b)
1125                }
1126            }
1127        }
1128
1129        impl PartialOrd for #name {
1130            #[inline(always)]
1131            fn partial_cmp(&self, other: &#name) -> Option<::core::cmp::Ordering> {
1132                Some(self.cmp(other))
1133            }
1134        }
1135
1136        impl From<u64> for #name {
1137            #[inline(always)]
1138            fn from(val: u64) -> #name {
1139                #[cfg(target_os = "zkvm")]
1140                {
1141                    <#name as ::openvm_algebra_guest::IntMod>::from_u64(val)
1142                }
1143                #[cfg(not(target_os = "zkvm"))]
1144                {
1145                    let mut raw = [0u64; #limbs];
1146                    raw[0] = val;
1147                    #name(raw) * R2
1148                }
1149            }
1150        }
1151
1152        impl From<#name> for #repr {
1153            fn from(e: #name) -> #repr {
1154                use ::ff::PrimeField;
1155                e.to_repr()
1156            }
1157        }
1158
1159        impl<'a> From<&'a #name> for #repr {
1160            fn from(e: &'a #name) -> #repr {
1161                use ::ff::PrimeField;
1162                e.to_repr()
1163            }
1164        }
1165
1166        impl ::ff::derive::subtle::ConditionallySelectable for #name {
1167            fn conditional_select(a: &#name, b: &#name, choice: ::ff::derive::subtle::Choice) -> #name {
1168                #[cfg(target_os = "zkvm")]
1169                {
1170                    let mut res = [0u8; #zkvm_limbs];
1171                    let a_le_bytes = <Self as ::openvm_algebra_guest::IntMod>::as_le_bytes(a);
1172                    let b_le_bytes = <Self as ::openvm_algebra_guest::IntMod>::as_le_bytes(b);
1173                    for i in 0..#zkvm_limbs {
1174                        res[i] = u8::conditional_select(&a_le_bytes[i], &b_le_bytes[i], choice);
1175                    }
1176                    #name(res)
1177                }
1178                #[cfg(not(target_os = "zkvm"))]
1179                {
1180                    let mut res = [0u64; #limbs];
1181                    for i in 0..#limbs {
1182                        res[i] = u64::conditional_select(&a.0[i], &b.0[i], choice);
1183                    }
1184                    #name(res)
1185                }
1186            }
1187        }
1188
1189        // All the traits that are implemented in this module are already implemented
1190        // on our zkvm-compatible struct, so we need to conditionally implement them
1191        #[cfg(not(target_os = "zkvm"))]
1192        mod #impl_module_ident {
1193            use super::{#name, MODULUS_LIMBS};
1194
1195            impl ::core::clone::Clone for #name {
1196                fn clone(&self) -> #name {
1197                    *self
1198                }
1199            }
1200
1201            impl ::core::cmp::PartialEq for #name {
1202                fn eq(&self, other: &#name) -> bool {
1203                    use ::ff::derive::subtle::ConstantTimeEq;
1204                    self.ct_eq(other).into()
1205                }
1206            }
1207
1208            impl ::core::cmp::Eq for #name { }
1209
1210            impl ::core::fmt::Debug for #name
1211            {
1212                fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
1213                    use ::ff::PrimeField;
1214                    write!(f, "{}({:?})", stringify!(#name), self.to_repr())
1215                }
1216            }
1217
1218
1219            impl ::core::ops::Neg for #name {
1220                type Output = #name;
1221
1222                #[inline]
1223                fn neg(self) -> #name {
1224                    use ::ff::Field;
1225
1226                    let mut ret = self;
1227                    if !ret.is_zero_vartime() {
1228                        let mut tmp = MODULUS_LIMBS;
1229                        tmp.sub_noborrow(&ret);
1230                        ret = tmp;
1231                    }
1232                    ret
1233                }
1234            }
1235
1236            impl<'r> ::core::ops::Add<&'r #name> for #name {
1237                type Output = #name;
1238
1239                #[inline]
1240                fn add(self, other: &#name) -> #name {
1241                    use ::core::ops::AddAssign;
1242
1243                    let mut ret = self;
1244                    ret.add_assign(other);
1245                    ret
1246                }
1247            }
1248
1249            impl ::core::ops::Add for #name {
1250                type Output = #name;
1251
1252                #[inline]
1253                fn add(self, other: #name) -> Self {
1254                    self + &other
1255                }
1256            }
1257
1258            impl<'r> ::core::ops::AddAssign<&'r #name> for #name {
1259                #[inline]
1260                fn add_assign(&mut self, other: &#name) {
1261                    // This cannot exceed the backing capacity.
1262                    self.add_nocarry(other);
1263
1264                    // However, it may need to be reduced.
1265                    self.reduce();
1266                }
1267            }
1268
1269            impl ::core::ops::AddAssign for #name {
1270                #[inline]
1271                fn add_assign(&mut self, other: #name) {
1272                    self.add_assign(&other);
1273                }
1274            }
1275
1276            impl<'r> ::core::ops::Sub<&'r #name> for #name {
1277                type Output = #name;
1278
1279                #[inline]
1280                fn sub(self, other: &#name) -> Self {
1281                    use ::core::ops::SubAssign;
1282
1283                    let mut ret = self;
1284                    ret.sub_assign(other);
1285                    ret
1286                }
1287            }
1288
1289            impl ::core::ops::Sub for #name {
1290                type Output = #name;
1291
1292                #[inline]
1293                fn sub(self, other: #name) -> Self {
1294                    self - &other
1295                }
1296            }
1297
1298            impl<'r> ::core::ops::SubAssign<&'r #name> for #name {
1299                #[inline]
1300                fn sub_assign(&mut self, other: &#name) {
1301                    // If `other` is larger than `self`, we'll need to add the modulus to self first.
1302                    if other.cmp_native(self) == ::core::cmp::Ordering::Greater {
1303                        self.add_nocarry(&MODULUS_LIMBS);
1304                    }
1305
1306                    self.sub_noborrow(other);
1307                }
1308            }
1309
1310            impl ::core::ops::SubAssign for #name {
1311                #[inline]
1312                fn sub_assign(&mut self, other: #name) {
1313                    self.sub_assign(&other);
1314                }
1315            }
1316
1317            impl<'r> ::core::ops::Mul<&'r #name> for #name {
1318                type Output = #name;
1319
1320                #[inline]
1321                fn mul(self, other: &#name) -> Self {
1322                    use ::core::ops::MulAssign;
1323
1324                    let mut ret = self;
1325                    ret.mul_assign(other);
1326                    ret
1327                }
1328            }
1329
1330            impl ::core::ops::Mul for #name {
1331                type Output = #name;
1332
1333                #[inline]
1334                fn mul(self, other: #name) -> Self {
1335                    self * &other
1336                }
1337            }
1338
1339            impl<'r> ::core::ops::MulAssign<&'r #name> for #name {
1340                #[inline]
1341                fn mul_assign(&mut self, other: &#name)
1342                {
1343                    #multiply_impl
1344                }
1345            }
1346
1347            impl ::core::ops::MulAssign for #name {
1348                #[inline]
1349                fn mul_assign(&mut self, other: #name)
1350                {
1351                    self.mul_assign(&other);
1352                }
1353            }
1354
1355            impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Sum<T> for #name {
1356                fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
1357                    use ::ff::Field;
1358
1359                    iter.fold(Self::ZERO, |acc, item| acc + item.borrow())
1360                }
1361            }
1362
1363            impl<T: ::core::borrow::Borrow<#name>> ::core::iter::Product<T> for #name {
1364                fn product<I: Iterator<Item = T>>(iter: I) -> Self {
1365                    use ::ff::Field;
1366
1367                    iter.fold(Self::ONE, |acc, item| acc * item.borrow())
1368                }
1369            }
1370        }
1371
1372        impl ::ff::PrimeField for #name {
1373            type Repr = #repr;
1374
1375            fn from_repr(r: #repr) -> ::ff::derive::subtle::CtOption<#name> {
1376                #from_repr_impl
1377
1378                #[cfg(target_os = "zkvm")]
1379                {
1380                    ::ff::derive::subtle::CtOption::new(r, r.constant_time_is_reduced())
1381                }
1382                #[cfg(not(target_os = "zkvm"))]
1383                {
1384                    // Try to subtract the modulus
1385                    let borrow = r.0.iter().zip(MODULUS_LIMBS.0.iter()).fold(0, |borrow, (a, b)| {
1386                        ::ff::derive::sbb(*a, *b, borrow).1
1387                    });
1388
1389                    // If the element is smaller than MODULUS then the
1390                    // subtraction will underflow, producing a borrow value
1391                    // of 0xffff...ffff. Otherwise, it'll be zero.
1392                    let is_some = ::ff::derive::subtle::Choice::from((borrow as u8) & 1);
1393
1394                    // Convert to Montgomery form by computing
1395                    // (a.R^0 * R^2) / R = a.R
1396                    ::ff::derive::subtle::CtOption::new(r * &R2, is_some)
1397                }
1398            }
1399
1400            fn from_repr_vartime(r: #repr) -> Option<#name> {
1401                #from_repr_impl
1402
1403                #[cfg(target_os = "zkvm")]
1404                {
1405                    if <Self as ::openvm_algebra_guest::IntMod>::is_reduced(&r) {
1406                        Some(r)
1407                    } else {
1408                        None
1409                    }
1410                }
1411                #[cfg(not(target_os = "zkvm"))]
1412                {
1413                    if r.is_valid() {
1414                        Some(r * R2)
1415                    } else {
1416                        None
1417                    }
1418
1419                }
1420            }
1421
1422            fn to_repr(&self) -> #repr {
1423                #to_repr_impl
1424            }
1425
1426            #[inline(always)]
1427            fn is_odd(&self) -> ::ff::derive::subtle::Choice {
1428                #[cfg(target_os = "zkvm")]
1429                {
1430                    ::ff::derive::subtle::Choice::from((<Self as ::openvm_algebra_guest::IntMod>::as_le_bytes(self)[0] & 1) as u8)
1431                }
1432                #[cfg(not(target_os = "zkvm"))]
1433                {
1434                    let mut r = *self;
1435                    r.mont_reduce(
1436                        #mont_reduce_self_params
1437                    );
1438
1439                    // TODO: This looks like a constant-time result, but r.mont_reduce() is
1440                    // currently implemented using variable-time code.
1441                    ::ff::derive::subtle::Choice::from((r.0[0] & 1) as u8)
1442                }
1443            }
1444
1445            const MODULUS: &'static str = MODULUS_STR;
1446
1447            const NUM_BITS: u32 = MODULUS_BITS;
1448
1449            const CAPACITY: u32 = Self::NUM_BITS - 1;
1450
1451            const TWO_INV: Self = TWO_INV;
1452
1453            const MULTIPLICATIVE_GENERATOR: Self = GENERATOR;
1454
1455            const S: u32 = S;
1456
1457            const ROOT_OF_UNITY: Self = ROOT_OF_UNITY;
1458
1459            const ROOT_OF_UNITY_INV: Self = ROOT_OF_UNITY_INV;
1460
1461            const DELTA: Self = DELTA;
1462        }
1463
1464        #prime_field_bits_impl
1465
1466        impl ::ff::Field for #name {
1467            #zero_impl
1468            #one_impl
1469
1470            /// Computes a uniformly random element using rejection sampling.
1471            fn random(mut rng: impl ::ff::derive::rand_core::RngCore) -> Self {
1472                #[cfg(target_os = "zkvm")]
1473                {
1474                    panic!("randomn is not implemented for the zkvm");
1475                }
1476                #[cfg(not(target_os = "zkvm"))]
1477                {
1478                    loop {
1479                        let mut tmp = {
1480                            let mut repr = [0u64; #limbs];
1481                            for i in 0..#limbs {
1482                                repr[i] = rng.next_u64();
1483                            }
1484                            #name(repr)
1485                        };
1486
1487                        // Mask away the unused most-significant bits.
1488                        // Note: In some edge cases, `REPR_SHAVE_BITS` could be 64, in which case
1489                        // `0xfff... >> REPR_SHAVE_BITS` overflows. So use `checked_shr` instead.
1490                        // This is always sufficient because we will have at most one spare limb
1491                        // to accommodate values of up to twice the modulus.
1492                        tmp.0[#top_limb_index] &= 0xffffffffffffffffu64.checked_shr(REPR_SHAVE_BITS).unwrap_or(0);
1493
1494                        if tmp.is_valid() {
1495                            return tmp
1496                        }
1497                    }
1498                }
1499            }
1500
1501            #[inline]
1502            fn is_zero_vartime(&self) -> bool {
1503                #[cfg(target_os = "zkvm")]
1504                {
1505                    self == &Self::ZERO
1506                }
1507                #[cfg(not(target_os = "zkvm"))]
1508                {
1509                    self.0.iter().all(|&e| e == 0)
1510                }
1511            }
1512
1513            #[inline]
1514            fn double(&self) -> Self {
1515                #[cfg(target_os = "zkvm")]
1516                {
1517                    <Self as ::openvm_algebra_guest::IntMod>::double(self)
1518                }
1519                #[cfg(not(target_os = "zkvm"))]
1520                {
1521                    let mut ret = *self;
1522
1523                    // This cannot exceed the backing capacity.
1524                    let mut last = 0;
1525                    for i in &mut ret.0 {
1526                        let tmp = *i >> 63;
1527                        *i <<= 1;
1528                        *i |= last;
1529                        last = tmp;
1530                    }
1531
1532                    // However, it may need to be reduced.
1533                    ret.reduce();
1534
1535                    ret
1536                }
1537            }
1538
1539            /// Note that invert is not constant-time in the zkvm.
1540            fn invert(&self) -> ::ff::derive::subtle::CtOption<Self> {
1541                #[cfg(target_os = "zkvm")]
1542                {
1543                    let is_self_zero = self.is_zero_vartime();
1544                    let res = if is_self_zero {
1545                        <Self as ::openvm_algebra_guest::IntMod>::ZERO
1546                    } else {
1547                        use ::openvm_algebra_guest::DivUnsafe;
1548                        <Self as ::openvm_algebra_guest::IntMod>::ONE.div_unsafe(self)
1549                    };
1550                    ::ff::derive::subtle::CtOption::new(res, (!is_self_zero as u8).into())
1551                }
1552                #[cfg(not(target_os = "zkvm"))]
1553                {
1554                    #invert_impl
1555                }
1556            }
1557
1558            #[inline]
1559            fn square(&self) -> Self
1560            {
1561                #[cfg(target_os = "zkvm")]
1562                {
1563                    <Self as ::openvm_algebra_guest::IntMod>::square(self)
1564                }
1565                #[cfg(not(target_os = "zkvm"))]
1566                {
1567                    #squaring_impl
1568                }
1569            }
1570
1571            fn sqrt_ratio(num: &Self, div: &Self) -> (::ff::derive::subtle::Choice, Self) {
1572                ::ff::helpers::sqrt_ratio_generic(num, div)
1573            }
1574
1575            /// Note that sqrt is not constant-time in the zkvm
1576            fn sqrt(&self) -> ::ff::derive::subtle::CtOption<Self> {
1577                #[cfg(target_os = "zkvm")]
1578                {
1579                    use ::openvm_algebra_guest::Sqrt;
1580                    match Sqrt::sqrt(self) {
1581                        Some(sqrt) => ::ff::derive::subtle::CtOption::new(sqrt, 1u8.into()),
1582                        None => ::ff::derive::subtle::CtOption::new(Self::ZERO, 0u8.into()),
1583                    }
1584                }
1585                #[cfg(not(target_os = "zkvm"))]
1586                {
1587                    #sqrt_impl
1588                }
1589            }
1590        }
1591
1592        impl #name {
1593            /// Compares two elements in native representation. This is only used
1594            /// internally.
1595            #[inline(always)]
1596            fn cmp_native(&self, other: &#name) -> ::core::cmp::Ordering {
1597                for (a, b) in self.0.iter().rev().zip(other.0.iter().rev()) {
1598                    if a < b {
1599                        return ::core::cmp::Ordering::Less
1600                    } else if a > b {
1601                        return ::core::cmp::Ordering::Greater
1602                    }
1603                }
1604
1605                ::core::cmp::Ordering::Equal
1606            }
1607
1608            /// Determines if the element is really in the field. This is only used
1609            /// internally.
1610            #[inline(always)]
1611            #[cfg(not(target_os = "zkvm"))]
1612            fn is_valid(&self) -> bool {
1613                // The Ord impl calls `reduce`, which in turn calls `is_valid`, so we use
1614                // this internal function to eliminate the cycle.
1615                self.cmp_native(&MODULUS_LIMBS) == ::core::cmp::Ordering::Less
1616            }
1617
1618            #[inline(always)]
1619            #[cfg(not(target_os = "zkvm"))]
1620            fn add_nocarry(&mut self, other: &#name) {
1621                let mut carry = 0;
1622
1623                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1624                    let (new_a, new_carry) = ::ff::derive::adc(*a, *b, carry);
1625                    *a = new_a;
1626                    carry = new_carry;
1627                }
1628            }
1629
1630            #[inline(always)]
1631            #[cfg(not(target_os = "zkvm"))]
1632            fn sub_noborrow(&mut self, other: &#name) {
1633                let mut borrow = 0;
1634
1635                for (a, b) in self.0.iter_mut().zip(other.0.iter()) {
1636                    let (new_a, new_borrow) = ::ff::derive::sbb(*a, *b, borrow);
1637                    *a = new_a;
1638                    borrow = new_borrow;
1639                }
1640            }
1641
1642            /// Subtracts the modulus from this element if this element is not in the
1643            /// field. Only used internally.
1644            #[inline(always)]
1645            #[cfg(not(target_os = "zkvm"))]
1646            fn reduce(&mut self) {
1647                if !self.is_valid() {
1648                    self.sub_noborrow(&MODULUS_LIMBS);
1649                }
1650            }
1651
1652            #[allow(clippy::too_many_arguments)]
1653            #[inline(always)]
1654            #[cfg(not(target_os = "zkvm"))]
1655            fn mont_reduce(
1656                &mut self,
1657                #mont_paramlist
1658            )
1659            {
1660                // The Montgomery reduction here is based on Algorithm 14.32 in
1661                // Handbook of Applied Cryptography
1662                // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
1663
1664                #montgomery_impl
1665
1666                self.reduce();
1667            }
1668
1669            // A variant of IntMod::is_reduced that runs in constant time
1670            #[cfg(target_os = "zkvm")]
1671            fn constant_time_is_reduced(&self) -> ::ff::derive::subtle::Choice {
1672                let mut is_less = 0u8.into();
1673                // Iterate over limbs in little endian order and retain the result of the last non-equal comparison.
1674                for (x_limb, p_limb) in self.0.iter().zip(<Self as ::openvm_algebra_guest::IntMod>::MODULUS.iter()) {
1675                    if x_limb < p_limb {
1676                        is_less = 1u8.into();
1677                    } else if x_limb > p_limb {
1678                        is_less = 0u8.into();
1679                    }
1680                }
1681                // If all limbs are equal, is_less is false
1682                is_less
1683            }
1684        }
1685    }
1686}