openvm_algebra_moduli_macros/
lib.rs

1extern crate alloc;
2extern crate proc_macro;
3
4use std::sync::atomic::AtomicUsize;
5
6use num_bigint::BigUint;
7use num_prime::nt_funcs::is_prime;
8use openvm_macros_common::{string_to_bytes, MacroArgs};
9use proc_macro::TokenStream;
10use quote::format_ident;
11use syn::{
12    parse::{Parse, ParseStream},
13    parse_macro_input, LitStr, Token,
14};
15
16static MOD_IDX: AtomicUsize = AtomicUsize::new(0);
17
18/// This macro generates the code to setup the modulus for a given prime. Also it places the moduli
19/// into a special static variable to be later extracted from the ELF and used by the VM. Usage:
20/// ```
21/// moduli_declare! {
22///     Bls12381 { modulus = "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" },
23///     Bn254 { modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" },
24/// }
25/// ```
26/// This creates two structs, `Bls12381` and `Bn254`, each representing the modular arithmetic class
27/// (implementing `Add`, `Sub` and so on).
28#[proc_macro]
29pub fn moduli_declare(input: TokenStream) -> TokenStream {
30    let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
31
32    let mut output = Vec::new();
33
34    let span = proc_macro::Span::call_site();
35
36    for item in items {
37        let struct_name = item.name.to_string();
38        let struct_name = syn::Ident::new(&struct_name, span.into());
39        let mut modulus: Option<String> = None;
40        for param in item.params {
41            match param.name.to_string().as_str() {
42                "modulus" => {
43                    if let syn::Expr::Lit(syn::ExprLit {
44                        lit: syn::Lit::Str(value),
45                        ..
46                    }) = param.value
47                    {
48                        modulus = Some(value.value());
49                    } else {
50                        return syn::Error::new_spanned(
51                            param.value,
52                            "Expected a string literal for macro argument `modulus`",
53                        )
54                        .to_compile_error()
55                        .into();
56                    }
57                }
58                _ => {
59                    panic!("Unknown parameter {}", param.name);
60                }
61            }
62        }
63
64        // Parsing the parameters is over at this point
65
66        let mod_idx = MOD_IDX.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
67
68        let modulus = modulus.expect("modulus parameter is required");
69        let modulus_bytes = string_to_bytes(&modulus);
70        let mut limbs = modulus_bytes.len();
71        let mut block_size = 32;
72
73        if limbs <= 32 {
74            limbs = 32;
75        } else if limbs <= 48 {
76            limbs = 48;
77            block_size = 16;
78        } else {
79            panic!("limbs must be at most 48");
80        }
81
82        let modulus_bytes = modulus_bytes
83            .into_iter()
84            .chain(vec![0u8; limbs])
85            .take(limbs)
86            .collect::<Vec<_>>();
87
88        let modulus_hex = modulus_bytes
89            .iter()
90            .rev()
91            .map(|x| format!("{:02x}", x))
92            .collect::<Vec<_>>()
93            .join("");
94        macro_rules! create_extern_func {
95            ($name:ident) => {
96                let $name = syn::Ident::new(
97                    &format!("{}_{}", stringify!($name), modulus_hex),
98                    span.into(),
99                );
100            };
101        }
102        create_extern_func!(add_extern_func);
103        create_extern_func!(sub_extern_func);
104        create_extern_func!(mul_extern_func);
105        create_extern_func!(div_extern_func);
106        create_extern_func!(is_eq_extern_func);
107        create_extern_func!(hint_sqrt_extern_func);
108        create_extern_func!(hint_non_qr_extern_func);
109        create_extern_func!(moduli_setup_extern_func);
110
111        let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
112        let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
113
114        let module_name = format_ident!("algebra_impl_{}", mod_idx);
115
116        let result = TokenStream::from(quote::quote_spanned! { span.into() =>
117            /// An element of the ring of integers modulo a positive integer.
118            /// The element is internally represented as a fixed size array of bytes.
119            ///
120            /// ## Caution
121            /// It is not guaranteed that the integer representation is less than the modulus.
122            /// After any arithmetic operation, the honest host should normalize the result
123            /// to its canonical representation less than the modulus, but guest execution does not
124            /// require it.
125            ///
126            /// See [`assert_reduced`](openvm_algebra_guest::IntMod::assert_reduced) and
127            /// [`is_reduced`](openvm_algebra_guest::IntMod::is_reduced).
128            #[derive(Clone, Eq, serde::Serialize, serde::Deserialize)]
129            #[repr(C, align(#block_size))]
130            pub struct #struct_name(#[serde(with = "openvm_algebra_guest::BigArray")] [u8; #limbs]);
131
132            extern "C" {
133                fn #add_extern_func(rd: usize, rs1: usize, rs2: usize);
134                fn #sub_extern_func(rd: usize, rs1: usize, rs2: usize);
135                fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize);
136                fn #div_extern_func(rd: usize, rs1: usize, rs2: usize);
137                fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool;
138                fn #hint_sqrt_extern_func(rs1: usize);
139                fn #hint_non_qr_extern_func();
140                fn #moduli_setup_extern_func();
141            }
142
143            impl #struct_name {
144                #[inline(always)]
145                const fn from_const_u8(val: u8) -> Self {
146                    let mut bytes = [0; #limbs];
147                    bytes[0] = val;
148                    Self(bytes)
149                }
150
151                /// Constructor from little-endian bytes. Does not enforce the integer value of `bytes`
152                /// must be less than the modulus.
153                pub const fn from_const_bytes(bytes: [u8; #limbs]) -> Self {
154                    Self(bytes)
155                }
156
157                #[inline(always)]
158                fn add_assign_impl(&mut self, other: &Self) {
159                    #[cfg(not(target_os = "zkvm"))]
160                    {
161                        *self = Self::from_biguint(
162                            (self.as_biguint() + other.as_biguint()) % Self::modulus_biguint(),
163                        );
164                    }
165                    #[cfg(target_os = "zkvm")]
166                    {
167                        Self::set_up_once();
168                        unsafe {
169                            #add_extern_func(
170                                self as *mut Self as usize,
171                                self as *const Self as usize,
172                                other as *const Self as usize,
173                            );
174                        }
175                    }
176                }
177
178                #[inline(always)]
179                fn sub_assign_impl(&mut self, other: &Self) {
180                    #[cfg(not(target_os = "zkvm"))]
181                    {
182                        let modulus = Self::modulus_biguint();
183                        *self = Self::from_biguint(
184                            (self.as_biguint() + modulus.clone() - other.as_biguint()) % modulus,
185                        );
186                    }
187                    #[cfg(target_os = "zkvm")]
188                    {
189                        Self::set_up_once();
190                        unsafe {
191                            #sub_extern_func(
192                                self as *mut Self as usize,
193                                self as *const Self as usize,
194                                other as *const Self as usize,
195                            );
196                        }
197                    }
198                }
199
200                #[inline(always)]
201                fn mul_assign_impl(&mut self, other: &Self) {
202                    #[cfg(not(target_os = "zkvm"))]
203                    {
204                        *self = Self::from_biguint(
205                            (self.as_biguint() * other.as_biguint()) % Self::modulus_biguint(),
206                        );
207                    }
208                    #[cfg(target_os = "zkvm")]
209                    {
210                        Self::set_up_once();
211                        unsafe {
212                            #mul_extern_func(
213                                self as *mut Self as usize,
214                                self as *const Self as usize,
215                                other as *const Self as usize,
216                            );
217                        }
218                    }
219                }
220
221                #[inline(always)]
222                fn div_assign_unsafe_impl(&mut self, other: &Self) {
223                    #[cfg(not(target_os = "zkvm"))]
224                    {
225                        let modulus = Self::modulus_biguint();
226                        let inv = other.as_biguint().modinv(&modulus).unwrap();
227                        *self = Self::from_biguint((self.as_biguint() * inv) % modulus);
228                    }
229                    #[cfg(target_os = "zkvm")]
230                    {
231                        Self::set_up_once();
232                        unsafe {
233                            #div_extern_func(
234                                self as *mut Self as usize,
235                                self as *const Self as usize,
236                                other as *const Self as usize,
237                            );
238                        }
239                    }
240                }
241
242                /// # Safety
243                /// - `dst_ptr` must be a raw pointer to `&mut Self`. It will be written to only at the very end.
244                #[inline(always)]
245                unsafe fn add_refs_impl<const CHECK_SETUP: bool>(&self, other: &Self, dst_ptr: *mut Self) {
246                    #[cfg(not(target_os = "zkvm"))]
247                    {
248                        let mut res = self.clone();
249                        res += other;
250                        // BEWARE order of operations: when dst_ptr = other as pointers
251                        let dst = unsafe { &mut *dst_ptr };
252                        *dst = res;
253                    }
254                    #[cfg(target_os = "zkvm")]
255                    {
256                        if CHECK_SETUP {
257                            Self::set_up_once();
258                        }
259                        #add_extern_func(
260                            dst_ptr as usize,
261                            self as *const #struct_name as usize,
262                            other as *const #struct_name as usize,
263                        );
264                    }
265                }
266
267                /// SAFETY: `dst_ptr` must be a raw pointer to `&mut Self`.
268                /// It will be written to only at the very end .
269                #[inline(always)]
270                unsafe fn sub_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
271                    #[cfg(not(target_os = "zkvm"))]
272                    {
273                        let mut res = self.clone();
274                        res -= other;
275                        // BEWARE order of operations: when dst_ptr = other as pointers
276                        let dst = unsafe { &mut *dst_ptr };
277                        *dst = res;
278                    }
279                    #[cfg(target_os = "zkvm")]
280                    {
281                        Self::set_up_once();
282                        unsafe {
283                            #sub_extern_func(
284                                dst_ptr as usize,
285                                self as *const #struct_name as usize,
286                                other as *const #struct_name as usize,
287                            );
288                        }
289                    }
290                }
291
292                /// SAFETY: `dst_ptr` must be a raw pointer to `&mut Self`.
293                /// It will be written to only at the very end .
294                #[inline(always)]
295                unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
296                    #[cfg(not(target_os = "zkvm"))]
297                    {
298                        let mut res = self.clone();
299                        res *= other;
300                        // BEWARE order of operations: when dst_ptr = other as pointers
301                        let dst = unsafe { &mut *dst_ptr };
302                        *dst = res;
303                    }
304                    #[cfg(target_os = "zkvm")]
305                    {
306                        Self::set_up_once();
307                        unsafe {
308                            #mul_extern_func(
309                                dst_ptr as usize,
310                                self as *const #struct_name as usize,
311                                other as *const #struct_name as usize,
312                            );
313                        }
314                    }
315                }
316
317                #[inline(always)]
318                fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
319                    #[cfg(not(target_os = "zkvm"))]
320                    {
321                        let modulus = Self::modulus_biguint();
322                        let inv = other.as_biguint().modinv(&modulus).unwrap();
323                        Self::from_biguint((self.as_biguint() * inv) % modulus)
324                    }
325                    #[cfg(target_os = "zkvm")]
326                    {
327                        Self::set_up_once();
328                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
329                        unsafe {
330                            #div_extern_func(
331                                uninit.as_mut_ptr() as usize,
332                                self as *const #struct_name as usize,
333                                other as *const #struct_name as usize,
334                            );
335                        }
336                        unsafe { uninit.assume_init() }
337                    }
338                }
339
340                #[inline(always)]
341                unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
342                    #[cfg(not(target_os = "zkvm"))]
343                    {
344                        self.as_le_bytes() == other.as_le_bytes()
345                    }
346                    #[cfg(target_os = "zkvm")]
347                    {
348                        if CHECK_SETUP {
349                            Self::set_up_once();
350                        }
351                        #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize)
352                    }
353                }
354
355                // Helper function to call the setup instruction on first use
356                #[inline(always)]
357                #[cfg(target_os = "zkvm")]
358                fn set_up_once() {
359                    static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new();
360                    is_setup.get_or_init(|| {
361                        unsafe { #moduli_setup_extern_func(); }
362                        true
363                    });
364                }
365                #[inline(always)]
366                #[cfg(not(target_os = "zkvm"))]
367                fn set_up_once() {
368                    // No-op for non-ZKVM targets
369                }
370            }
371
372            // Put trait implementations in a private module to avoid conflicts
373            mod #module_name {
374                use openvm_algebra_guest::IntMod;
375
376                use super::#struct_name;
377
378                impl IntMod for #struct_name {
379                    type Repr = [u8; #limbs];
380                    type SelfRef<'a> = &'a Self;
381
382                    const MODULUS: Self::Repr = [#(#modulus_bytes),*];
383
384                    const ZERO: Self = Self([0; #limbs]);
385
386                    const NUM_LIMBS: usize = #limbs;
387
388                    const ONE: Self = Self::from_const_u8(1);
389
390                    fn from_repr(repr: Self::Repr) -> Self {
391                        Self(repr)
392                    }
393
394                    fn from_le_bytes(bytes: &[u8]) -> Option<Self> {
395                        let elt = Self::from_le_bytes_unchecked(bytes);
396                        if elt.is_reduced() {
397                            Some(elt)
398                        } else {
399                            None
400                        }
401                    }
402
403                    fn from_be_bytes(bytes: &[u8]) -> Option<Self> {
404                        let elt = Self::from_be_bytes_unchecked(bytes);
405                        if elt.is_reduced() {
406                            Some(elt)
407                        } else {
408                            None
409                        }
410                    }
411
412                    fn from_le_bytes_unchecked(bytes: &[u8]) -> Self {
413                        let mut arr = [0u8; #limbs];
414                        arr.copy_from_slice(bytes);
415                        Self(arr)
416                    }
417
418                    fn from_be_bytes_unchecked(bytes: &[u8]) -> Self {
419                        let mut arr = [0u8; #limbs];
420                        for (a, b) in arr.iter_mut().zip(bytes.iter().rev()) {
421                            *a = *b;
422                        }
423                        Self(arr)
424                    }
425
426                    fn from_u8(val: u8) -> Self {
427                        Self::from_const_u8(val)
428                    }
429
430                    fn from_u32(val: u32) -> Self {
431                        let mut bytes = [0; #limbs];
432                        bytes[..4].copy_from_slice(&val.to_le_bytes());
433                        Self(bytes)
434                    }
435
436                    fn from_u64(val: u64) -> Self {
437                        let mut bytes = [0; #limbs];
438                        bytes[..8].copy_from_slice(&val.to_le_bytes());
439                        Self(bytes)
440                    }
441
442                    #[inline(always)]
443                    fn as_le_bytes(&self) -> &[u8] {
444                        &(self.0)
445                    }
446
447                    #[inline(always)]
448                    fn to_be_bytes(&self) -> [u8; #limbs] {
449                        core::array::from_fn(|i| self.0[#limbs - 1 - i])
450                    }
451
452                    #[cfg(not(target_os = "zkvm"))]
453                    fn modulus_biguint() -> num_bigint::BigUint {
454                        num_bigint::BigUint::from_bytes_le(&Self::MODULUS)
455                    }
456
457                    #[cfg(not(target_os = "zkvm"))]
458                    fn from_biguint(biguint: num_bigint::BigUint) -> Self {
459                        Self(openvm::utils::biguint_to_limbs(&biguint))
460                    }
461
462                    #[cfg(not(target_os = "zkvm"))]
463                    fn as_biguint(&self) -> num_bigint::BigUint {
464                        num_bigint::BigUint::from_bytes_le(self.as_le_bytes())
465                    }
466
467                    #[inline(always)]
468                    fn neg_assign(&mut self) {
469                        unsafe {
470                            // SAFETY: we borrow self as &Self and as *mut Self but
471                            // the latter will only be written to at the very end.
472                            (#struct_name::ZERO).sub_refs_impl(self, self as *const Self as *mut Self);
473                        }
474                    }
475
476                    #[inline(always)]
477                    fn double_assign(&mut self) {
478                        unsafe {
479                            // SAFETY: we borrow self as &Self and as *mut Self but
480                            // the latter will only be written to at the very end.
481                            self.add_refs_impl::<true>(self, self as *const Self as *mut Self);
482                        }
483                    }
484
485                    #[inline(always)]
486                    fn square_assign(&mut self) {
487                        unsafe {
488                            // SAFETY: we borrow self as &Self and as *mut Self but
489                            // the latter will only be written to at the very end.
490                            self.mul_refs_impl(self, self as *const Self as *mut Self);
491                        }
492                    }
493
494                    #[inline(always)]
495                    fn double(&self) -> Self {
496                        self + self
497                    }
498
499                    #[inline(always)]
500                    fn square(&self) -> Self {
501                        self * self
502                    }
503
504                    #[inline(always)]
505                    fn cube(&self) -> Self {
506                        &self.square() * self
507                    }
508
509                    /// If `self` is not in its canonical form, the proof will fail to verify.
510                    /// This means guest execution will never terminate (either successfully or
511                    /// unsuccessfully) if `self` is not in its canonical form.
512                    // is_eq_mod enforces `self` is less than `modulus`
513                    fn assert_reduced(&self) {
514                        // This must not be optimized out
515                        let _ = core::hint::black_box(PartialEq::eq(self, self));
516                    }
517
518                    fn is_reduced(&self) -> bool {
519                        // limbs are little endian
520                        for (x_limb, p_limb) in self.0.iter().rev().zip(Self::MODULUS.iter().rev()) {
521                            if x_limb < p_limb {
522                                return true;
523                            } else if x_limb > p_limb {
524                                return false;
525                            }
526                        }
527                        // At this point, all limbs are equal
528                        false
529                    }
530
531                    #[inline(always)]
532                    fn set_up_once() {
533                        Self::set_up_once();
534                    }
535
536                    #[inline(always)]
537                    unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
538                        Self::eq_impl::<CHECK_SETUP>(self, other)
539                    }
540
541                    #[inline(always)]
542                    unsafe fn add_ref<const CHECK_SETUP: bool>(&self, other: &Self) -> Self {
543                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
544                        self.add_refs_impl::<CHECK_SETUP>(other, uninit.as_mut_ptr());
545                        uninit.assume_init()
546                    }
547                }
548
549                impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
550                    #[inline(always)]
551                    fn add_assign(&mut self, other: &'a #struct_name) {
552                        self.add_assign_impl(other);
553                    }
554                }
555
556                impl core::ops::AddAssign for #struct_name {
557                    #[inline(always)]
558                    fn add_assign(&mut self, other: Self) {
559                        self.add_assign_impl(&other);
560                    }
561                }
562
563                impl core::ops::Add for #struct_name {
564                    type Output = Self;
565                    #[inline(always)]
566                    fn add(mut self, other: Self) -> Self::Output {
567                        self += other;
568                        self
569                    }
570                }
571
572                impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
573                    type Output = Self;
574                    #[inline(always)]
575                    fn add(mut self, other: &'a #struct_name) -> Self::Output {
576                        self += other;
577                        self
578                    }
579                }
580
581                impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
582                    type Output = #struct_name;
583                    #[inline(always)]
584                    fn add(self, other: &'a #struct_name) -> Self::Output {
585                        // Safety: ensure setup
586                        unsafe { self.add_ref::<true>(other) }
587                    }
588                }
589
590                impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
591                    #[inline(always)]
592                    fn sub_assign(&mut self, other: &'a #struct_name) {
593                        self.sub_assign_impl(other);
594                    }
595                }
596
597                impl core::ops::SubAssign for #struct_name {
598                    #[inline(always)]
599                    fn sub_assign(&mut self, other: Self) {
600                        self.sub_assign_impl(&other);
601                    }
602                }
603
604                impl core::ops::Sub for #struct_name {
605                    type Output = Self;
606                    #[inline(always)]
607                    fn sub(mut self, other: Self) -> Self::Output {
608                        self -= other;
609                        self
610                    }
611                }
612
613                impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
614                    type Output = Self;
615                    #[inline(always)]
616                    fn sub(mut self, other: &'a #struct_name) -> Self::Output {
617                        self -= other;
618                        self
619                    }
620                }
621
622                impl<'a> core::ops::Sub<&'a #struct_name> for &'a #struct_name {
623                    type Output = #struct_name;
624                    #[inline(always)]
625                    fn sub(self, other: &'a #struct_name) -> Self::Output {
626                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
627                        unsafe {
628                            self.sub_refs_impl(other, uninit.as_mut_ptr());
629                            uninit.assume_init()
630                        }
631                    }
632                }
633
634                impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
635                    #[inline(always)]
636                    fn mul_assign(&mut self, other: &'a #struct_name) {
637                        self.mul_assign_impl(other);
638                    }
639                }
640
641                impl core::ops::MulAssign for #struct_name {
642                    #[inline(always)]
643                    fn mul_assign(&mut self, other: Self) {
644                        self.mul_assign_impl(&other);
645                    }
646                }
647
648                impl core::ops::Mul for #struct_name {
649                    type Output = Self;
650                    #[inline(always)]
651                    fn mul(mut self, other: Self) -> Self::Output {
652                        self *= other;
653                        self
654                    }
655                }
656
657                impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
658                    type Output = Self;
659                    #[inline(always)]
660                    fn mul(mut self, other: &'a #struct_name) -> Self::Output {
661                        self *= other;
662                        self
663                    }
664                }
665
666                impl<'a> core::ops::Mul<&'a #struct_name> for &#struct_name {
667                    type Output = #struct_name;
668                    #[inline(always)]
669                    fn mul(self, other: &'a #struct_name) -> Self::Output {
670                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
671                        unsafe {
672                            self.mul_refs_impl(other, uninit.as_mut_ptr());
673                            uninit.assume_init()
674                        }
675                    }
676                }
677
678                impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
679                    /// Undefined behaviour when denominator is not coprime to N
680                    #[inline(always)]
681                    fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
682                        self.div_assign_unsafe_impl(other);
683                    }
684                }
685
686                impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
687                    /// Undefined behaviour when denominator is not coprime to N
688                    #[inline(always)]
689                    fn div_assign_unsafe(&mut self, other: Self) {
690                        self.div_assign_unsafe_impl(&other);
691                    }
692                }
693
694                impl openvm_algebra_guest::DivUnsafe for #struct_name {
695                    type Output = Self;
696                    /// Undefined behaviour when denominator is not coprime to N
697                    #[inline(always)]
698                    fn div_unsafe(mut self, other: Self) -> Self::Output {
699                        self.div_assign_unsafe_impl(&other);
700                        self
701                    }
702                }
703
704                impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
705                    type Output = Self;
706                    /// Undefined behaviour when denominator is not coprime to N
707                    #[inline(always)]
708                    fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
709                        self.div_assign_unsafe_impl(other);
710                        self
711                    }
712                }
713
714                impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
715                    type Output = #struct_name;
716                    /// Undefined behaviour when denominator is not coprime to N
717                    #[inline(always)]
718                    fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
719                        self.div_unsafe_refs_impl(other)
720                    }
721                }
722
723                impl PartialEq for #struct_name {
724                    #[inline(always)]
725                    fn eq(&self, other: &Self) -> bool {
726                        // Safety: must check setup
727                        unsafe { self.eq_impl::<true>(other) }
728                    }
729                }
730
731                impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
732                    fn sum<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
733                        iter.fold(Self::ZERO, |acc, x| &acc + x)
734                    }
735                }
736
737                impl core::iter::Sum for #struct_name {
738                    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
739                        iter.fold(Self::ZERO, |acc, x| &acc + &x)
740                    }
741                }
742
743                impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
744                    fn product<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
745                        iter.fold(Self::ONE, |acc, x| &acc * x)
746                    }
747                }
748
749                impl core::iter::Product for #struct_name {
750                    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
751                        iter.fold(Self::ONE, |acc, x| &acc * &x)
752                    }
753                }
754
755                impl core::ops::Neg for #struct_name {
756                    type Output = #struct_name;
757                    #[inline(always)]
758                    fn neg(self) -> Self::Output {
759                        #struct_name::ZERO - &self
760                    }
761                }
762
763                impl<'a> core::ops::Neg for &'a #struct_name {
764                    type Output = #struct_name;
765                    #[inline(always)]
766                    fn neg(self) -> Self::Output {
767                        #struct_name::ZERO - self
768                    }
769                }
770
771                impl core::fmt::Debug for #struct_name {
772                    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
773                        write!(f, "{:?}", self.as_le_bytes())
774                    }
775                }
776            }
777
778            impl openvm_algebra_guest::Reduce for #struct_name {
779                fn reduce_le_bytes(bytes: &[u8]) -> Self {
780                    let mut res = <Self as openvm_algebra_guest::IntMod>::ZERO;
781                    // base should be 2 ^ #limbs which exceeds what Self can represent
782                    let mut base = <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(&[255u8; #limbs]);
783                    base += <Self as openvm_algebra_guest::IntMod>::ONE;
784                    for chunk in bytes.chunks(#limbs).rev() {
785                        res = res * &base + <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(chunk);
786                    }
787                    openvm_algebra_guest::IntMod::assert_reduced(&res);
788                    res
789                }
790            }
791        });
792
793        output.push(result);
794
795        let modulus_biguint = BigUint::from_bytes_le(&modulus_bytes);
796        let modulus_is_prime = is_prime(&modulus_biguint, None);
797
798        if modulus_is_prime.probably() {
799            // implement Field and Sqrt traits for prime moduli
800            let field_and_sqrt_impl = TokenStream::from(quote::quote_spanned! { span.into() =>
801                impl ::openvm_algebra_guest::Field for #struct_name {
802                    const ZERO: Self = <Self as ::openvm_algebra_guest::IntMod>::ZERO;
803                    const ONE: Self = <Self as ::openvm_algebra_guest::IntMod>::ONE;
804
805                    type SelfRef<'a> = &'a Self;
806
807                    fn double_assign(&mut self) {
808                        ::openvm_algebra_guest::IntMod::double_assign(self);
809                    }
810
811                    fn square_assign(&mut self) {
812                        ::openvm_algebra_guest::IntMod::square_assign(self);
813                    }
814
815                }
816
817                impl openvm_algebra_guest::Sqrt for #struct_name {
818                    // Returns a sqrt of self if it exists, otherwise None.
819                    // Note that we use a hint-based approach to prove whether the square root exists.
820                    // This approach works for prime moduli, but not necessarily for composite moduli,
821                    // which is why we have the sqrt method in the Field trait, not the IntMod trait.
822                    fn sqrt(&self) -> Option<Self> {
823                        match self.honest_host_sqrt() {
824                            // self is a square
825                            Some(Some(sqrt)) => Some(sqrt),
826                            // self is not a square
827                            Some(None) => None,
828                            // host is dishonest
829                            None => {
830                                // host is dishonest, enter infinite loop
831                                loop {
832                                    openvm::io::println("ERROR: Square root hint is invalid. Entering infinite loop.");
833                                }
834                            }
835                        }
836                    }
837                }
838
839                impl #struct_name {
840                    // Returns None if the hint is incorrect (i.e. the host is dishonest)
841                    // Returns Some(None) if the hint proves that self is not a quadratic residue
842                    // Otherwise, returns Some(Some(sqrt)) where sqrt is a square root of self
843                    fn honest_host_sqrt(&self) -> Option<Option<Self>> {
844                        let (is_square, sqrt) = self.hint_sqrt_impl()?;
845
846                        if is_square {
847                            // ensure sqrt < modulus
848                            <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
849
850                            if &(&sqrt * &sqrt) == self {
851                                Some(Some(sqrt))
852                            } else {
853                                None
854                            }
855                        } else {
856                            // ensure sqrt < modulus
857                            <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
858
859                            if &sqrt * &sqrt == self * Self::get_non_qr() {
860                                Some(None)
861                            } else {
862                                None
863                            }
864                        }
865                    }
866
867
868                    // Returns None if the hint is malformed.
869                    // Otherwise, returns Some((is_square, sqrt)) where sqrt is a square root of self if is_square is true,
870                    // and a square root of self * non_qr if is_square is false.
871                    fn hint_sqrt_impl(&self) -> Option<(bool, Self)> {
872                        #[cfg(not(target_os = "zkvm"))]
873                        {
874                            unimplemented!();
875                        }
876                        #[cfg(target_os = "zkvm")]
877                        {
878                            use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_store_u32! and hint_buffer_u32!
879
880                            let is_square = core::mem::MaybeUninit::<u32>::uninit();
881                            let sqrt = core::mem::MaybeUninit::<#struct_name>::uninit();
882                            unsafe {
883                                #hint_sqrt_extern_func(self as *const #struct_name as usize);
884                                let is_square_ptr = is_square.as_ptr() as *const u32;
885                                openvm_rv32im_guest::hint_store_u32!(is_square_ptr);
886                                openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
887                                let is_square = is_square.assume_init();
888                                if is_square == 0 || is_square == 1 {
889                                    Some((is_square == 1, sqrt.assume_init()))
890                                } else {
891                                    None
892                                }
893                            }
894                        }
895                    }
896
897                    // Generate a non quadratic residue by using a hint
898                    fn init_non_qr() -> alloc::boxed::Box<#struct_name> {
899                        #[cfg(not(target_os = "zkvm"))]
900                        {
901                            unimplemented!();
902                        }
903                        #[cfg(target_os = "zkvm")]
904                        {
905                            use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; // needed for hint_buffer_u32!
906
907                            let mut non_qr_uninit = core::mem::MaybeUninit::<Self>::uninit();
908                            let mut non_qr;
909                            unsafe {
910                                #hint_non_qr_extern_func();
911                                let ptr = non_qr_uninit.as_ptr() as *const u8;
912                                openvm_rv32im_guest::hint_buffer_u32!(ptr, <Self as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
913                                non_qr = non_qr_uninit.assume_init();
914                            }
915                            // ensure non_qr < modulus
916                            <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&non_qr);
917
918                            use ::openvm_algebra_guest::{DivUnsafe, ExpBytes};
919                            // construct exp = (p-1)/2 as an integer by first constraining exp = (p-1)/2 (mod p) and then exp < p
920                            let exp = -<Self as ::openvm_algebra_guest::IntMod>::ONE.div_unsafe(Self::from_const_u8(2));
921                            <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&exp);
922
923                            if non_qr.exp_bytes(true, &<Self as ::openvm_algebra_guest::IntMod>::to_be_bytes(&exp)) != -<Self as ::openvm_algebra_guest::IntMod>::ONE
924                            {
925                                // non_qr is not a non quadratic residue, so host is dishonest
926                                loop {
927                                    openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop.");
928                                }
929                            }
930
931                            alloc::boxed::Box::new(non_qr)
932                        }
933                    }
934
935                    // This function is public for use in tests
936                    pub fn get_non_qr() -> &'static #struct_name {
937                        static non_qr: ::openvm_algebra_guest::once_cell::race::OnceBox<#struct_name> = ::openvm_algebra_guest::once_cell::race::OnceBox::new();
938                        &non_qr.get_or_init(Self::init_non_qr)
939                    }
940                }
941            });
942
943            output.push(field_and_sqrt_impl);
944        }
945    }
946
947    TokenStream::from_iter(output)
948}
949
950struct ModuliDefine {
951    items: Vec<LitStr>,
952}
953
954impl Parse for ModuliDefine {
955    fn parse(input: ParseStream) -> syn::Result<Self> {
956        let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
957        Ok(Self {
958            items: items.into_iter().collect(),
959        })
960    }
961}
962
963#[proc_macro]
964pub fn moduli_init(input: TokenStream) -> TokenStream {
965    let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine);
966
967    let mut externs = Vec::new();
968
969    // List of all modular limbs in one (that is, with a compile-time known size) array.
970    let mut two_modular_limbs_flattened_list = Vec::<u8>::new();
971    // List of "bars" between adjacent modular limbs sublists.
972    let mut limb_list_borders = vec![0usize];
973
974    let span = proc_macro::Span::call_site();
975
976    for (mod_idx, item) in items.into_iter().enumerate() {
977        let modulus = item.value();
978        let modulus_bytes = string_to_bytes(&modulus);
979        let mut limbs = modulus_bytes.len();
980        let mut block_size = 32;
981
982        if limbs <= 32 {
983            limbs = 32;
984        } else if limbs <= 48 {
985            limbs = 48;
986            block_size = 16;
987        } else {
988            panic!("limbs must be at most 48");
989        }
990
991        let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
992        let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
993
994        let modulus_bytes = modulus_bytes
995            .into_iter()
996            .chain(vec![0u8; limbs])
997            .take(limbs)
998            .collect::<Vec<_>>();
999
1000        // We need two copies of modular limbs for Fp2 setup.
1001        let doubled_modulus = [modulus_bytes.clone(), modulus_bytes.clone()].concat();
1002        two_modular_limbs_flattened_list.extend(doubled_modulus);
1003        limb_list_borders.push(two_modular_limbs_flattened_list.len());
1004
1005        let modulus_hex = modulus_bytes
1006            .iter()
1007            .rev()
1008            .map(|x| format!("{:02x}", x))
1009            .collect::<Vec<_>>()
1010            .join("");
1011
1012        let setup_extern_func = syn::Ident::new(
1013            &format!("moduli_setup_extern_func_{}", modulus_hex),
1014            span.into(),
1015        );
1016
1017        for op_type in ["add", "sub", "mul", "div"] {
1018            let func_name = syn::Ident::new(
1019                &format!("{}_extern_func_{}", op_type, modulus_hex),
1020                span.into(),
1021            );
1022            let mut chars = op_type.chars().collect::<Vec<_>>();
1023            chars[0] = chars[0].to_ascii_uppercase();
1024            let local_opcode = syn::Ident::new(
1025                &format!("{}Mod", chars.iter().collect::<String>()),
1026                span.into(),
1027            );
1028            externs.push(quote::quote_spanned! { span.into() =>
1029                #[no_mangle]
1030                extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
1031                    openvm::platform::custom_insn_r!(
1032                        opcode = ::openvm_algebra_guest::OPCODE,
1033                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1034                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::#local_opcode as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1035                        rd = In rd,
1036                        rs1 = In rs1,
1037                        rs2 = In rs2
1038                    )
1039                }
1040            });
1041        }
1042
1043        let is_eq_extern_func =
1044            syn::Ident::new(&format!("is_eq_extern_func_{}", modulus_hex), span.into());
1045        externs.push(quote::quote_spanned! { span.into() =>
1046            #[no_mangle]
1047            extern "C" fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool {
1048                let mut x: u32;
1049                openvm::platform::custom_insn_r!(
1050                    opcode = ::openvm_algebra_guest::OPCODE,
1051                    funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1052                    funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::IsEqMod as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1053                    rd = Out x,
1054                    rs1 = In rs1,
1055                    rs2 = In rs2
1056                );
1057                x != 0
1058            }
1059        });
1060
1061        let hint_non_qr_extern_func = syn::Ident::new(
1062            &format!("hint_non_qr_extern_func_{}", modulus_hex),
1063            span.into(),
1064        );
1065        externs.push(quote::quote_spanned! { span.into() =>
1066            #[no_mangle]
1067            extern "C" fn #hint_non_qr_extern_func() {
1068                openvm::platform::custom_insn_r!(
1069                    opcode = ::openvm_algebra_guest::OPCODE,
1070                    funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1071                    funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintNonQr as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1072                    rd = Const "x0",
1073                    rs1 = Const "x0",
1074                    rs2 = Const "x0"
1075                );
1076            }
1077
1078
1079        });
1080
1081        // This function will be defined regardless of whether the modulus is prime or not,
1082        // but it will be called only if the modulus is prime.
1083        let hint_sqrt_extern_func = syn::Ident::new(
1084            &format!("hint_sqrt_extern_func_{}", modulus_hex),
1085            span.into(),
1086        );
1087        externs.push(quote::quote_spanned! { span.into() =>
1088            #[no_mangle]
1089            extern "C" fn #hint_sqrt_extern_func(rs1: usize) {
1090                openvm::platform::custom_insn_r!(
1091                    opcode = ::openvm_algebra_guest::OPCODE,
1092                    funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1093                    funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintSqrt as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1094                    rd = Const "x0",
1095                    rs1 = In rs1,
1096                    rs2 = Const "x0"
1097                );
1098            }
1099        });
1100
1101        externs.push(quote::quote_spanned! { span.into() =>
1102            #[no_mangle]
1103            extern "C" fn #setup_extern_func() {
1104                #[cfg(target_os = "zkvm")]
1105                {
1106                    // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment.
1107                    #[repr(C, align(#block_size))]
1108                    struct AlignedPlaceholder([u8; #limbs]);
1109
1110                    const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]);
1111
1112                    // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup.
1113                    // The transpiler will transform this instruction, based on whether `rs2` is `x0`, `x1` or `x2`, into a `SETUP_ADDSUB`, `SETUP_MULDIV` or `SETUP_ISEQ` instruction.
1114                    let mut uninit: core::mem::MaybeUninit<AlignedPlaceholder> = core::mem::MaybeUninit::uninit();
1115                    openvm::platform::custom_insn_r!(
1116                        opcode = ::openvm_algebra_guest::OPCODE,
1117                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1118                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1119                            + #mod_idx
1120                                * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1121                        rd = In uninit.as_mut_ptr(),
1122                        rs1 = In MODULUS_BYTES.0.as_ptr(),
1123                        rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD
1124                    );
1125                    openvm::platform::custom_insn_r!(
1126                        opcode = ::openvm_algebra_guest::OPCODE,
1127                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1128                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1129                            + #mod_idx
1130                                * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1131                        rd = In uninit.as_mut_ptr(),
1132                        rs1 = In MODULUS_BYTES.0.as_ptr(),
1133                        rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV
1134                    );
1135                    unsafe {
1136                        // This should not be x0:
1137                        let mut tmp = uninit.as_mut_ptr() as usize;
1138                        openvm::platform::custom_insn_r!(
1139                            opcode = ::openvm_algebra_guest::OPCODE,
1140                            funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1141                            funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1142                                + #mod_idx
1143                                    * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1144                            rd = InOut tmp,
1145                            rs1 = In MODULUS_BYTES.0.as_ptr(),
1146                            rs2 = Const "x2" // will be parsed as 2 and therefore transpiled to SETUP_ISEQ
1147                        );
1148                        // rd = inout(reg) is necessary because this instruction will write to `rd` register
1149                    }
1150                }
1151            }
1152        });
1153    }
1154
1155    let total_limbs_cnt = two_modular_limbs_flattened_list.len();
1156    let cnt_limbs_list_len = limb_list_borders.len();
1157    TokenStream::from(quote::quote_spanned! { span.into() =>
1158        #[allow(non_snake_case)]
1159        #[cfg(target_os = "zkvm")]
1160        mod openvm_intrinsics_ffi {
1161            #(#externs)*
1162        }
1163        #[allow(non_snake_case, non_upper_case_globals)]
1164        pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself {
1165            pub const two_modular_limbs_list: [u8; #total_limbs_cnt] = [#(#two_modular_limbs_flattened_list),*];
1166            pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*];
1167        }
1168    })
1169}