1extern crate alloc;
2extern crate proc_macro;
3
4use std::sync::atomic::AtomicUsize;
5
6use num_bigint::BigUint;
7use num_prime::nt_funcs::is_prime;
8use openvm_macros_common::{string_to_bytes, MacroArgs};
9use proc_macro::TokenStream;
10use quote::format_ident;
11use syn::{
12 parse::{Parse, ParseStream},
13 parse_macro_input, LitStr, Token,
14};
15
16static MOD_IDX: AtomicUsize = AtomicUsize::new(0);
17
18#[proc_macro]
29pub fn moduli_declare(input: TokenStream) -> TokenStream {
30 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
31
32 let mut output = Vec::new();
33
34 let span = proc_macro::Span::call_site();
35
36 for item in items {
37 let struct_name = item.name.to_string();
38 let struct_name = syn::Ident::new(&struct_name, span.into());
39 let mut modulus: Option<String> = None;
40 for param in item.params {
41 match param.name.to_string().as_str() {
42 "modulus" => {
43 if let syn::Expr::Lit(syn::ExprLit {
44 lit: syn::Lit::Str(value),
45 ..
46 }) = param.value
47 {
48 modulus = Some(value.value());
49 } else {
50 return syn::Error::new_spanned(
51 param.value,
52 "Expected a string literal for macro argument `modulus`",
53 )
54 .to_compile_error()
55 .into();
56 }
57 }
58 _ => {
59 panic!("Unknown parameter {}", param.name);
60 }
61 }
62 }
63
64 let mod_idx = MOD_IDX.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
67
68 let modulus = modulus.expect("modulus parameter is required");
69 let modulus_bytes = string_to_bytes(&modulus);
70 let mut limbs = modulus_bytes.len();
71 let mut block_size = 32;
72
73 if limbs <= 32 {
74 limbs = 32;
75 } else if limbs <= 48 {
76 limbs = 48;
77 block_size = 16;
78 } else {
79 panic!("limbs must be at most 48");
80 }
81
82 let modulus_bytes = modulus_bytes
83 .into_iter()
84 .chain(vec![0u8; limbs])
85 .take(limbs)
86 .collect::<Vec<_>>();
87
88 let modulus_hex = modulus_bytes
89 .iter()
90 .rev()
91 .map(|x| format!("{:02x}", x))
92 .collect::<Vec<_>>()
93 .join("");
94 macro_rules! create_extern_func {
95 ($name:ident) => {
96 let $name = syn::Ident::new(
97 &format!("{}_{}", stringify!($name), modulus_hex),
98 span.into(),
99 );
100 };
101 }
102 create_extern_func!(add_extern_func);
103 create_extern_func!(sub_extern_func);
104 create_extern_func!(mul_extern_func);
105 create_extern_func!(div_extern_func);
106 create_extern_func!(is_eq_extern_func);
107 create_extern_func!(hint_sqrt_extern_func);
108 create_extern_func!(hint_non_qr_extern_func);
109 create_extern_func!(moduli_setup_extern_func);
110
111 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
112 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
113
114 let module_name = format_ident!("algebra_impl_{}", mod_idx);
115
116 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
117 #[derive(Clone, Eq, serde::Serialize, serde::Deserialize)]
129 #[repr(C, align(#block_size))]
130 pub struct #struct_name(#[serde(with = "openvm_algebra_guest::BigArray")] [u8; #limbs]);
131
132 extern "C" {
133 fn #add_extern_func(rd: usize, rs1: usize, rs2: usize);
134 fn #sub_extern_func(rd: usize, rs1: usize, rs2: usize);
135 fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize);
136 fn #div_extern_func(rd: usize, rs1: usize, rs2: usize);
137 fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool;
138 fn #hint_sqrt_extern_func(rs1: usize);
139 fn #hint_non_qr_extern_func();
140 fn #moduli_setup_extern_func();
141 }
142
143 impl #struct_name {
144 #[inline(always)]
145 const fn from_const_u8(val: u8) -> Self {
146 let mut bytes = [0; #limbs];
147 bytes[0] = val;
148 Self(bytes)
149 }
150
151 pub const fn from_const_bytes(bytes: [u8; #limbs]) -> Self {
154 Self(bytes)
155 }
156
157 #[inline(always)]
158 fn add_assign_impl(&mut self, other: &Self) {
159 #[cfg(not(target_os = "zkvm"))]
160 {
161 *self = Self::from_biguint(
162 (self.as_biguint() + other.as_biguint()) % Self::modulus_biguint(),
163 );
164 }
165 #[cfg(target_os = "zkvm")]
166 {
167 Self::set_up_once();
168 unsafe {
169 #add_extern_func(
170 self as *mut Self as usize,
171 self as *const Self as usize,
172 other as *const Self as usize,
173 );
174 }
175 }
176 }
177
178 #[inline(always)]
179 fn sub_assign_impl(&mut self, other: &Self) {
180 #[cfg(not(target_os = "zkvm"))]
181 {
182 let modulus = Self::modulus_biguint();
183 *self = Self::from_biguint(
184 (self.as_biguint() + modulus.clone() - other.as_biguint()) % modulus,
185 );
186 }
187 #[cfg(target_os = "zkvm")]
188 {
189 Self::set_up_once();
190 unsafe {
191 #sub_extern_func(
192 self as *mut Self as usize,
193 self as *const Self as usize,
194 other as *const Self as usize,
195 );
196 }
197 }
198 }
199
200 #[inline(always)]
201 fn mul_assign_impl(&mut self, other: &Self) {
202 #[cfg(not(target_os = "zkvm"))]
203 {
204 *self = Self::from_biguint(
205 (self.as_biguint() * other.as_biguint()) % Self::modulus_biguint(),
206 );
207 }
208 #[cfg(target_os = "zkvm")]
209 {
210 Self::set_up_once();
211 unsafe {
212 #mul_extern_func(
213 self as *mut Self as usize,
214 self as *const Self as usize,
215 other as *const Self as usize,
216 );
217 }
218 }
219 }
220
221 #[inline(always)]
222 fn div_assign_unsafe_impl(&mut self, other: &Self) {
223 #[cfg(not(target_os = "zkvm"))]
224 {
225 let modulus = Self::modulus_biguint();
226 let inv = other.as_biguint().modinv(&modulus).unwrap();
227 *self = Self::from_biguint((self.as_biguint() * inv) % modulus);
228 }
229 #[cfg(target_os = "zkvm")]
230 {
231 Self::set_up_once();
232 unsafe {
233 #div_extern_func(
234 self as *mut Self as usize,
235 self as *const Self as usize,
236 other as *const Self as usize,
237 );
238 }
239 }
240 }
241
242 #[inline(always)]
245 unsafe fn add_refs_impl<const CHECK_SETUP: bool>(&self, other: &Self, dst_ptr: *mut Self) {
246 #[cfg(not(target_os = "zkvm"))]
247 {
248 let mut res = self.clone();
249 res += other;
250 let dst = unsafe { &mut *dst_ptr };
252 *dst = res;
253 }
254 #[cfg(target_os = "zkvm")]
255 {
256 if CHECK_SETUP {
257 Self::set_up_once();
258 }
259 #add_extern_func(
260 dst_ptr as usize,
261 self as *const #struct_name as usize,
262 other as *const #struct_name as usize,
263 );
264 }
265 }
266
267 #[inline(always)]
270 unsafe fn sub_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
271 #[cfg(not(target_os = "zkvm"))]
272 {
273 let mut res = self.clone();
274 res -= other;
275 let dst = unsafe { &mut *dst_ptr };
277 *dst = res;
278 }
279 #[cfg(target_os = "zkvm")]
280 {
281 Self::set_up_once();
282 unsafe {
283 #sub_extern_func(
284 dst_ptr as usize,
285 self as *const #struct_name as usize,
286 other as *const #struct_name as usize,
287 );
288 }
289 }
290 }
291
292 #[inline(always)]
295 unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
296 #[cfg(not(target_os = "zkvm"))]
297 {
298 let mut res = self.clone();
299 res *= other;
300 let dst = unsafe { &mut *dst_ptr };
302 *dst = res;
303 }
304 #[cfg(target_os = "zkvm")]
305 {
306 Self::set_up_once();
307 unsafe {
308 #mul_extern_func(
309 dst_ptr as usize,
310 self as *const #struct_name as usize,
311 other as *const #struct_name as usize,
312 );
313 }
314 }
315 }
316
317 #[inline(always)]
318 fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
319 #[cfg(not(target_os = "zkvm"))]
320 {
321 let modulus = Self::modulus_biguint();
322 let inv = other.as_biguint().modinv(&modulus).unwrap();
323 Self::from_biguint((self.as_biguint() * inv) % modulus)
324 }
325 #[cfg(target_os = "zkvm")]
326 {
327 Self::set_up_once();
328 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
329 unsafe {
330 #div_extern_func(
331 uninit.as_mut_ptr() as usize,
332 self as *const #struct_name as usize,
333 other as *const #struct_name as usize,
334 );
335 }
336 unsafe { uninit.assume_init() }
337 }
338 }
339
340 #[inline(always)]
341 unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
342 #[cfg(not(target_os = "zkvm"))]
343 {
344 self.as_le_bytes() == other.as_le_bytes()
345 }
346 #[cfg(target_os = "zkvm")]
347 {
348 if CHECK_SETUP {
349 Self::set_up_once();
350 }
351 #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize)
352 }
353 }
354
355 #[inline(always)]
357 #[cfg(target_os = "zkvm")]
358 fn set_up_once() {
359 static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new();
360 is_setup.get_or_init(|| {
361 unsafe { #moduli_setup_extern_func(); }
362 true
363 });
364 }
365 #[inline(always)]
366 #[cfg(not(target_os = "zkvm"))]
367 fn set_up_once() {
368 }
370 }
371
372 mod #module_name {
374 use openvm_algebra_guest::IntMod;
375
376 use super::#struct_name;
377
378 impl IntMod for #struct_name {
379 type Repr = [u8; #limbs];
380 type SelfRef<'a> = &'a Self;
381
382 const MODULUS: Self::Repr = [#(#modulus_bytes),*];
383
384 const ZERO: Self = Self([0; #limbs]);
385
386 const NUM_LIMBS: usize = #limbs;
387
388 const ONE: Self = Self::from_const_u8(1);
389
390 fn from_repr(repr: Self::Repr) -> Self {
391 Self(repr)
392 }
393
394 fn from_le_bytes(bytes: &[u8]) -> Option<Self> {
395 let elt = Self::from_le_bytes_unchecked(bytes);
396 if elt.is_reduced() {
397 Some(elt)
398 } else {
399 None
400 }
401 }
402
403 fn from_be_bytes(bytes: &[u8]) -> Option<Self> {
404 let elt = Self::from_be_bytes_unchecked(bytes);
405 if elt.is_reduced() {
406 Some(elt)
407 } else {
408 None
409 }
410 }
411
412 fn from_le_bytes_unchecked(bytes: &[u8]) -> Self {
413 let mut arr = [0u8; #limbs];
414 arr.copy_from_slice(bytes);
415 Self(arr)
416 }
417
418 fn from_be_bytes_unchecked(bytes: &[u8]) -> Self {
419 let mut arr = [0u8; #limbs];
420 for (a, b) in arr.iter_mut().zip(bytes.iter().rev()) {
421 *a = *b;
422 }
423 Self(arr)
424 }
425
426 fn from_u8(val: u8) -> Self {
427 Self::from_const_u8(val)
428 }
429
430 fn from_u32(val: u32) -> Self {
431 let mut bytes = [0; #limbs];
432 bytes[..4].copy_from_slice(&val.to_le_bytes());
433 Self(bytes)
434 }
435
436 fn from_u64(val: u64) -> Self {
437 let mut bytes = [0; #limbs];
438 bytes[..8].copy_from_slice(&val.to_le_bytes());
439 Self(bytes)
440 }
441
442 #[inline(always)]
443 fn as_le_bytes(&self) -> &[u8] {
444 &(self.0)
445 }
446
447 #[inline(always)]
448 fn to_be_bytes(&self) -> [u8; #limbs] {
449 core::array::from_fn(|i| self.0[#limbs - 1 - i])
450 }
451
452 #[cfg(not(target_os = "zkvm"))]
453 fn modulus_biguint() -> num_bigint::BigUint {
454 num_bigint::BigUint::from_bytes_le(&Self::MODULUS)
455 }
456
457 #[cfg(not(target_os = "zkvm"))]
458 fn from_biguint(biguint: num_bigint::BigUint) -> Self {
459 Self(openvm::utils::biguint_to_limbs(&biguint))
460 }
461
462 #[cfg(not(target_os = "zkvm"))]
463 fn as_biguint(&self) -> num_bigint::BigUint {
464 num_bigint::BigUint::from_bytes_le(self.as_le_bytes())
465 }
466
467 #[inline(always)]
468 fn neg_assign(&mut self) {
469 unsafe {
470 (#struct_name::ZERO).sub_refs_impl(self, self as *const Self as *mut Self);
473 }
474 }
475
476 #[inline(always)]
477 fn double_assign(&mut self) {
478 unsafe {
479 self.add_refs_impl::<true>(self, self as *const Self as *mut Self);
482 }
483 }
484
485 #[inline(always)]
486 fn square_assign(&mut self) {
487 unsafe {
488 self.mul_refs_impl(self, self as *const Self as *mut Self);
491 }
492 }
493
494 #[inline(always)]
495 fn double(&self) -> Self {
496 self + self
497 }
498
499 #[inline(always)]
500 fn square(&self) -> Self {
501 self * self
502 }
503
504 #[inline(always)]
505 fn cube(&self) -> Self {
506 &self.square() * self
507 }
508
509 fn assert_reduced(&self) {
514 let _ = core::hint::black_box(PartialEq::eq(self, self));
516 }
517
518 fn is_reduced(&self) -> bool {
519 for (x_limb, p_limb) in self.0.iter().rev().zip(Self::MODULUS.iter().rev()) {
521 if x_limb < p_limb {
522 return true;
523 } else if x_limb > p_limb {
524 return false;
525 }
526 }
527 false
529 }
530
531 #[inline(always)]
532 fn set_up_once() {
533 Self::set_up_once();
534 }
535
536 #[inline(always)]
537 unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
538 Self::eq_impl::<CHECK_SETUP>(self, other)
539 }
540
541 #[inline(always)]
542 unsafe fn add_ref<const CHECK_SETUP: bool>(&self, other: &Self) -> Self {
543 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
544 self.add_refs_impl::<CHECK_SETUP>(other, uninit.as_mut_ptr());
545 uninit.assume_init()
546 }
547 }
548
549 impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
550 #[inline(always)]
551 fn add_assign(&mut self, other: &'a #struct_name) {
552 self.add_assign_impl(other);
553 }
554 }
555
556 impl core::ops::AddAssign for #struct_name {
557 #[inline(always)]
558 fn add_assign(&mut self, other: Self) {
559 self.add_assign_impl(&other);
560 }
561 }
562
563 impl core::ops::Add for #struct_name {
564 type Output = Self;
565 #[inline(always)]
566 fn add(mut self, other: Self) -> Self::Output {
567 self += other;
568 self
569 }
570 }
571
572 impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
573 type Output = Self;
574 #[inline(always)]
575 fn add(mut self, other: &'a #struct_name) -> Self::Output {
576 self += other;
577 self
578 }
579 }
580
581 impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
582 type Output = #struct_name;
583 #[inline(always)]
584 fn add(self, other: &'a #struct_name) -> Self::Output {
585 unsafe { self.add_ref::<true>(other) }
587 }
588 }
589
590 impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
591 #[inline(always)]
592 fn sub_assign(&mut self, other: &'a #struct_name) {
593 self.sub_assign_impl(other);
594 }
595 }
596
597 impl core::ops::SubAssign for #struct_name {
598 #[inline(always)]
599 fn sub_assign(&mut self, other: Self) {
600 self.sub_assign_impl(&other);
601 }
602 }
603
604 impl core::ops::Sub for #struct_name {
605 type Output = Self;
606 #[inline(always)]
607 fn sub(mut self, other: Self) -> Self::Output {
608 self -= other;
609 self
610 }
611 }
612
613 impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
614 type Output = Self;
615 #[inline(always)]
616 fn sub(mut self, other: &'a #struct_name) -> Self::Output {
617 self -= other;
618 self
619 }
620 }
621
622 impl<'a> core::ops::Sub<&'a #struct_name> for &'a #struct_name {
623 type Output = #struct_name;
624 #[inline(always)]
625 fn sub(self, other: &'a #struct_name) -> Self::Output {
626 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
627 unsafe {
628 self.sub_refs_impl(other, uninit.as_mut_ptr());
629 uninit.assume_init()
630 }
631 }
632 }
633
634 impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
635 #[inline(always)]
636 fn mul_assign(&mut self, other: &'a #struct_name) {
637 self.mul_assign_impl(other);
638 }
639 }
640
641 impl core::ops::MulAssign for #struct_name {
642 #[inline(always)]
643 fn mul_assign(&mut self, other: Self) {
644 self.mul_assign_impl(&other);
645 }
646 }
647
648 impl core::ops::Mul for #struct_name {
649 type Output = Self;
650 #[inline(always)]
651 fn mul(mut self, other: Self) -> Self::Output {
652 self *= other;
653 self
654 }
655 }
656
657 impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
658 type Output = Self;
659 #[inline(always)]
660 fn mul(mut self, other: &'a #struct_name) -> Self::Output {
661 self *= other;
662 self
663 }
664 }
665
666 impl<'a> core::ops::Mul<&'a #struct_name> for &#struct_name {
667 type Output = #struct_name;
668 #[inline(always)]
669 fn mul(self, other: &'a #struct_name) -> Self::Output {
670 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
671 unsafe {
672 self.mul_refs_impl(other, uninit.as_mut_ptr());
673 uninit.assume_init()
674 }
675 }
676 }
677
678 impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
679 #[inline(always)]
681 fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
682 self.div_assign_unsafe_impl(other);
683 }
684 }
685
686 impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
687 #[inline(always)]
689 fn div_assign_unsafe(&mut self, other: Self) {
690 self.div_assign_unsafe_impl(&other);
691 }
692 }
693
694 impl openvm_algebra_guest::DivUnsafe for #struct_name {
695 type Output = Self;
696 #[inline(always)]
698 fn div_unsafe(mut self, other: Self) -> Self::Output {
699 self.div_assign_unsafe_impl(&other);
700 self
701 }
702 }
703
704 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
705 type Output = Self;
706 #[inline(always)]
708 fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
709 self.div_assign_unsafe_impl(other);
710 self
711 }
712 }
713
714 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
715 type Output = #struct_name;
716 #[inline(always)]
718 fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
719 self.div_unsafe_refs_impl(other)
720 }
721 }
722
723 impl PartialEq for #struct_name {
724 #[inline(always)]
725 fn eq(&self, other: &Self) -> bool {
726 unsafe { self.eq_impl::<true>(other) }
728 }
729 }
730
731 impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
732 fn sum<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
733 iter.fold(Self::ZERO, |acc, x| &acc + x)
734 }
735 }
736
737 impl core::iter::Sum for #struct_name {
738 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
739 iter.fold(Self::ZERO, |acc, x| &acc + &x)
740 }
741 }
742
743 impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
744 fn product<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
745 iter.fold(Self::ONE, |acc, x| &acc * x)
746 }
747 }
748
749 impl core::iter::Product for #struct_name {
750 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
751 iter.fold(Self::ONE, |acc, x| &acc * &x)
752 }
753 }
754
755 impl core::ops::Neg for #struct_name {
756 type Output = #struct_name;
757 #[inline(always)]
758 fn neg(self) -> Self::Output {
759 #struct_name::ZERO - &self
760 }
761 }
762
763 impl<'a> core::ops::Neg for &'a #struct_name {
764 type Output = #struct_name;
765 #[inline(always)]
766 fn neg(self) -> Self::Output {
767 #struct_name::ZERO - self
768 }
769 }
770
771 impl core::fmt::Debug for #struct_name {
772 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
773 write!(f, "{:?}", self.as_le_bytes())
774 }
775 }
776 }
777
778 impl openvm_algebra_guest::Reduce for #struct_name {
779 fn reduce_le_bytes(bytes: &[u8]) -> Self {
780 let mut res = <Self as openvm_algebra_guest::IntMod>::ZERO;
781 let mut base = <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(&[255u8; #limbs]);
783 base += <Self as openvm_algebra_guest::IntMod>::ONE;
784 for chunk in bytes.chunks(#limbs).rev() {
785 res = res * &base + <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(chunk);
786 }
787 openvm_algebra_guest::IntMod::assert_reduced(&res);
788 res
789 }
790 }
791 });
792
793 output.push(result);
794
795 let modulus_biguint = BigUint::from_bytes_le(&modulus_bytes);
796 let modulus_is_prime = is_prime(&modulus_biguint, None);
797
798 if modulus_is_prime.probably() {
799 let field_and_sqrt_impl = TokenStream::from(quote::quote_spanned! { span.into() =>
801 impl ::openvm_algebra_guest::Field for #struct_name {
802 const ZERO: Self = <Self as ::openvm_algebra_guest::IntMod>::ZERO;
803 const ONE: Self = <Self as ::openvm_algebra_guest::IntMod>::ONE;
804
805 type SelfRef<'a> = &'a Self;
806
807 fn double_assign(&mut self) {
808 ::openvm_algebra_guest::IntMod::double_assign(self);
809 }
810
811 fn square_assign(&mut self) {
812 ::openvm_algebra_guest::IntMod::square_assign(self);
813 }
814
815 }
816
817 impl openvm_algebra_guest::Sqrt for #struct_name {
818 fn sqrt(&self) -> Option<Self> {
823 match self.honest_host_sqrt() {
824 Some(Some(sqrt)) => Some(sqrt),
826 Some(None) => None,
828 None => {
830 loop {
832 openvm::io::println("ERROR: Square root hint is invalid. Entering infinite loop.");
833 }
834 }
835 }
836 }
837 }
838
839 impl #struct_name {
840 fn honest_host_sqrt(&self) -> Option<Option<Self>> {
844 let (is_square, sqrt) = self.hint_sqrt_impl()?;
845
846 if is_square {
847 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
849
850 if &(&sqrt * &sqrt) == self {
851 Some(Some(sqrt))
852 } else {
853 None
854 }
855 } else {
856 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
858
859 if &sqrt * &sqrt == self * Self::get_non_qr() {
860 Some(None)
861 } else {
862 None
863 }
864 }
865 }
866
867
868 fn hint_sqrt_impl(&self) -> Option<(bool, Self)> {
872 #[cfg(not(target_os = "zkvm"))]
873 {
874 unimplemented!();
875 }
876 #[cfg(target_os = "zkvm")]
877 {
878 use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; let is_square = core::mem::MaybeUninit::<u32>::uninit();
881 let sqrt = core::mem::MaybeUninit::<#struct_name>::uninit();
882 unsafe {
883 #hint_sqrt_extern_func(self as *const #struct_name as usize);
884 let is_square_ptr = is_square.as_ptr() as *const u32;
885 openvm_rv32im_guest::hint_store_u32!(is_square_ptr);
886 openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
887 let is_square = is_square.assume_init();
888 if is_square == 0 || is_square == 1 {
889 Some((is_square == 1, sqrt.assume_init()))
890 } else {
891 None
892 }
893 }
894 }
895 }
896
897 fn init_non_qr() -> alloc::boxed::Box<#struct_name> {
899 #[cfg(not(target_os = "zkvm"))]
900 {
901 unimplemented!();
902 }
903 #[cfg(target_os = "zkvm")]
904 {
905 use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; let mut non_qr_uninit = core::mem::MaybeUninit::<Self>::uninit();
908 let mut non_qr;
909 unsafe {
910 #hint_non_qr_extern_func();
911 let ptr = non_qr_uninit.as_ptr() as *const u8;
912 openvm_rv32im_guest::hint_buffer_u32!(ptr, <Self as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
913 non_qr = non_qr_uninit.assume_init();
914 }
915 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&non_qr);
917
918 use ::openvm_algebra_guest::{DivUnsafe, ExpBytes};
919 let exp = -<Self as ::openvm_algebra_guest::IntMod>::ONE.div_unsafe(Self::from_const_u8(2));
921 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&exp);
922
923 if non_qr.exp_bytes(true, &<Self as ::openvm_algebra_guest::IntMod>::to_be_bytes(&exp)) != -<Self as ::openvm_algebra_guest::IntMod>::ONE
924 {
925 loop {
927 openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop.");
928 }
929 }
930
931 alloc::boxed::Box::new(non_qr)
932 }
933 }
934
935 pub fn get_non_qr() -> &'static #struct_name {
937 static non_qr: ::openvm_algebra_guest::once_cell::race::OnceBox<#struct_name> = ::openvm_algebra_guest::once_cell::race::OnceBox::new();
938 &non_qr.get_or_init(Self::init_non_qr)
939 }
940 }
941 });
942
943 output.push(field_and_sqrt_impl);
944 }
945 }
946
947 TokenStream::from_iter(output)
948}
949
950struct ModuliDefine {
951 items: Vec<LitStr>,
952}
953
954impl Parse for ModuliDefine {
955 fn parse(input: ParseStream) -> syn::Result<Self> {
956 let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
957 Ok(Self {
958 items: items.into_iter().collect(),
959 })
960 }
961}
962
963#[proc_macro]
964pub fn moduli_init(input: TokenStream) -> TokenStream {
965 let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine);
966
967 let mut externs = Vec::new();
968
969 let mut two_modular_limbs_flattened_list = Vec::<u8>::new();
971 let mut limb_list_borders = vec![0usize];
973
974 let span = proc_macro::Span::call_site();
975
976 for (mod_idx, item) in items.into_iter().enumerate() {
977 let modulus = item.value();
978 let modulus_bytes = string_to_bytes(&modulus);
979 let mut limbs = modulus_bytes.len();
980 let mut block_size = 32;
981
982 if limbs <= 32 {
983 limbs = 32;
984 } else if limbs <= 48 {
985 limbs = 48;
986 block_size = 16;
987 } else {
988 panic!("limbs must be at most 48");
989 }
990
991 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
992 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
993
994 let modulus_bytes = modulus_bytes
995 .into_iter()
996 .chain(vec![0u8; limbs])
997 .take(limbs)
998 .collect::<Vec<_>>();
999
1000 let doubled_modulus = [modulus_bytes.clone(), modulus_bytes.clone()].concat();
1002 two_modular_limbs_flattened_list.extend(doubled_modulus);
1003 limb_list_borders.push(two_modular_limbs_flattened_list.len());
1004
1005 let modulus_hex = modulus_bytes
1006 .iter()
1007 .rev()
1008 .map(|x| format!("{:02x}", x))
1009 .collect::<Vec<_>>()
1010 .join("");
1011
1012 let setup_extern_func = syn::Ident::new(
1013 &format!("moduli_setup_extern_func_{}", modulus_hex),
1014 span.into(),
1015 );
1016
1017 for op_type in ["add", "sub", "mul", "div"] {
1018 let func_name = syn::Ident::new(
1019 &format!("{}_extern_func_{}", op_type, modulus_hex),
1020 span.into(),
1021 );
1022 let mut chars = op_type.chars().collect::<Vec<_>>();
1023 chars[0] = chars[0].to_ascii_uppercase();
1024 let local_opcode = syn::Ident::new(
1025 &format!("{}Mod", chars.iter().collect::<String>()),
1026 span.into(),
1027 );
1028 externs.push(quote::quote_spanned! { span.into() =>
1029 #[no_mangle]
1030 extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
1031 openvm::platform::custom_insn_r!(
1032 opcode = ::openvm_algebra_guest::OPCODE,
1033 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1034 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::#local_opcode as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1035 rd = In rd,
1036 rs1 = In rs1,
1037 rs2 = In rs2
1038 )
1039 }
1040 });
1041 }
1042
1043 let is_eq_extern_func =
1044 syn::Ident::new(&format!("is_eq_extern_func_{}", modulus_hex), span.into());
1045 externs.push(quote::quote_spanned! { span.into() =>
1046 #[no_mangle]
1047 extern "C" fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool {
1048 let mut x: u32;
1049 openvm::platform::custom_insn_r!(
1050 opcode = ::openvm_algebra_guest::OPCODE,
1051 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1052 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::IsEqMod as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1053 rd = Out x,
1054 rs1 = In rs1,
1055 rs2 = In rs2
1056 );
1057 x != 0
1058 }
1059 });
1060
1061 let hint_non_qr_extern_func = syn::Ident::new(
1062 &format!("hint_non_qr_extern_func_{}", modulus_hex),
1063 span.into(),
1064 );
1065 externs.push(quote::quote_spanned! { span.into() =>
1066 #[no_mangle]
1067 extern "C" fn #hint_non_qr_extern_func() {
1068 openvm::platform::custom_insn_r!(
1069 opcode = ::openvm_algebra_guest::OPCODE,
1070 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1071 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintNonQr as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1072 rd = Const "x0",
1073 rs1 = Const "x0",
1074 rs2 = Const "x0"
1075 );
1076 }
1077
1078
1079 });
1080
1081 let hint_sqrt_extern_func = syn::Ident::new(
1084 &format!("hint_sqrt_extern_func_{}", modulus_hex),
1085 span.into(),
1086 );
1087 externs.push(quote::quote_spanned! { span.into() =>
1088 #[no_mangle]
1089 extern "C" fn #hint_sqrt_extern_func(rs1: usize) {
1090 openvm::platform::custom_insn_r!(
1091 opcode = ::openvm_algebra_guest::OPCODE,
1092 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1093 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintSqrt as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1094 rd = Const "x0",
1095 rs1 = In rs1,
1096 rs2 = Const "x0"
1097 );
1098 }
1099 });
1100
1101 externs.push(quote::quote_spanned! { span.into() =>
1102 #[no_mangle]
1103 extern "C" fn #setup_extern_func() {
1104 #[cfg(target_os = "zkvm")]
1105 {
1106 #[repr(C, align(#block_size))]
1108 struct AlignedPlaceholder([u8; #limbs]);
1109
1110 const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]);
1111
1112 let mut uninit: core::mem::MaybeUninit<AlignedPlaceholder> = core::mem::MaybeUninit::uninit();
1115 openvm::platform::custom_insn_r!(
1116 opcode = ::openvm_algebra_guest::OPCODE,
1117 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1118 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1119 + #mod_idx
1120 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1121 rd = In uninit.as_mut_ptr(),
1122 rs1 = In MODULUS_BYTES.0.as_ptr(),
1123 rs2 = Const "x0" );
1125 openvm::platform::custom_insn_r!(
1126 opcode = ::openvm_algebra_guest::OPCODE,
1127 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1128 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1129 + #mod_idx
1130 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1131 rd = In uninit.as_mut_ptr(),
1132 rs1 = In MODULUS_BYTES.0.as_ptr(),
1133 rs2 = Const "x1" );
1135 unsafe {
1136 let mut tmp = uninit.as_mut_ptr() as usize;
1138 openvm::platform::custom_insn_r!(
1139 opcode = ::openvm_algebra_guest::OPCODE,
1140 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1141 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1142 + #mod_idx
1143 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1144 rd = InOut tmp,
1145 rs1 = In MODULUS_BYTES.0.as_ptr(),
1146 rs2 = Const "x2" );
1148 }
1150 }
1151 }
1152 });
1153 }
1154
1155 let total_limbs_cnt = two_modular_limbs_flattened_list.len();
1156 let cnt_limbs_list_len = limb_list_borders.len();
1157 TokenStream::from(quote::quote_spanned! { span.into() =>
1158 #[allow(non_snake_case)]
1159 #[cfg(target_os = "zkvm")]
1160 mod openvm_intrinsics_ffi {
1161 #(#externs)*
1162 }
1163 #[allow(non_snake_case, non_upper_case_globals)]
1164 pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself {
1165 pub const two_modular_limbs_list: [u8; #total_limbs_cnt] = [#(#two_modular_limbs_flattened_list),*];
1166 pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*];
1167 }
1168 })
1169}