halo2derive/field/
mod.rs

1mod arith;
2#[cfg(feature = "asm")]
3mod asm;
4
5use num_bigint::BigUint;
6use num_integer::Integer;
7use num_traits::{Num, One};
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use quote::quote;
11use syn::Token;
12
13struct FieldConfig {
14    identifier: String,
15    field: syn::Ident,
16    modulus: BigUint,
17    mul_gen: BigUint,
18    zeta: BigUint,
19    endian: String,
20    from_uniform: Vec<usize>,
21}
22
23impl syn::parse::Parse for FieldConfig {
24    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
25        let identifier: syn::Ident = input.parse()?;
26        let identifier = identifier.to_string();
27        input.parse::<syn::Token![,]>()?;
28
29        let field: syn::Ident = input.parse()?;
30        input.parse::<syn::Token![,]>()?;
31
32        let get_big = |is_key: &str| -> Result<BigUint, syn::Error> {
33            let key: syn::Ident = input.parse()?;
34            assert_eq!(key.to_string(), is_key);
35            input.parse::<Token![=]>()?;
36            let n: syn::LitStr = input.parse()?;
37            let n = BigUint::from_str_radix(&n.value(), 16)
38                .map_err(|err| syn::Error::new(Span::call_site(), err.to_string()))?;
39            input.parse::<Token![,]>()?;
40            Ok(n)
41        };
42
43        let get_str = |is_key: &str| -> Result<String, syn::Error> {
44            let key: syn::Ident = input.parse()?;
45            assert_eq!(key.to_string(), is_key);
46            input.parse::<Token![=]>()?;
47            let n: syn::LitStr = input.parse()?;
48            let n = n.value();
49            input.parse::<Token![,]>()?;
50            Ok(n)
51        };
52
53        let get_usize_list = |is_key: &str| -> Result<Vec<usize>, syn::Error> {
54            let key: syn::Ident = input.parse()?;
55            assert_eq!(key.to_string(), is_key);
56            input.parse::<Token![=]>()?;
57
58            // chatgpt
59            let content;
60            syn::bracketed!(content in input);
61            let punctuated: syn::punctuated::Punctuated<syn::LitInt, Token![,]> =
62                content.parse_terminated(syn::LitInt::parse)?;
63            let values = punctuated
64                .into_iter()
65                .map(|lit| lit.base10_parse::<usize>())
66                .collect::<Result<Vec<_>, _>>()?;
67            input.parse::<Token![,]>()?;
68            Ok(values)
69        };
70
71        let modulus = get_big("modulus")?;
72        let mul_gen = get_big("mul_gen")?;
73        let zeta = get_big("zeta")?;
74        let from_uniform = get_usize_list("from_uniform")?;
75        let endian = get_str("endian")?;
76        assert!(endian == "little" || endian == "big");
77        assert!(input.is_empty());
78
79        Ok(FieldConfig {
80            identifier,
81            field,
82            modulus,
83            mul_gen,
84            zeta,
85            from_uniform,
86            endian,
87        })
88    }
89}
90
91pub(crate) fn impl_field(input: TokenStream) -> TokenStream {
92    use crate::utils::{big_to_token, mod_inv};
93    let FieldConfig {
94        identifier,
95        field,
96        modulus,
97        mul_gen,
98        zeta,
99        from_uniform,
100        endian,
101    } = syn::parse_macro_input!(input as FieldConfig);
102    let _ = identifier;
103
104    let num_bits = modulus.bits() as u32;
105    let limb_size = 64;
106    let num_limbs = ((num_bits - 1) / limb_size + 1) as usize;
107    let size = num_limbs * 8;
108    let modulus_limbs = crate::utils::big_to_limbs(&modulus, num_limbs);
109    let modulus_str = format!("0x{}", modulus.to_str_radix(16));
110    let modulus_limbs_ident = quote! {[#(#modulus_limbs,)*]};
111
112    let modulus_limbs_32 = crate::utils::big_to_limbs_32(&modulus, num_limbs * 2);
113    let modulus_limbs_32_ident = quote! {[#(#modulus_limbs_32,)*]};
114
115    let to_token = |e: &BigUint| big_to_token(e, num_limbs);
116    let half_modulus = (&modulus - 1usize) >> 1;
117    let half_modulus = to_token(&half_modulus);
118
119    // binary modulus
120    let t = BigUint::from(1u64) << (num_limbs * limb_size as usize);
121    // r1 = mont(1)
122    let r1: BigUint = &t % &modulus;
123    let mont = |v: &BigUint| (v * &r1) % &modulus;
124    // r2 = mont(r)
125    let r2: BigUint = (&r1 * &r1) % &modulus;
126    // r3 = mont(r^2)
127    let r3: BigUint = (&r1 * &r1 * &r1) % &modulus;
128
129    let r1 = to_token(&r1);
130    let r2 = to_token(&r2);
131    let r3 = to_token(&r3);
132
133    // inv = -(r^{-1} mod 2^64) mod 2^64
134    let mut inv64 = 1u64;
135    for _ in 0..63 {
136        inv64 = inv64.wrapping_mul(inv64);
137        inv64 = inv64.wrapping_mul(modulus_limbs[0]);
138    }
139    inv64 = inv64.wrapping_neg();
140
141    let mut by_inverter_constant: usize = 2;
142    loop {
143        let t = BigUint::from(1u64) << (62 * by_inverter_constant - 64);
144        if t > modulus {
145            break;
146        }
147        by_inverter_constant += 1;
148    }
149
150    let mut jacobi_constant: usize = 1;
151    loop {
152        let t = BigUint::from(1u64) << (64 * jacobi_constant - 31);
153        if t > modulus {
154            break;
155        }
156        jacobi_constant += 1;
157    }
158
159    let mut s: u32 = 0;
160    let mut t = &modulus - BigUint::one();
161    while t.is_even() {
162        t >>= 1;
163        s += 1;
164    }
165
166    let two_inv = mod_inv(&BigUint::from(2usize), &modulus);
167
168    let sqrt_impl = {
169        if &modulus % 16u64 == BigUint::from(1u64) {
170            let tm1o2 = ((&t - 1usize) * &two_inv) % &modulus;
171            let tm1o2 = big_to_token(&tm1o2, num_limbs);
172            quote! {
173                fn sqrt(&self) -> subtle::CtOption<Self> {
174                    ff::helpers::sqrt_tonelli_shanks(self, #tm1o2)
175                }
176            }
177        } else if &modulus % 4u64 == BigUint::from(3u64) {
178            let exp = (&modulus + 1usize) >> 2;
179            let exp = big_to_token(&exp, num_limbs);
180            quote! {
181                fn sqrt(&self) -> subtle::CtOption<Self> {
182                    use subtle::ConstantTimeEq;
183                    let t = self.pow(#exp);
184                    subtle::CtOption::new(t, t.square().ct_eq(self))
185                }
186            }
187        } else {
188            panic!("unsupported modulus")
189        }
190    };
191
192    let root_of_unity = mul_gen.modpow(&t, &modulus);
193    let root_of_unity_inv = mod_inv(&root_of_unity, &modulus);
194    let delta = mul_gen.modpow(&(BigUint::one() << s), &modulus);
195
196    let root_of_unity = to_token(&mont(&root_of_unity));
197    let root_of_unity_inv = to_token(&mont(&root_of_unity_inv));
198    let two_inv = to_token(&mont(&two_inv));
199    let mul_gen = to_token(&mont(&mul_gen));
200    let delta = to_token(&mont(&delta));
201    let zeta = to_token(&mont(&zeta));
202
203    let endian = match endian.as_str() {
204        "little" => {
205            quote! { LE }
206        }
207        "big" => {
208            quote! { BE }
209        }
210        _ => {
211            unreachable!()
212        }
213    };
214
215    let impl_field = quote! {
216        #[derive(Clone, Copy, PartialEq, Eq, Hash, Default)]
217        pub struct #field(pub(crate) [u64; #num_limbs]);
218
219        impl core::fmt::Debug for #field {
220            fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
221                use ff::PrimeField;
222                let tmp = self.to_repr();
223                write!(f, "0x")?;
224                for &b in tmp.as_ref().iter().rev() {
225                    write!(f, "{:02x}", b)?;
226                }
227                Ok(())
228            }
229        }
230
231        impl ConstantTimeEq for #field {
232            fn ct_eq(&self, other: &Self) -> Choice {
233                Choice::from(
234                    self.0
235                        .iter()
236                        .zip(other.0)
237                        .all(|(a, b)| bool::from(a.ct_eq(&b))) as u8,
238                )
239            }
240        }
241
242        impl ConditionallySelectable for #field {
243            fn conditional_select(a: &Self, b: &Self, choice: Choice) -> Self {
244                let limbs = (0..#num_limbs)
245                    .map(|i| u64::conditional_select(&a.0[i], &b.0[i], choice))
246                    .collect::<Vec<_>>()
247                    .try_into()
248                    .unwrap();
249                #field(limbs)
250            }
251        }
252
253        impl core::cmp::PartialOrd for #field {
254            fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
255                Some(self.cmp(other))
256            }
257        }
258
259        impl core::cmp::Ord for #field {
260            fn cmp(&self, other: &Self) -> core::cmp::Ordering {
261                use ff::PrimeField;
262                let left = self.to_repr();
263                let right = other.to_repr();
264                left.as_ref().iter()
265                    .zip(right.as_ref().iter())
266                    .rev()
267                    .find_map(|(left_byte, right_byte)| match left_byte.cmp(right_byte) {
268                        core::cmp::Ordering::Equal => None,
269                        res => Some(res),
270                    })
271                    .unwrap_or(core::cmp::Ordering::Equal)
272            }
273        }
274
275        impl<T: ::core::borrow::Borrow<#field>> ::core::iter::Sum<T> for #field {
276            fn sum<I: Iterator<Item = T>>(iter: I) -> Self {
277                iter.fold(Self::zero(), |acc, item| acc + item.borrow())
278            }
279        }
280
281        impl<T: ::core::borrow::Borrow<#field>> ::core::iter::Product<T> for #field {
282            fn product<I: Iterator<Item = T>>(iter: I) -> Self {
283                iter.fold(Self::one(), |acc, item| acc * item.borrow())
284            }
285        }
286
287        impl crate::serde::endian::EndianRepr for #field {
288            const ENDIAN: crate::serde::endian::Endian = crate::serde::endian::Endian::#endian;
289
290            fn to_bytes(&self) -> Vec<u8> {
291                self.to_bytes().to_vec()
292            }
293
294            fn from_bytes(bytes: &[u8]) -> subtle::CtOption<Self> {
295                #field::from_bytes(bytes[..#field::SIZE].try_into().unwrap())
296            }
297        }
298
299        impl #field {
300            pub const SIZE: usize = #num_limbs * 8;
301            pub const NUM_LIMBS: usize = #num_limbs;
302            pub(crate) const MODULUS_LIMBS: [u64; Self::NUM_LIMBS] = #modulus_limbs_ident;
303            pub(crate) const MODULUS_LIMBS_32: [u32; Self::NUM_LIMBS*2] = #modulus_limbs_32_ident;
304            const R: Self = Self(#r1);
305            const R2: Self = Self(#r2);
306            const R3: Self = Self(#r3);
307
308            /// Returns zero, the additive identity.
309            #[inline(always)]
310            pub const fn zero() -> #field {
311                #field([0; Self::NUM_LIMBS])
312            }
313
314            /// Returns one, the multiplicative identity.
315            #[inline(always)]
316            pub const fn one() -> #field {
317                Self::R
318            }
319
320            /// Converts from an integer represented in little endian
321            /// into its (congruent) `$field` representation.
322            pub const fn from_raw(val: [u64; Self::NUM_LIMBS]) -> Self {
323                Self(val).mul_const(&Self::R2)
324            }
325
326            /// Attempts to convert a <#endian>-endian byte representation of
327            /// a scalar into a `$field`, failing if the input is not canonical.
328            pub fn from_bytes(bytes: &[u8; Self::SIZE]) -> subtle::CtOption<Self> {
329                use crate::serde::endian::EndianRepr;
330                let mut el = #field::default();
331                #field::ENDIAN.from_bytes(bytes, &mut el.0);
332                subtle::CtOption::new(el * Self::R2, subtle::Choice::from(Self::is_less_than_modulus(&el.0) as u8))
333            }
334
335
336            /// Converts an element of `$field` into a byte representation in
337            /// <#endian>-endian byte order.
338            pub fn to_bytes(&self) -> [u8; Self::SIZE] {
339                use crate::serde::endian::EndianRepr;
340                let el = self.from_mont();
341                let mut res = [0; Self::SIZE];
342                #field::ENDIAN.to_bytes(&mut res, &el);
343                res.into()
344            }
345
346
347            // Returns the Jacobi symbol, where the numerator and denominator
348            // are the element and the characteristic of the field, respectively.
349            // The Jacobi symbol is applicable to odd moduli
350            // while the Legendre symbol is applicable to prime moduli.
351            // They are equivalent for prime moduli.
352            #[inline(always)]
353            fn jacobi(&self) -> i64 {
354                crate::ff_ext::jacobi::jacobi::<#jacobi_constant>(&self.0, &#modulus_limbs_ident)
355            }
356
357
358            #[inline(always)]
359            pub(crate) fn is_less_than_modulus(limbs: &[u64; Self::NUM_LIMBS]) -> bool {
360                let borrow = limbs.iter().enumerate().fold(0, |borrow, (i, limb)| {
361                    crate::arithmetic::sbb(*limb, Self::MODULUS_LIMBS[i], borrow).1
362                });
363                (borrow as u8) & 1 == 1
364            }
365
366            /// Returns whether or not this element is strictly lexicographically
367            /// larger than its negation.
368            pub fn lexicographically_largest(&self) -> Choice {
369                const HALF_MODULUS: [u64; #num_limbs]= #half_modulus;
370                let tmp = self.from_mont();
371                let borrow = tmp
372                    .into_iter()
373                    .zip(HALF_MODULUS.into_iter())
374                    .fold(0, |borrow, (t, m)| crate::arithmetic::sbb(t, m, borrow).1);
375                !Choice::from((borrow as u8) & 1)
376            }
377        }
378
379        impl ff::Field for #field {
380            const ZERO: Self = Self::zero();
381            const ONE: Self = Self::one();
382
383            fn random(mut rng: impl RngCore) -> Self {
384                let mut wide = [0u8; Self::SIZE * 2];
385                rng.fill_bytes(&mut wide);
386                <#field as ff::FromUniformBytes<{ #field::SIZE * 2 }>>::from_uniform_bytes(&wide)
387            }
388
389            #[inline(always)]
390            #[must_use]
391            fn double(&self) -> Self {
392                self.double()
393            }
394
395            #[inline(always)]
396            #[must_use]
397            fn square(&self) -> Self {
398                self.square()
399            }
400
401            // Returns the multiplicative inverse of the element. If it is zero, the method fails.
402            #[inline(always)]
403            fn invert(&self) -> CtOption<Self> {
404                const BYINVERTOR: crate::ff_ext::inverse::BYInverter<#by_inverter_constant> =
405                crate::ff_ext::inverse::BYInverter::<#by_inverter_constant>::new(&#modulus_limbs_ident, &#r2);
406
407                if let Some(inverse) = BYINVERTOR.invert::<{ Self::NUM_LIMBS }>(&self.0) {
408                    subtle::CtOption::new(Self(inverse), subtle::Choice::from(1))
409                } else {
410                    subtle::CtOption::new(Self::zero(), subtle::Choice::from(0))
411                }
412            }
413
414            #sqrt_impl
415
416            fn sqrt_ratio(num: &Self, div: &Self) -> (Choice, Self) {
417                ff::helpers::sqrt_ratio_generic(num, div)
418            }
419        }
420    };
421
422    let impl_prime_field = quote! {
423
424        // TODO use ::core::borrow::Borrow or AsRef
425        impl From<#field> for crate::serde::Repr<{ #field::SIZE }> {
426            fn from(value: #field) -> crate::serde::Repr<{ #field::SIZE }> {
427                use ff::PrimeField;
428                value.to_repr()
429            }
430        }
431
432        impl<'a> From<&'a #field> for crate::serde::Repr<{ #field::SIZE }> {
433            fn from(value: &'a #field) -> crate::serde::Repr<{ #field::SIZE }> {
434                use ff::PrimeField;
435                value.to_repr()
436            }
437        }
438
439        impl ff::PrimeField for #field {
440            const NUM_BITS: u32 = #num_bits;
441            const CAPACITY: u32 = #num_bits-1;
442            const TWO_INV :Self = Self(#two_inv);
443            const MULTIPLICATIVE_GENERATOR: Self = Self(#mul_gen);
444            const S: u32 = #s;
445            const ROOT_OF_UNITY: Self = Self(#root_of_unity);
446            const ROOT_OF_UNITY_INV: Self = Self(#root_of_unity_inv);
447            const DELTA: Self = Self(#delta);
448            const MODULUS: &'static str = #modulus_str;
449
450            type Repr = crate::serde::Repr<{ #field::SIZE }>;
451
452            fn from_u128(v: u128) -> Self {
453                Self::R2 * Self(
454                    [v as u64, (v >> 64) as u64]
455                        .into_iter()
456                        .chain(std::iter::repeat(0))
457                        .take(Self::NUM_LIMBS)
458                        .collect::<Vec<_>>()
459                        .try_into()
460                        .unwrap(),
461                )
462            }
463
464            fn from_repr(repr: Self::Repr) -> subtle::CtOption<Self> {
465                let mut el = #field::default();
466                crate::serde::endian::Endian::LE.from_bytes(repr.as_ref(), &mut el.0);
467                subtle::CtOption::new(el * Self::R2, subtle::Choice::from(Self::is_less_than_modulus(&el.0) as u8))
468            }
469
470            fn to_repr(&self) -> Self::Repr {
471                use crate::serde::endian::Endian;
472                let el = self.from_mont();
473                let mut res = [0; #size];
474                crate::serde::endian::Endian::LE.to_bytes(&mut res, &el);
475                res.into()
476            }
477
478            fn is_odd(&self) -> Choice {
479                Choice::from(self.to_repr()[0] & 1)
480            }
481        }
482    };
483
484    let impl_serde_object = quote! {
485        impl crate::serde::SerdeObject for #field {
486            fn from_raw_bytes_unchecked(bytes: &[u8]) -> Self {
487                debug_assert_eq!(bytes.len(), #size);
488
489                let inner = (0..#num_limbs)
490                    .map(|off| {
491                        u64::from_le_bytes(bytes[off * 8..(off + 1) * 8].try_into().unwrap())
492                    })
493                    .collect::<Vec<_>>();
494                Self(inner.try_into().unwrap())
495            }
496
497            fn from_raw_bytes(bytes: &[u8]) -> Option<Self> {
498                if bytes.len() != #size {
499                    return None;
500                }
501                let elt = Self::from_raw_bytes_unchecked(bytes);
502                Self::is_less_than_modulus(&elt.0).then(|| elt)
503            }
504
505            fn to_raw_bytes(&self) -> Vec<u8> {
506                let mut res = Vec::with_capacity(#num_limbs * 4);
507                for limb in self.0.iter() {
508                    res.extend_from_slice(&limb.to_le_bytes());
509                }
510                res
511            }
512
513            fn read_raw_unchecked<R: std::io::Read>(reader: &mut R) -> Self {
514                let inner = [(); #num_limbs].map(|_| {
515                    let mut buf = [0; 8];
516                    reader.read_exact(&mut buf).unwrap();
517                    u64::from_le_bytes(buf)
518                });
519                Self(inner)
520            }
521
522            fn read_raw<R: std::io::Read>(reader: &mut R) -> std::io::Result<Self> {
523                let mut inner = [0u64; #num_limbs];
524                for limb in inner.iter_mut() {
525                    let mut buf = [0; 8];
526                    reader.read_exact(&mut buf)?;
527                    *limb = u64::from_le_bytes(buf);
528                }
529                let elt = Self(inner);
530                Self::is_less_than_modulus(&elt.0)
531                    .then(|| elt)
532                    .ok_or_else(|| {
533                        std::io::Error::new(
534                            std::io::ErrorKind::InvalidData,
535                            "input number is not less than field modulus",
536                        )
537                    })
538            }
539            fn write_raw<W: std::io::Write>(&self, writer: &mut W) -> std::io::Result<()> {
540                for limb in self.0.iter() {
541                    writer.write_all(&limb.to_le_bytes())?;
542                }
543                Ok(())
544            }
545        }
546    };
547
548    #[cfg(feature = "asm")]
549    let impl_arith = {
550        if num_limbs == 4 && num_bits < 256 {
551            println!("implementing asm, {}", identifier);
552            asm::limb4::impl_arith(&field, inv64)
553        } else {
554            arith::impl_arith(&field, num_limbs, inv64)
555        }
556    };
557    #[cfg(not(feature = "asm"))]
558    let impl_arith = arith::impl_arith(&field, num_limbs, inv64);
559
560    let impl_arith_always_const = arith::impl_arith_always_const(&field, num_limbs, inv64);
561
562    let impl_from_uniform_bytes = from_uniform
563        .iter()
564        .map(|input_size| {
565            assert!(*input_size >= size);
566            assert!(*input_size <= size*2);
567            quote! {
568                impl ff::FromUniformBytes<#input_size> for #field {
569                    fn from_uniform_bytes(bytes: &[u8; #input_size]) -> Self {
570                        let mut wide = [0u8; Self::SIZE * 2];
571                        wide[..#input_size].copy_from_slice(bytes);
572                        let (a0, a1) = wide.split_at(Self::SIZE);
573
574                        let a0: [u64; Self::NUM_LIMBS] = (0..Self::NUM_LIMBS)
575                            .map(|off| u64::from_le_bytes(a0[off * 8..(off + 1) * 8].try_into().unwrap()))
576                            .collect::<Vec<_>>()
577                            .try_into()
578                            .unwrap();
579                        let a0 = #field(a0);
580
581                        let a1: [u64; Self::NUM_LIMBS] = (0..Self::NUM_LIMBS)
582                            .map(|off| u64::from_le_bytes(a1[off * 8..(off + 1) * 8].try_into().unwrap()))
583                            .collect::<Vec<_>>()
584                            .try_into()
585                            .unwrap();
586                        let a1 = #field(a1);
587
588                        // enforce non assembly impl since asm is likely to be optimized for sparse fields
589                        a0.mul_const(&Self::R2) + a1.mul_const(&Self::R3)
590
591                    }
592                }
593            }
594        })
595        .collect::<proc_macro2::TokenStream>();
596
597    let impl_zeta = quote! {
598        impl ff::WithSmallOrderMulGroup<3> for #field {
599            const ZETA: Self = Self(#zeta);
600        }
601    };
602
603    let output = quote! {
604        #impl_arith
605        #impl_arith_always_const
606        #impl_field
607        #impl_prime_field
608        #impl_serde_object
609        #impl_from_uniform_bytes
610        #impl_zeta
611    };
612
613    output.into()
614}