openvm_algebra_moduli_macros/
lib.rs

1extern crate proc_macro;
2
3use std::sync::atomic::AtomicUsize;
4
5use openvm_macros_common::{string_to_bytes, MacroArgs};
6use proc_macro::TokenStream;
7use quote::format_ident;
8use syn::{
9    parse::{Parse, ParseStream},
10    parse_macro_input, LitStr, Token,
11};
12
13static MOD_IDX: AtomicUsize = AtomicUsize::new(0);
14
15/// This macro generates the code to setup the modulus for a given prime. Also it places the moduli
16/// into a special static variable to be later extracted from the ELF and used by the VM. Usage:
17/// ```
18/// moduli_declare! {
19///     Bls12381 { modulus = "0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab" },
20///     Bn254 { modulus = "21888242871839275222246405745257275088696311157297823662689037894645226208583" },
21/// }
22/// ```
23/// This creates two structs, `Bls12381` and `Bn254`, each representing the modular arithmetic class
24/// (implementing `Add`, `Sub` and so on).
25#[proc_macro]
26pub fn moduli_declare(input: TokenStream) -> TokenStream {
27    let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
28
29    let mut output = Vec::new();
30
31    let span = proc_macro::Span::call_site();
32
33    for item in items {
34        let struct_name = item.name.to_string();
35        let struct_name = syn::Ident::new(&struct_name, span.into());
36        let mut modulus: Option<String> = None;
37        for param in item.params {
38            match param.name.to_string().as_str() {
39                "modulus" => {
40                    if let syn::Expr::Lit(syn::ExprLit {
41                        lit: syn::Lit::Str(value),
42                        ..
43                    }) = param.value
44                    {
45                        modulus = Some(value.value());
46                    } else {
47                        return syn::Error::new_spanned(param.value, "Expected a string literal")
48                            .to_compile_error()
49                            .into();
50                    }
51                }
52                _ => {
53                    panic!("Unknown parameter {}", param.name);
54                }
55            }
56        }
57
58        // Parsing the parameters is over at this point
59
60        let mod_idx = MOD_IDX.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
61
62        let modulus = modulus.expect("modulus parameter is required");
63        let modulus_bytes = string_to_bytes(&modulus);
64        let mut limbs = modulus_bytes.len();
65        let mut block_size = 32;
66
67        if limbs <= 32 {
68            limbs = 32;
69        } else if limbs <= 48 {
70            limbs = 48;
71            block_size = 16;
72        } else {
73            panic!("limbs must be at most 48");
74        }
75
76        let modulus_bytes = modulus_bytes
77            .into_iter()
78            .chain(vec![0u8; limbs])
79            .take(limbs)
80            .collect::<Vec<_>>();
81
82        let modulus_hex = modulus_bytes
83            .iter()
84            .rev()
85            .map(|x| format!("{:02x}", x))
86            .collect::<Vec<_>>()
87            .join("");
88        macro_rules! create_extern_func {
89            ($name:ident) => {
90                let $name = syn::Ident::new(
91                    &format!("{}_{}", stringify!($name), modulus_hex),
92                    span.into(),
93                );
94            };
95        }
96        create_extern_func!(add_extern_func);
97        create_extern_func!(sub_extern_func);
98        create_extern_func!(mul_extern_func);
99        create_extern_func!(div_extern_func);
100        create_extern_func!(is_eq_extern_func);
101
102        let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
103        let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
104
105        let module_name = format_ident!("algebra_impl_{}", mod_idx);
106
107        let result = TokenStream::from(quote::quote_spanned! { span.into() =>
108            /// An element of the ring of integers modulo a positive integer.
109            /// The element is internally represented as a fixed size array of bytes.
110            ///
111            /// ## Caution
112            /// It is not guaranteed that the integer representation is less than the modulus.
113            /// After any arithmetic operation, the honest host should normalize the result
114            /// to its canonical representation less than the modulus, but guest execution does not
115            /// require it.
116            ///
117            /// See [`assert_reduced`](openvm_algebra_guest::IntMod::assert_reduced) and
118            /// [`is_reduced`](openvm_algebra_guest::IntMod::is_reduced).
119            #[derive(Clone, Eq, serde::Serialize, serde::Deserialize)]
120            #[repr(C, align(#block_size))]
121            pub struct #struct_name(#[serde(with = "openvm_algebra_guest::BigArray")] [u8; #limbs]);
122
123            extern "C" {
124                fn #add_extern_func(rd: usize, rs1: usize, rs2: usize);
125                fn #sub_extern_func(rd: usize, rs1: usize, rs2: usize);
126                fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize);
127                fn #div_extern_func(rd: usize, rs1: usize, rs2: usize);
128                fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool;
129            }
130
131            impl #struct_name {
132                #[inline(always)]
133                const fn from_const_u8(val: u8) -> Self {
134                    let mut bytes = [0; #limbs];
135                    bytes[0] = val;
136                    Self(bytes)
137                }
138
139                /// Constructor from little-endian bytes. Does not enforce the integer value of `bytes`
140                /// must be less than the modulus.
141                pub const fn from_const_bytes(bytes: [u8; #limbs]) -> Self {
142                    Self(bytes)
143                }
144
145                #[inline(always)]
146                fn add_assign_impl(&mut self, other: &Self) {
147                    #[cfg(not(target_os = "zkvm"))]
148                    {
149                        *self = Self::from_biguint(
150                            (self.as_biguint() + other.as_biguint()) % Self::modulus_biguint(),
151                        );
152                    }
153                    #[cfg(target_os = "zkvm")]
154                    {
155                        unsafe {
156                            #add_extern_func(
157                                self as *mut Self as usize,
158                                self as *const Self as usize,
159                                other as *const Self as usize,
160                            );
161                        }
162                    }
163                }
164
165                #[inline(always)]
166                fn sub_assign_impl(&mut self, other: &Self) {
167                    #[cfg(not(target_os = "zkvm"))]
168                    {
169                        let modulus = Self::modulus_biguint();
170                        *self = Self::from_biguint(
171                            (self.as_biguint() + modulus.clone() - other.as_biguint()) % modulus,
172                        );
173                    }
174                    #[cfg(target_os = "zkvm")]
175                    {
176                        unsafe {
177                            #sub_extern_func(
178                                self as *mut Self as usize,
179                                self as *const Self as usize,
180                                other as *const Self as usize,
181                            );
182                        }
183                    }
184                }
185
186                #[inline(always)]
187                fn mul_assign_impl(&mut self, other: &Self) {
188                    #[cfg(not(target_os = "zkvm"))]
189                    {
190                        *self = Self::from_biguint(
191                            (self.as_biguint() * other.as_biguint()) % Self::modulus_biguint(),
192                        );
193                    }
194                    #[cfg(target_os = "zkvm")]
195                    {
196                        unsafe {
197                            #mul_extern_func(
198                                self as *mut Self as usize,
199                                self as *const Self as usize,
200                                other as *const Self as usize,
201                            );
202                        }
203                    }
204                }
205
206                #[inline(always)]
207                fn div_assign_unsafe_impl(&mut self, other: &Self) {
208                    #[cfg(not(target_os = "zkvm"))]
209                    {
210                        let modulus = Self::modulus_biguint();
211                        let inv = other.as_biguint().modinv(&modulus).unwrap();
212                        *self = Self::from_biguint((self.as_biguint() * inv) % modulus);
213                    }
214                    #[cfg(target_os = "zkvm")]
215                    {
216                        unsafe {
217                            #div_extern_func(
218                                self as *mut Self as usize,
219                                self as *const Self as usize,
220                                other as *const Self as usize,
221                            );
222                        }
223                    }
224                }
225
226                /// SAFETY: `dst_ptr` must be a raw pointer to `&mut Self`.
227                /// It will be written to only at the very end .
228                #[inline(always)]
229                unsafe fn add_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
230                    #[cfg(not(target_os = "zkvm"))]
231                    {
232                        let mut res = self.clone();
233                        res += other;
234                        // BEWARE order of operations: when dst_ptr = other as pointers
235                        let dst = unsafe { &mut *dst_ptr };
236                        *dst = res;
237                    }
238                    #[cfg(target_os = "zkvm")]
239                    {
240                        unsafe {
241                            #add_extern_func(
242                                dst_ptr as usize,
243                                self as *const #struct_name as usize,
244                                other as *const #struct_name as usize,
245                            );
246                        }
247                    }
248                }
249
250                /// SAFETY: `dst_ptr` must be a raw pointer to `&mut Self`.
251                /// It will be written to only at the very end .
252                #[inline(always)]
253                unsafe fn sub_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
254                    #[cfg(not(target_os = "zkvm"))]
255                    {
256                        let mut res = self.clone();
257                        res -= other;
258                        // BEWARE order of operations: when dst_ptr = other as pointers
259                        let dst = unsafe { &mut *dst_ptr };
260                        *dst = res;
261                    }
262                    #[cfg(target_os = "zkvm")]
263                    {
264                        unsafe {
265                            #sub_extern_func(
266                                dst_ptr as usize,
267                                self as *const #struct_name as usize,
268                                other as *const #struct_name as usize,
269                            );
270                        }
271                    }
272                }
273
274                /// SAFETY: `dst_ptr` must be a raw pointer to `&mut Self`.
275                /// It will be written to only at the very end .
276                #[inline(always)]
277                unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
278                    #[cfg(not(target_os = "zkvm"))]
279                    {
280                        let mut res = self.clone();
281                        res *= other;
282                        // BEWARE order of operations: when dst_ptr = other as pointers
283                        let dst = unsafe { &mut *dst_ptr };
284                        *dst = res;
285                    }
286                    #[cfg(target_os = "zkvm")]
287                    {
288                        unsafe {
289                            #mul_extern_func(
290                                dst_ptr as usize,
291                                self as *const #struct_name as usize,
292                                other as *const #struct_name as usize,
293                            );
294                        }
295                    }
296                }
297
298                #[inline(always)]
299                fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
300                    #[cfg(not(target_os = "zkvm"))]
301                    {
302                        let modulus = Self::modulus_biguint();
303                        let inv = other.as_biguint().modinv(&modulus).unwrap();
304                        Self::from_biguint((self.as_biguint() * inv) % modulus)
305                    }
306                    #[cfg(target_os = "zkvm")]
307                    {
308                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
309                        unsafe {
310                            #div_extern_func(
311                                uninit.as_mut_ptr() as usize,
312                                self as *const #struct_name as usize,
313                                other as *const #struct_name as usize,
314                            );
315                        }
316                        unsafe { uninit.assume_init() }
317                    }
318                }
319
320                #[inline(always)]
321                fn eq_impl(&self, other: &Self) -> bool {
322                    #[cfg(not(target_os = "zkvm"))]
323                    {
324                        self.as_le_bytes() == other.as_le_bytes()
325                    }
326                    #[cfg(target_os = "zkvm")]
327                    {
328                        unsafe {
329                            #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize)
330                        }
331                    }
332                }
333            }
334
335            // Put trait implementations in a private module to avoid conflicts
336            mod #module_name {
337                use openvm_algebra_guest::IntMod;
338
339                use super::#struct_name;
340
341                impl IntMod for #struct_name {
342                    type Repr = [u8; #limbs];
343                    type SelfRef<'a> = &'a Self;
344
345                    const MODULUS: Self::Repr = [#(#modulus_bytes),*];
346
347                    const ZERO: Self = Self([0; #limbs]);
348
349                    const NUM_LIMBS: usize = #limbs;
350
351                    const ONE: Self = Self::from_const_u8(1);
352
353                    fn from_repr(repr: Self::Repr) -> Self {
354                        Self(repr)
355                    }
356
357                    fn from_le_bytes(bytes: &[u8]) -> Self {
358                        let mut arr = [0u8; #limbs];
359                        arr.copy_from_slice(bytes);
360                        Self(arr)
361                    }
362
363                    fn from_be_bytes(bytes: &[u8]) -> Self {
364                        let mut arr = [0u8; #limbs];
365                        for (a, b) in arr.iter_mut().zip(bytes.iter().rev()) {
366                            *a = *b;
367                        }
368                        Self(arr)
369                    }
370
371                    fn from_u8(val: u8) -> Self {
372                        Self::from_const_u8(val)
373                    }
374
375                    fn from_u32(val: u32) -> Self {
376                        let mut bytes = [0; #limbs];
377                        bytes[..4].copy_from_slice(&val.to_le_bytes());
378                        Self(bytes)
379                    }
380
381                    fn from_u64(val: u64) -> Self {
382                        let mut bytes = [0; #limbs];
383                        bytes[..8].copy_from_slice(&val.to_le_bytes());
384                        Self(bytes)
385                    }
386
387                    fn as_le_bytes(&self) -> &[u8] {
388                        &(self.0)
389                    }
390
391                    fn to_be_bytes(&self) -> [u8; #limbs] {
392                        core::array::from_fn(|i| self.0[#limbs - 1 - i])
393                    }
394
395                    #[cfg(not(target_os = "zkvm"))]
396                    fn modulus_biguint() -> num_bigint::BigUint {
397                        num_bigint::BigUint::from_bytes_le(&Self::MODULUS)
398                    }
399
400                    #[cfg(not(target_os = "zkvm"))]
401                    fn from_biguint(biguint: num_bigint::BigUint) -> Self {
402                        Self(openvm::utils::biguint_to_limbs(&biguint))
403                    }
404
405                    #[cfg(not(target_os = "zkvm"))]
406                    fn as_biguint(&self) -> num_bigint::BigUint {
407                        num_bigint::BigUint::from_bytes_le(self.as_le_bytes())
408                    }
409
410                    fn neg_assign(&mut self) {
411                        unsafe {
412                            // SAFETY: we borrow self as &Self and as *mut Self but
413                            // the latter will only be written to at the very end.
414                            (#struct_name::ZERO).sub_refs_impl(self, self as *const Self as *mut Self);
415                        }
416                    }
417
418                    fn double_assign(&mut self) {
419                        unsafe {
420                            // SAFETY: we borrow self as &Self and as *mut Self but
421                            // the latter will only be written to at the very end.
422                            self.add_refs_impl(self, self as *const Self as *mut Self);
423                        }
424                    }
425
426                    fn square_assign(&mut self) {
427                        unsafe {
428                            // SAFETY: we borrow self as &Self and as *mut Self but
429                            // the latter will only be written to at the very end.
430                            self.mul_refs_impl(self, self as *const Self as *mut Self);
431                        }
432                    }
433
434                    fn double(&self) -> Self {
435                        self + self
436                    }
437
438                    fn square(&self) -> Self {
439                        self * self
440                    }
441
442                    fn cube(&self) -> Self {
443                        &self.square() * self
444                    }
445
446                    /// If `self` is not in its canonical form, the proof will fail to verify.
447                    /// This means guest execution will never terminate (either successfully or
448                    /// unsuccessfully) if `self` is not in its canonical form.
449                    // is_eq_mod enforces `self` is less than `modulus`
450                    fn assert_reduced(&self) {
451                        // This must not be optimized out
452                        let _ = core::hint::black_box(PartialEq::eq(self, self));
453                    }
454
455                    fn is_reduced(&self) -> bool {
456                        // limbs are little endian
457                        for (x_limb, p_limb) in self.0.iter().rev().zip(Self::MODULUS.iter().rev()) {
458                            if x_limb < p_limb {
459                                return true;
460                            } else if x_limb > p_limb {
461                                return false;
462                            }
463                        }
464                        // At this point, all limbs are equal
465                        false
466                    }
467                }
468
469                impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
470                    #[inline(always)]
471                    fn add_assign(&mut self, other: &'a #struct_name) {
472                        self.add_assign_impl(other);
473                    }
474                }
475
476                impl core::ops::AddAssign for #struct_name {
477                    #[inline(always)]
478                    fn add_assign(&mut self, other: Self) {
479                        self.add_assign_impl(&other);
480                    }
481                }
482
483                impl core::ops::Add for #struct_name {
484                    type Output = Self;
485                    #[inline(always)]
486                    fn add(mut self, other: Self) -> Self::Output {
487                        self += other;
488                        self
489                    }
490                }
491
492                impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
493                    type Output = Self;
494                    #[inline(always)]
495                    fn add(mut self, other: &'a #struct_name) -> Self::Output {
496                        self += other;
497                        self
498                    }
499                }
500
501                impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
502                    type Output = #struct_name;
503                    #[inline(always)]
504                    fn add(self, other: &'a #struct_name) -> Self::Output {
505                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
506                        unsafe {
507                            self.add_refs_impl(other, uninit.as_mut_ptr());
508                            uninit.assume_init()
509                        }
510                    }
511                }
512
513                impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
514                    #[inline(always)]
515                    fn sub_assign(&mut self, other: &'a #struct_name) {
516                        self.sub_assign_impl(other);
517                    }
518                }
519
520                impl core::ops::SubAssign for #struct_name {
521                    #[inline(always)]
522                    fn sub_assign(&mut self, other: Self) {
523                        self.sub_assign_impl(&other);
524                    }
525                }
526
527                impl core::ops::Sub for #struct_name {
528                    type Output = Self;
529                    #[inline(always)]
530                    fn sub(mut self, other: Self) -> Self::Output {
531                        self -= other;
532                        self
533                    }
534                }
535
536                impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
537                    type Output = Self;
538                    #[inline(always)]
539                    fn sub(mut self, other: &'a #struct_name) -> Self::Output {
540                        self -= other;
541                        self
542                    }
543                }
544
545                impl<'a> core::ops::Sub<&'a #struct_name> for &'a #struct_name {
546                    type Output = #struct_name;
547                    #[inline(always)]
548                    fn sub(self, other: &'a #struct_name) -> Self::Output {
549                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
550                        unsafe {
551                            self.sub_refs_impl(other, uninit.as_mut_ptr());
552                            uninit.assume_init()
553                        }
554                    }
555                }
556
557                impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
558                    #[inline(always)]
559                    fn mul_assign(&mut self, other: &'a #struct_name) {
560                        self.mul_assign_impl(other);
561                    }
562                }
563
564                impl core::ops::MulAssign for #struct_name {
565                    #[inline(always)]
566                    fn mul_assign(&mut self, other: Self) {
567                        self.mul_assign_impl(&other);
568                    }
569                }
570
571                impl core::ops::Mul for #struct_name {
572                    type Output = Self;
573                    #[inline(always)]
574                    fn mul(mut self, other: Self) -> Self::Output {
575                        self *= other;
576                        self
577                    }
578                }
579
580                impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
581                    type Output = Self;
582                    #[inline(always)]
583                    fn mul(mut self, other: &'a #struct_name) -> Self::Output {
584                        self *= other;
585                        self
586                    }
587                }
588
589                impl<'a> core::ops::Mul<&'a #struct_name> for &#struct_name {
590                    type Output = #struct_name;
591                    #[inline(always)]
592                    fn mul(self, other: &'a #struct_name) -> Self::Output {
593                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
594                        unsafe {
595                            self.mul_refs_impl(other, uninit.as_mut_ptr());
596                            uninit.assume_init()
597                        }
598                    }
599                }
600
601                impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
602                    /// Undefined behaviour when denominator is not coprime to N
603                    #[inline(always)]
604                    fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
605                        self.div_assign_unsafe_impl(other);
606                    }
607                }
608
609                impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
610                    /// Undefined behaviour when denominator is not coprime to N
611                    #[inline(always)]
612                    fn div_assign_unsafe(&mut self, other: Self) {
613                        self.div_assign_unsafe_impl(&other);
614                    }
615                }
616
617                impl openvm_algebra_guest::DivUnsafe for #struct_name {
618                    type Output = Self;
619                    /// Undefined behaviour when denominator is not coprime to N
620                    #[inline(always)]
621                    fn div_unsafe(mut self, other: Self) -> Self::Output {
622                        self.div_assign_unsafe_impl(&other);
623                        self
624                    }
625                }
626
627                impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
628                    type Output = Self;
629                    /// Undefined behaviour when denominator is not coprime to N
630                    #[inline(always)]
631                    fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
632                        self.div_assign_unsafe_impl(other);
633                        self
634                    }
635                }
636
637                impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
638                    type Output = #struct_name;
639                    /// Undefined behaviour when denominator is not coprime to N
640                    #[inline(always)]
641                    fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
642                        self.div_unsafe_refs_impl(other)
643                    }
644                }
645
646                impl PartialEq for #struct_name {
647                    #[inline(always)]
648                    fn eq(&self, other: &Self) -> bool {
649                        self.eq_impl(other)
650                    }
651                }
652
653                impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
654                    fn sum<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
655                        iter.fold(Self::ZERO, |acc, x| &acc + x)
656                    }
657                }
658
659                impl core::iter::Sum for #struct_name {
660                    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
661                        iter.fold(Self::ZERO, |acc, x| &acc + &x)
662                    }
663                }
664
665                impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
666                    fn product<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
667                        iter.fold(Self::ONE, |acc, x| &acc * x)
668                    }
669                }
670
671                impl core::iter::Product for #struct_name {
672                    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
673                        iter.fold(Self::ONE, |acc, x| &acc * &x)
674                    }
675                }
676
677                impl core::ops::Neg for #struct_name {
678                    type Output = #struct_name;
679                    fn neg(self) -> Self::Output {
680                        #struct_name::ZERO - &self
681                    }
682                }
683
684                impl<'a> core::ops::Neg for &'a #struct_name {
685                    type Output = #struct_name;
686                    fn neg(self) -> Self::Output {
687                        #struct_name::ZERO - self
688                    }
689                }
690
691                impl core::fmt::Debug for #struct_name {
692                    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
693                        write!(f, "{:?}", self.as_le_bytes())
694                    }
695                }
696            }
697
698            impl openvm_algebra_guest::Reduce for #struct_name {
699                fn reduce_le_bytes(bytes: &[u8]) -> Self {
700                    let mut res = <Self as openvm_algebra_guest::IntMod>::ZERO;
701                    // base should be 2 ^ #limbs which exceeds what Self can represent
702                    let mut base = Self::from_le_bytes(&[255u8; #limbs]);
703                    base += <Self as openvm_algebra_guest::IntMod>::ONE;
704                    for chunk in bytes.chunks(#limbs).rev() {
705                        res = res * &base + Self::from_le_bytes(chunk);
706                    }
707                    res
708                }
709            }
710        });
711
712        output.push(result);
713    }
714
715    TokenStream::from_iter(output)
716}
717
718struct ModuliDefine {
719    items: Vec<LitStr>,
720}
721
722impl Parse for ModuliDefine {
723    fn parse(input: ParseStream) -> syn::Result<Self> {
724        let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
725        Ok(Self {
726            items: items.into_iter().collect(),
727        })
728    }
729}
730
731#[proc_macro]
732pub fn moduli_init(input: TokenStream) -> TokenStream {
733    let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine);
734
735    let mut externs = Vec::new();
736    let mut setups = Vec::new();
737    let mut openvm_section = Vec::new();
738    let mut setup_all_moduli = Vec::new();
739
740    // List of all modular limbs in one (that is, with a compile-time known size) array.
741    let mut two_modular_limbs_flattened_list = Vec::<u8>::new();
742    // List of "bars" between adjacent modular limbs sublists.
743    let mut limb_list_borders = vec![0usize];
744
745    let span = proc_macro::Span::call_site();
746
747    for (mod_idx, item) in items.into_iter().enumerate() {
748        let modulus = item.value();
749        println!("[init] modulus #{} = {}", mod_idx, modulus);
750
751        let modulus_bytes = string_to_bytes(&modulus);
752        let mut limbs = modulus_bytes.len();
753        let mut block_size = 32;
754
755        if limbs <= 32 {
756            limbs = 32;
757        } else if limbs <= 48 {
758            limbs = 48;
759            block_size = 16;
760        } else {
761            panic!("limbs must be at most 48");
762        }
763
764        let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
765        let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
766
767        let modulus_bytes = modulus_bytes
768            .into_iter()
769            .chain(vec![0u8; limbs])
770            .take(limbs)
771            .collect::<Vec<_>>();
772
773        // We need two copies of modular limbs for Fp2 setup.
774        let doubled_modulus = [modulus_bytes.clone(), modulus_bytes.clone()].concat();
775        two_modular_limbs_flattened_list.extend(doubled_modulus);
776        limb_list_borders.push(two_modular_limbs_flattened_list.len());
777
778        let modulus_hex = modulus_bytes
779            .iter()
780            .rev()
781            .map(|x| format!("{:02x}", x))
782            .collect::<Vec<_>>()
783            .join("");
784
785        let serialized_modulus =
786            core::iter::once(1) // 1 for "modulus"
787                .chain(core::iter::once(mod_idx as u8)) // mod_idx is u8 for now (can make it u32), because we don't know the order of
788                // variables in the elf
789                .chain((modulus_bytes.len() as u32).to_le_bytes().iter().copied())
790                .chain(modulus_bytes.iter().copied())
791                .collect::<Vec<_>>();
792        let serialized_name = syn::Ident::new(
793            &format!("OPENVM_SERIALIZED_MODULUS_{}", mod_idx),
794            span.into(),
795        );
796        let serialized_len = serialized_modulus.len();
797        let setup_function = syn::Ident::new(&format!("setup_{}", mod_idx), span.into());
798
799        openvm_section.push(quote::quote_spanned! { span.into() =>
800            #[cfg(target_os = "zkvm")]
801            #[link_section = ".openvm"]
802            #[no_mangle]
803            #[used]
804            static #serialized_name: [u8; #serialized_len] = [#(#serialized_modulus),*];
805        });
806
807        for op_type in ["add", "sub", "mul", "div"] {
808            let func_name = syn::Ident::new(
809                &format!("{}_extern_func_{}", op_type, modulus_hex),
810                span.into(),
811            );
812            let mut chars = op_type.chars().collect::<Vec<_>>();
813            chars[0] = chars[0].to_ascii_uppercase();
814            let local_opcode = syn::Ident::new(
815                &format!("{}Mod", chars.iter().collect::<String>()),
816                span.into(),
817            );
818            externs.push(quote::quote_spanned! { span.into() =>
819                #[no_mangle]
820                extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
821                    openvm::platform::custom_insn_r!(
822                        opcode = ::openvm_algebra_guest::OPCODE,
823                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
824                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::#local_opcode as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
825                        rd = In rd,
826                        rs1 = In rs1,
827                        rs2 = In rs2
828                    )
829                }
830            });
831        }
832
833        let is_eq_extern_func =
834            syn::Ident::new(&format!("is_eq_extern_func_{}", modulus_hex), span.into());
835        externs.push(quote::quote_spanned! { span.into() =>
836            #[no_mangle]
837            extern "C" fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool {
838                let mut x: u32;
839                openvm::platform::custom_insn_r!(
840                    opcode = ::openvm_algebra_guest::OPCODE,
841                    funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
842                    funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::IsEqMod as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
843                    rd = Out x,
844                    rs1 = In rs1,
845                    rs2 = In rs2
846                );
847                x != 0
848            }
849        });
850
851        setup_all_moduli.push(quote::quote_spanned! { span.into() =>
852            #setup_function();
853        });
854
855        setups.push(quote::quote_spanned! { span.into() =>
856            #[allow(non_snake_case)]
857            pub fn #setup_function() {
858                #[cfg(target_os = "zkvm")]
859                {
860                    let mut ptr = 0;
861                    assert_eq!(#serialized_name[ptr], 1);
862                    ptr += 1;
863                    assert_eq!(#serialized_name[ptr], #mod_idx as u8);
864                    ptr += 1;
865                    assert_eq!(#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs);
866                    ptr += 4;
867                    let remaining = &#serialized_name[ptr..];
868
869                    // To avoid importing #struct_name, we create a placeholder struct with the same size and alignment.
870                    #[repr(C, align(#block_size))]
871                    struct AlignedPlaceholder([u8; #limbs]);
872
873                    // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup.
874                    // The transpiler will transform this instruction, based on whether `rs2` is `x0`, `x1` or `x2`, into a `SETUP_ADDSUB`, `SETUP_MULDIV` or `SETUP_ISEQ` instruction.
875                    let mut uninit: core::mem::MaybeUninit<AlignedPlaceholder> = core::mem::MaybeUninit::uninit();
876                    openvm::platform::custom_insn_r!(
877                        opcode = ::openvm_algebra_guest::OPCODE,
878                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
879                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
880                            + #mod_idx
881                                * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
882                        rd = In uninit.as_mut_ptr(),
883                        rs1 = In remaining.as_ptr(),
884                        rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD
885                    );
886                    openvm::platform::custom_insn_r!(
887                        opcode = ::openvm_algebra_guest::OPCODE,
888                        funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
889                        funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
890                            + #mod_idx
891                                * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
892                        rd = In uninit.as_mut_ptr(),
893                        rs1 = In remaining.as_ptr(),
894                        rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV
895                    );
896                    unsafe {
897                        // This should not be x0:
898                        let mut tmp = uninit.as_mut_ptr() as usize;
899                        openvm::platform::custom_insn_r!(
900                            opcode = ::openvm_algebra_guest::OPCODE,
901                            funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
902                            funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
903                                + #mod_idx
904                                    * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
905                            rd = InOut tmp,
906                            rs1 = In remaining.as_ptr(),
907                            rs2 = Const "x2" // will be parsed as 2 and therefore transpiled to SETUP_ISEQ
908                        );
909                        // rd = inout(reg) is necessary because this instruction will write to `rd` register
910                    }
911                }
912            }
913        });
914    }
915
916    let total_limbs_cnt = two_modular_limbs_flattened_list.len();
917    let cnt_limbs_list_len = limb_list_borders.len();
918    TokenStream::from(quote::quote_spanned! { span.into() =>
919        #(#openvm_section)*
920        #[cfg(target_os = "zkvm")]
921        mod openvm_intrinsics_ffi {
922            #(#externs)*
923        }
924        #[allow(non_snake_case, non_upper_case_globals)]
925        pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself {
926            pub const two_modular_limbs_list: [u8; #total_limbs_cnt] = [#(#two_modular_limbs_flattened_list),*];
927            pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*];
928        }
929        #(#setups)*
930        pub fn setup_all_moduli() {
931            #(#setup_all_moduli)*
932        }
933    })
934}