1extern crate proc_macro;
2
3use std::sync::atomic::AtomicUsize;
4
5use openvm_macros_common::{string_to_bytes, MacroArgs};
6use proc_macro::TokenStream;
7use quote::format_ident;
8use syn::{
9 parse::{Parse, ParseStream},
10 parse_macro_input, LitStr, Token,
11};
12
13static MOD_IDX: AtomicUsize = AtomicUsize::new(0);
14
15#[proc_macro]
26pub fn moduli_declare(input: TokenStream) -> TokenStream {
27 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
28
29 let mut output = Vec::new();
30
31 let span = proc_macro::Span::call_site();
32
33 for item in items {
34 let struct_name = item.name.to_string();
35 let struct_name = syn::Ident::new(&struct_name, span.into());
36 let mut modulus: Option<String> = None;
37 for param in item.params {
38 match param.name.to_string().as_str() {
39 "modulus" => {
40 if let syn::Expr::Lit(syn::ExprLit {
41 lit: syn::Lit::Str(value),
42 ..
43 }) = param.value
44 {
45 modulus = Some(value.value());
46 } else {
47 return syn::Error::new_spanned(param.value, "Expected a string literal")
48 .to_compile_error()
49 .into();
50 }
51 }
52 _ => {
53 panic!("Unknown parameter {}", param.name);
54 }
55 }
56 }
57
58 let mod_idx = MOD_IDX.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
61
62 let modulus = modulus.expect("modulus parameter is required");
63 let modulus_bytes = string_to_bytes(&modulus);
64 let mut limbs = modulus_bytes.len();
65 let mut block_size = 32;
66
67 if limbs <= 32 {
68 limbs = 32;
69 } else if limbs <= 48 {
70 limbs = 48;
71 block_size = 16;
72 } else {
73 panic!("limbs must be at most 48");
74 }
75
76 let modulus_bytes = modulus_bytes
77 .into_iter()
78 .chain(vec![0u8; limbs])
79 .take(limbs)
80 .collect::<Vec<_>>();
81
82 let modulus_hex = modulus_bytes
83 .iter()
84 .rev()
85 .map(|x| format!("{:02x}", x))
86 .collect::<Vec<_>>()
87 .join("");
88 macro_rules! create_extern_func {
89 ($name:ident) => {
90 let $name = syn::Ident::new(
91 &format!("{}_{}", stringify!($name), modulus_hex),
92 span.into(),
93 );
94 };
95 }
96 create_extern_func!(add_extern_func);
97 create_extern_func!(sub_extern_func);
98 create_extern_func!(mul_extern_func);
99 create_extern_func!(div_extern_func);
100 create_extern_func!(is_eq_extern_func);
101
102 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
103 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
104
105 let module_name = format_ident!("algebra_impl_{}", mod_idx);
106
107 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
108 #[derive(Clone, Eq, serde::Serialize, serde::Deserialize)]
120 #[repr(C, align(#block_size))]
121 pub struct #struct_name(#[serde(with = "openvm_algebra_guest::BigArray")] [u8; #limbs]);
122
123 extern "C" {
124 fn #add_extern_func(rd: usize, rs1: usize, rs2: usize);
125 fn #sub_extern_func(rd: usize, rs1: usize, rs2: usize);
126 fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize);
127 fn #div_extern_func(rd: usize, rs1: usize, rs2: usize);
128 fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool;
129 }
130
131 impl #struct_name {
132 #[inline(always)]
133 const fn from_const_u8(val: u8) -> Self {
134 let mut bytes = [0; #limbs];
135 bytes[0] = val;
136 Self(bytes)
137 }
138
139 pub const fn from_const_bytes(bytes: [u8; #limbs]) -> Self {
142 Self(bytes)
143 }
144
145 #[inline(always)]
146 fn add_assign_impl(&mut self, other: &Self) {
147 #[cfg(not(target_os = "zkvm"))]
148 {
149 *self = Self::from_biguint(
150 (self.as_biguint() + other.as_biguint()) % Self::modulus_biguint(),
151 );
152 }
153 #[cfg(target_os = "zkvm")]
154 {
155 unsafe {
156 #add_extern_func(
157 self as *mut Self as usize,
158 self as *const Self as usize,
159 other as *const Self as usize,
160 );
161 }
162 }
163 }
164
165 #[inline(always)]
166 fn sub_assign_impl(&mut self, other: &Self) {
167 #[cfg(not(target_os = "zkvm"))]
168 {
169 let modulus = Self::modulus_biguint();
170 *self = Self::from_biguint(
171 (self.as_biguint() + modulus.clone() - other.as_biguint()) % modulus,
172 );
173 }
174 #[cfg(target_os = "zkvm")]
175 {
176 unsafe {
177 #sub_extern_func(
178 self as *mut Self as usize,
179 self as *const Self as usize,
180 other as *const Self as usize,
181 );
182 }
183 }
184 }
185
186 #[inline(always)]
187 fn mul_assign_impl(&mut self, other: &Self) {
188 #[cfg(not(target_os = "zkvm"))]
189 {
190 *self = Self::from_biguint(
191 (self.as_biguint() * other.as_biguint()) % Self::modulus_biguint(),
192 );
193 }
194 #[cfg(target_os = "zkvm")]
195 {
196 unsafe {
197 #mul_extern_func(
198 self as *mut Self as usize,
199 self as *const Self as usize,
200 other as *const Self as usize,
201 );
202 }
203 }
204 }
205
206 #[inline(always)]
207 fn div_assign_unsafe_impl(&mut self, other: &Self) {
208 #[cfg(not(target_os = "zkvm"))]
209 {
210 let modulus = Self::modulus_biguint();
211 let inv = other.as_biguint().modinv(&modulus).unwrap();
212 *self = Self::from_biguint((self.as_biguint() * inv) % modulus);
213 }
214 #[cfg(target_os = "zkvm")]
215 {
216 unsafe {
217 #div_extern_func(
218 self as *mut Self as usize,
219 self as *const Self as usize,
220 other as *const Self as usize,
221 );
222 }
223 }
224 }
225
226 #[inline(always)]
229 unsafe fn add_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
230 #[cfg(not(target_os = "zkvm"))]
231 {
232 let mut res = self.clone();
233 res += other;
234 let dst = unsafe { &mut *dst_ptr };
236 *dst = res;
237 }
238 #[cfg(target_os = "zkvm")]
239 {
240 unsafe {
241 #add_extern_func(
242 dst_ptr as usize,
243 self as *const #struct_name as usize,
244 other as *const #struct_name as usize,
245 );
246 }
247 }
248 }
249
250 #[inline(always)]
253 unsafe fn sub_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
254 #[cfg(not(target_os = "zkvm"))]
255 {
256 let mut res = self.clone();
257 res -= other;
258 let dst = unsafe { &mut *dst_ptr };
260 *dst = res;
261 }
262 #[cfg(target_os = "zkvm")]
263 {
264 unsafe {
265 #sub_extern_func(
266 dst_ptr as usize,
267 self as *const #struct_name as usize,
268 other as *const #struct_name as usize,
269 );
270 }
271 }
272 }
273
274 #[inline(always)]
277 unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
278 #[cfg(not(target_os = "zkvm"))]
279 {
280 let mut res = self.clone();
281 res *= other;
282 let dst = unsafe { &mut *dst_ptr };
284 *dst = res;
285 }
286 #[cfg(target_os = "zkvm")]
287 {
288 unsafe {
289 #mul_extern_func(
290 dst_ptr as usize,
291 self as *const #struct_name as usize,
292 other as *const #struct_name as usize,
293 );
294 }
295 }
296 }
297
298 #[inline(always)]
299 fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
300 #[cfg(not(target_os = "zkvm"))]
301 {
302 let modulus = Self::modulus_biguint();
303 let inv = other.as_biguint().modinv(&modulus).unwrap();
304 Self::from_biguint((self.as_biguint() * inv) % modulus)
305 }
306 #[cfg(target_os = "zkvm")]
307 {
308 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
309 unsafe {
310 #div_extern_func(
311 uninit.as_mut_ptr() as usize,
312 self as *const #struct_name as usize,
313 other as *const #struct_name as usize,
314 );
315 }
316 unsafe { uninit.assume_init() }
317 }
318 }
319
320 #[inline(always)]
321 fn eq_impl(&self, other: &Self) -> bool {
322 #[cfg(not(target_os = "zkvm"))]
323 {
324 self.as_le_bytes() == other.as_le_bytes()
325 }
326 #[cfg(target_os = "zkvm")]
327 {
328 unsafe {
329 #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize)
330 }
331 }
332 }
333 }
334
335 mod #module_name {
337 use openvm_algebra_guest::IntMod;
338
339 use super::#struct_name;
340
341 impl IntMod for #struct_name {
342 type Repr = [u8; #limbs];
343 type SelfRef<'a> = &'a Self;
344
345 const MODULUS: Self::Repr = [#(#modulus_bytes),*];
346
347 const ZERO: Self = Self([0; #limbs]);
348
349 const NUM_LIMBS: usize = #limbs;
350
351 const ONE: Self = Self::from_const_u8(1);
352
353 fn from_repr(repr: Self::Repr) -> Self {
354 Self(repr)
355 }
356
357 fn from_le_bytes(bytes: &[u8]) -> Self {
358 let mut arr = [0u8; #limbs];
359 arr.copy_from_slice(bytes);
360 Self(arr)
361 }
362
363 fn from_be_bytes(bytes: &[u8]) -> Self {
364 let mut arr = [0u8; #limbs];
365 for (a, b) in arr.iter_mut().zip(bytes.iter().rev()) {
366 *a = *b;
367 }
368 Self(arr)
369 }
370
371 fn from_u8(val: u8) -> Self {
372 Self::from_const_u8(val)
373 }
374
375 fn from_u32(val: u32) -> Self {
376 let mut bytes = [0; #limbs];
377 bytes[..4].copy_from_slice(&val.to_le_bytes());
378 Self(bytes)
379 }
380
381 fn from_u64(val: u64) -> Self {
382 let mut bytes = [0; #limbs];
383 bytes[..8].copy_from_slice(&val.to_le_bytes());
384 Self(bytes)
385 }
386
387 fn as_le_bytes(&self) -> &[u8] {
388 &(self.0)
389 }
390
391 fn to_be_bytes(&self) -> [u8; #limbs] {
392 core::array::from_fn(|i| self.0[#limbs - 1 - i])
393 }
394
395 #[cfg(not(target_os = "zkvm"))]
396 fn modulus_biguint() -> num_bigint::BigUint {
397 num_bigint::BigUint::from_bytes_le(&Self::MODULUS)
398 }
399
400 #[cfg(not(target_os = "zkvm"))]
401 fn from_biguint(biguint: num_bigint::BigUint) -> Self {
402 Self(openvm::utils::biguint_to_limbs(&biguint))
403 }
404
405 #[cfg(not(target_os = "zkvm"))]
406 fn as_biguint(&self) -> num_bigint::BigUint {
407 num_bigint::BigUint::from_bytes_le(self.as_le_bytes())
408 }
409
410 fn neg_assign(&mut self) {
411 unsafe {
412 (#struct_name::ZERO).sub_refs_impl(self, self as *const Self as *mut Self);
415 }
416 }
417
418 fn double_assign(&mut self) {
419 unsafe {
420 self.add_refs_impl(self, self as *const Self as *mut Self);
423 }
424 }
425
426 fn square_assign(&mut self) {
427 unsafe {
428 self.mul_refs_impl(self, self as *const Self as *mut Self);
431 }
432 }
433
434 fn double(&self) -> Self {
435 self + self
436 }
437
438 fn square(&self) -> Self {
439 self * self
440 }
441
442 fn cube(&self) -> Self {
443 &self.square() * self
444 }
445
446 fn assert_reduced(&self) {
451 let _ = core::hint::black_box(PartialEq::eq(self, self));
453 }
454
455 fn is_reduced(&self) -> bool {
456 for (x_limb, p_limb) in self.0.iter().rev().zip(Self::MODULUS.iter().rev()) {
458 if x_limb < p_limb {
459 return true;
460 } else if x_limb > p_limb {
461 return false;
462 }
463 }
464 false
466 }
467 }
468
469 impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
470 #[inline(always)]
471 fn add_assign(&mut self, other: &'a #struct_name) {
472 self.add_assign_impl(other);
473 }
474 }
475
476 impl core::ops::AddAssign for #struct_name {
477 #[inline(always)]
478 fn add_assign(&mut self, other: Self) {
479 self.add_assign_impl(&other);
480 }
481 }
482
483 impl core::ops::Add for #struct_name {
484 type Output = Self;
485 #[inline(always)]
486 fn add(mut self, other: Self) -> Self::Output {
487 self += other;
488 self
489 }
490 }
491
492 impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
493 type Output = Self;
494 #[inline(always)]
495 fn add(mut self, other: &'a #struct_name) -> Self::Output {
496 self += other;
497 self
498 }
499 }
500
501 impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
502 type Output = #struct_name;
503 #[inline(always)]
504 fn add(self, other: &'a #struct_name) -> Self::Output {
505 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
506 unsafe {
507 self.add_refs_impl(other, uninit.as_mut_ptr());
508 uninit.assume_init()
509 }
510 }
511 }
512
513 impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
514 #[inline(always)]
515 fn sub_assign(&mut self, other: &'a #struct_name) {
516 self.sub_assign_impl(other);
517 }
518 }
519
520 impl core::ops::SubAssign for #struct_name {
521 #[inline(always)]
522 fn sub_assign(&mut self, other: Self) {
523 self.sub_assign_impl(&other);
524 }
525 }
526
527 impl core::ops::Sub for #struct_name {
528 type Output = Self;
529 #[inline(always)]
530 fn sub(mut self, other: Self) -> Self::Output {
531 self -= other;
532 self
533 }
534 }
535
536 impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
537 type Output = Self;
538 #[inline(always)]
539 fn sub(mut self, other: &'a #struct_name) -> Self::Output {
540 self -= other;
541 self
542 }
543 }
544
545 impl<'a> core::ops::Sub<&'a #struct_name> for &'a #struct_name {
546 type Output = #struct_name;
547 #[inline(always)]
548 fn sub(self, other: &'a #struct_name) -> Self::Output {
549 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
550 unsafe {
551 self.sub_refs_impl(other, uninit.as_mut_ptr());
552 uninit.assume_init()
553 }
554 }
555 }
556
557 impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
558 #[inline(always)]
559 fn mul_assign(&mut self, other: &'a #struct_name) {
560 self.mul_assign_impl(other);
561 }
562 }
563
564 impl core::ops::MulAssign for #struct_name {
565 #[inline(always)]
566 fn mul_assign(&mut self, other: Self) {
567 self.mul_assign_impl(&other);
568 }
569 }
570
571 impl core::ops::Mul for #struct_name {
572 type Output = Self;
573 #[inline(always)]
574 fn mul(mut self, other: Self) -> Self::Output {
575 self *= other;
576 self
577 }
578 }
579
580 impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
581 type Output = Self;
582 #[inline(always)]
583 fn mul(mut self, other: &'a #struct_name) -> Self::Output {
584 self *= other;
585 self
586 }
587 }
588
589 impl<'a> core::ops::Mul<&'a #struct_name> for &#struct_name {
590 type Output = #struct_name;
591 #[inline(always)]
592 fn mul(self, other: &'a #struct_name) -> Self::Output {
593 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
594 unsafe {
595 self.mul_refs_impl(other, uninit.as_mut_ptr());
596 uninit.assume_init()
597 }
598 }
599 }
600
601 impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
602 #[inline(always)]
604 fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
605 self.div_assign_unsafe_impl(other);
606 }
607 }
608
609 impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
610 #[inline(always)]
612 fn div_assign_unsafe(&mut self, other: Self) {
613 self.div_assign_unsafe_impl(&other);
614 }
615 }
616
617 impl openvm_algebra_guest::DivUnsafe for #struct_name {
618 type Output = Self;
619 #[inline(always)]
621 fn div_unsafe(mut self, other: Self) -> Self::Output {
622 self.div_assign_unsafe_impl(&other);
623 self
624 }
625 }
626
627 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
628 type Output = Self;
629 #[inline(always)]
631 fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
632 self.div_assign_unsafe_impl(other);
633 self
634 }
635 }
636
637 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
638 type Output = #struct_name;
639 #[inline(always)]
641 fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
642 self.div_unsafe_refs_impl(other)
643 }
644 }
645
646 impl PartialEq for #struct_name {
647 #[inline(always)]
648 fn eq(&self, other: &Self) -> bool {
649 self.eq_impl(other)
650 }
651 }
652
653 impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
654 fn sum<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
655 iter.fold(Self::ZERO, |acc, x| &acc + x)
656 }
657 }
658
659 impl core::iter::Sum for #struct_name {
660 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
661 iter.fold(Self::ZERO, |acc, x| &acc + &x)
662 }
663 }
664
665 impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
666 fn product<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
667 iter.fold(Self::ONE, |acc, x| &acc * x)
668 }
669 }
670
671 impl core::iter::Product for #struct_name {
672 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
673 iter.fold(Self::ONE, |acc, x| &acc * &x)
674 }
675 }
676
677 impl core::ops::Neg for #struct_name {
678 type Output = #struct_name;
679 fn neg(self) -> Self::Output {
680 #struct_name::ZERO - &self
681 }
682 }
683
684 impl<'a> core::ops::Neg for &'a #struct_name {
685 type Output = #struct_name;
686 fn neg(self) -> Self::Output {
687 #struct_name::ZERO - self
688 }
689 }
690
691 impl core::fmt::Debug for #struct_name {
692 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
693 write!(f, "{:?}", self.as_le_bytes())
694 }
695 }
696 }
697
698 impl openvm_algebra_guest::Reduce for #struct_name {
699 fn reduce_le_bytes(bytes: &[u8]) -> Self {
700 let mut res = <Self as openvm_algebra_guest::IntMod>::ZERO;
701 let mut base = Self::from_le_bytes(&[255u8; #limbs]);
703 base += <Self as openvm_algebra_guest::IntMod>::ONE;
704 for chunk in bytes.chunks(#limbs).rev() {
705 res = res * &base + Self::from_le_bytes(chunk);
706 }
707 res
708 }
709 }
710 });
711
712 output.push(result);
713 }
714
715 TokenStream::from_iter(output)
716}
717
718struct ModuliDefine {
719 items: Vec<LitStr>,
720}
721
722impl Parse for ModuliDefine {
723 fn parse(input: ParseStream) -> syn::Result<Self> {
724 let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
725 Ok(Self {
726 items: items.into_iter().collect(),
727 })
728 }
729}
730
731#[proc_macro]
732pub fn moduli_init(input: TokenStream) -> TokenStream {
733 let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine);
734
735 let mut externs = Vec::new();
736 let mut setups = Vec::new();
737 let mut openvm_section = Vec::new();
738 let mut setup_all_moduli = Vec::new();
739
740 let mut two_modular_limbs_flattened_list = Vec::<u8>::new();
742 let mut limb_list_borders = vec![0usize];
744
745 let span = proc_macro::Span::call_site();
746
747 for (mod_idx, item) in items.into_iter().enumerate() {
748 let modulus = item.value();
749 println!("[init] modulus #{} = {}", mod_idx, modulus);
750
751 let modulus_bytes = string_to_bytes(&modulus);
752 let mut limbs = modulus_bytes.len();
753 let mut block_size = 32;
754
755 if limbs <= 32 {
756 limbs = 32;
757 } else if limbs <= 48 {
758 limbs = 48;
759 block_size = 16;
760 } else {
761 panic!("limbs must be at most 48");
762 }
763
764 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
765 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
766
767 let modulus_bytes = modulus_bytes
768 .into_iter()
769 .chain(vec![0u8; limbs])
770 .take(limbs)
771 .collect::<Vec<_>>();
772
773 let doubled_modulus = [modulus_bytes.clone(), modulus_bytes.clone()].concat();
775 two_modular_limbs_flattened_list.extend(doubled_modulus);
776 limb_list_borders.push(two_modular_limbs_flattened_list.len());
777
778 let modulus_hex = modulus_bytes
779 .iter()
780 .rev()
781 .map(|x| format!("{:02x}", x))
782 .collect::<Vec<_>>()
783 .join("");
784
785 let serialized_modulus =
786 core::iter::once(1) .chain(core::iter::once(mod_idx as u8)) .chain((modulus_bytes.len() as u32).to_le_bytes().iter().copied())
790 .chain(modulus_bytes.iter().copied())
791 .collect::<Vec<_>>();
792 let serialized_name = syn::Ident::new(
793 &format!("OPENVM_SERIALIZED_MODULUS_{}", mod_idx),
794 span.into(),
795 );
796 let serialized_len = serialized_modulus.len();
797 let setup_function = syn::Ident::new(&format!("setup_{}", mod_idx), span.into());
798
799 openvm_section.push(quote::quote_spanned! { span.into() =>
800 #[cfg(target_os = "zkvm")]
801 #[link_section = ".openvm"]
802 #[no_mangle]
803 #[used]
804 static #serialized_name: [u8; #serialized_len] = [#(#serialized_modulus),*];
805 });
806
807 for op_type in ["add", "sub", "mul", "div"] {
808 let func_name = syn::Ident::new(
809 &format!("{}_extern_func_{}", op_type, modulus_hex),
810 span.into(),
811 );
812 let mut chars = op_type.chars().collect::<Vec<_>>();
813 chars[0] = chars[0].to_ascii_uppercase();
814 let local_opcode = syn::Ident::new(
815 &format!("{}Mod", chars.iter().collect::<String>()),
816 span.into(),
817 );
818 externs.push(quote::quote_spanned! { span.into() =>
819 #[no_mangle]
820 extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
821 openvm::platform::custom_insn_r!(
822 opcode = ::openvm_algebra_guest::OPCODE,
823 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
824 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::#local_opcode as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
825 rd = In rd,
826 rs1 = In rs1,
827 rs2 = In rs2
828 )
829 }
830 });
831 }
832
833 let is_eq_extern_func =
834 syn::Ident::new(&format!("is_eq_extern_func_{}", modulus_hex), span.into());
835 externs.push(quote::quote_spanned! { span.into() =>
836 #[no_mangle]
837 extern "C" fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool {
838 let mut x: u32;
839 openvm::platform::custom_insn_r!(
840 opcode = ::openvm_algebra_guest::OPCODE,
841 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
842 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::IsEqMod as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
843 rd = Out x,
844 rs1 = In rs1,
845 rs2 = In rs2
846 );
847 x != 0
848 }
849 });
850
851 setup_all_moduli.push(quote::quote_spanned! { span.into() =>
852 #setup_function();
853 });
854
855 setups.push(quote::quote_spanned! { span.into() =>
856 #[allow(non_snake_case)]
857 pub fn #setup_function() {
858 #[cfg(target_os = "zkvm")]
859 {
860 let mut ptr = 0;
861 assert_eq!(#serialized_name[ptr], 1);
862 ptr += 1;
863 assert_eq!(#serialized_name[ptr], #mod_idx as u8);
864 ptr += 1;
865 assert_eq!(#serialized_name[ptr..ptr+4].iter().rev().fold(0, |acc, &x| acc * 256 + x as usize), #limbs);
866 ptr += 4;
867 let remaining = &#serialized_name[ptr..];
868
869 #[repr(C, align(#block_size))]
871 struct AlignedPlaceholder([u8; #limbs]);
872
873 let mut uninit: core::mem::MaybeUninit<AlignedPlaceholder> = core::mem::MaybeUninit::uninit();
876 openvm::platform::custom_insn_r!(
877 opcode = ::openvm_algebra_guest::OPCODE,
878 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
879 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
880 + #mod_idx
881 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
882 rd = In uninit.as_mut_ptr(),
883 rs1 = In remaining.as_ptr(),
884 rs2 = Const "x0" );
886 openvm::platform::custom_insn_r!(
887 opcode = ::openvm_algebra_guest::OPCODE,
888 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
889 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
890 + #mod_idx
891 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
892 rd = In uninit.as_mut_ptr(),
893 rs1 = In remaining.as_ptr(),
894 rs2 = Const "x1" );
896 unsafe {
897 let mut tmp = uninit.as_mut_ptr() as usize;
899 openvm::platform::custom_insn_r!(
900 opcode = ::openvm_algebra_guest::OPCODE,
901 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
902 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
903 + #mod_idx
904 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
905 rd = InOut tmp,
906 rs1 = In remaining.as_ptr(),
907 rs2 = Const "x2" );
909 }
911 }
912 }
913 });
914 }
915
916 let total_limbs_cnt = two_modular_limbs_flattened_list.len();
917 let cnt_limbs_list_len = limb_list_borders.len();
918 TokenStream::from(quote::quote_spanned! { span.into() =>
919 #(#openvm_section)*
920 #[cfg(target_os = "zkvm")]
921 mod openvm_intrinsics_ffi {
922 #(#externs)*
923 }
924 #[allow(non_snake_case, non_upper_case_globals)]
925 pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself {
926 pub const two_modular_limbs_list: [u8; #total_limbs_cnt] = [#(#two_modular_limbs_flattened_list),*];
927 pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*];
928 }
929 #(#setups)*
930 pub fn setup_all_moduli() {
931 #(#setup_all_moduli)*
932 }
933 })
934}