1extern crate alloc;
2extern crate proc_macro;
3
4use std::sync::atomic::AtomicUsize;
5
6use num_bigint::BigUint;
7use num_prime::nt_funcs::is_prime;
8use openvm_macros_common::{string_to_bytes, MacroArgs};
9use proc_macro::TokenStream;
10use quote::format_ident;
11use syn::{
12 parse::{Parse, ParseStream},
13 parse_macro_input, LitStr, Token,
14};
15
16static MOD_IDX: AtomicUsize = AtomicUsize::new(0);
17
18#[proc_macro]
29pub fn moduli_declare(input: TokenStream) -> TokenStream {
30 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
31
32 let mut output = Vec::new();
33
34 let span = proc_macro::Span::call_site();
35
36 for item in items {
37 let struct_name = item.name.to_string();
38 let struct_name = syn::Ident::new(&struct_name, span.into());
39 let mut modulus: Option<String> = None;
40 for param in item.params {
41 match param.name.to_string().as_str() {
42 "modulus" => {
43 if let syn::Expr::Lit(syn::ExprLit {
44 lit: syn::Lit::Str(value),
45 ..
46 }) = param.value
47 {
48 modulus = Some(value.value());
49 } else {
50 return syn::Error::new_spanned(
51 param.value,
52 "Expected a string literal for macro argument `modulus`",
53 )
54 .to_compile_error()
55 .into();
56 }
57 }
58 _ => {
59 panic!("Unknown parameter {}", param.name);
60 }
61 }
62 }
63
64 let mod_idx = MOD_IDX.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
67
68 let modulus = modulus.expect("modulus parameter is required");
69 let modulus_bytes = string_to_bytes(&modulus);
70 let mut limbs = modulus_bytes.len();
71 let mut block_size = 32;
72
73 if limbs <= 32 {
74 limbs = 32;
75 } else if limbs <= 48 {
76 limbs = 48;
77 block_size = 16;
78 } else {
79 panic!("limbs must be at most 48");
80 }
81
82 let modulus_bytes = modulus_bytes
83 .into_iter()
84 .chain(vec![0u8; limbs])
85 .take(limbs)
86 .collect::<Vec<_>>();
87
88 let modulus_hex = modulus_bytes
89 .iter()
90 .rev()
91 .map(|x| format!("{x:02x}"))
92 .collect::<Vec<_>>()
93 .join("");
94 macro_rules! create_extern_func {
95 ($name:ident) => {
96 let $name = syn::Ident::new(
97 &format!("{}_{}", stringify!($name), modulus_hex),
98 span.into(),
99 );
100 };
101 }
102 create_extern_func!(add_extern_func);
103 create_extern_func!(sub_extern_func);
104 create_extern_func!(mul_extern_func);
105 create_extern_func!(div_extern_func);
106 create_extern_func!(is_eq_extern_func);
107 create_extern_func!(hint_sqrt_extern_func);
108 create_extern_func!(hint_non_qr_extern_func);
109 create_extern_func!(moduli_setup_extern_func);
110
111 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
112 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
113
114 let module_name = format_ident!("algebra_impl_{}", mod_idx);
115
116 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
117 #[derive(Clone, Eq, serde::Serialize, serde::Deserialize)]
129 #[repr(C, align(#block_size))]
130 pub struct #struct_name(#[serde(with = "openvm_algebra_guest::BigArray")] [u8; #limbs]);
131
132 extern "C" {
133 fn #add_extern_func(rd: usize, rs1: usize, rs2: usize);
134 fn #sub_extern_func(rd: usize, rs1: usize, rs2: usize);
135 fn #mul_extern_func(rd: usize, rs1: usize, rs2: usize);
136 fn #div_extern_func(rd: usize, rs1: usize, rs2: usize);
137 fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool;
138 fn #hint_sqrt_extern_func(rs1: usize);
139 fn #hint_non_qr_extern_func();
140 fn #moduli_setup_extern_func();
141 }
142
143 impl #struct_name {
144 #[inline(always)]
145 const fn from_const_u8(val: u8) -> Self {
146 let mut bytes = [0; #limbs];
147 bytes[0] = val;
148 Self(bytes)
149 }
150
151 pub const fn from_const_bytes(bytes: [u8; #limbs]) -> Self {
154 Self(bytes)
155 }
156
157 #[inline(always)]
158 fn add_assign_impl(&mut self, other: &Self) {
159 #[cfg(not(target_os = "zkvm"))]
160 {
161 *self = Self::from_biguint(
162 (self.as_biguint() + other.as_biguint()) % Self::modulus_biguint(),
163 );
164 }
165 #[cfg(target_os = "zkvm")]
166 {
167 Self::set_up_once();
168 unsafe {
169 #add_extern_func(
170 self as *mut Self as usize,
171 self as *const Self as usize,
172 other as *const Self as usize,
173 );
174 }
175 }
176 }
177
178 #[inline(always)]
179 fn sub_assign_impl(&mut self, other: &Self) {
180 #[cfg(not(target_os = "zkvm"))]
181 {
182 let modulus = Self::modulus_biguint();
183 *self = Self::from_biguint(
184 (self.as_biguint() + modulus.clone() - other.as_biguint()) % modulus,
185 );
186 }
187 #[cfg(target_os = "zkvm")]
188 {
189 Self::set_up_once();
190 unsafe {
191 #sub_extern_func(
192 self as *mut Self as usize,
193 self as *const Self as usize,
194 other as *const Self as usize,
195 );
196 }
197 }
198 }
199
200 #[inline(always)]
201 fn mul_assign_impl(&mut self, other: &Self) {
202 #[cfg(not(target_os = "zkvm"))]
203 {
204 *self = Self::from_biguint(
205 (self.as_biguint() * other.as_biguint()) % Self::modulus_biguint(),
206 );
207 }
208 #[cfg(target_os = "zkvm")]
209 {
210 Self::set_up_once();
211 unsafe {
212 #mul_extern_func(
213 self as *mut Self as usize,
214 self as *const Self as usize,
215 other as *const Self as usize,
216 );
217 }
218 }
219 }
220
221 #[inline(always)]
222 fn div_assign_unsafe_impl(&mut self, other: &Self) {
223 #[cfg(not(target_os = "zkvm"))]
224 {
225 let modulus = Self::modulus_biguint();
226 let inv = other.as_biguint().modinv(&modulus).unwrap();
227 *self = Self::from_biguint((self.as_biguint() * inv) % modulus);
228 }
229 #[cfg(target_os = "zkvm")]
230 {
231 Self::set_up_once();
232 unsafe {
233 #div_extern_func(
234 self as *mut Self as usize,
235 self as *const Self as usize,
236 other as *const Self as usize,
237 );
238 }
239 }
240 }
241
242 #[inline(always)]
245 unsafe fn add_refs_impl<const CHECK_SETUP: bool>(&self, other: &Self, dst_ptr: *mut Self) {
246 #[cfg(not(target_os = "zkvm"))]
247 {
248 let mut res = self.clone();
249 res += other;
250 let dst = unsafe { &mut *dst_ptr };
252 *dst = res;
253 }
254 #[cfg(target_os = "zkvm")]
255 {
256 if CHECK_SETUP {
257 Self::set_up_once();
258 }
259 #add_extern_func(
260 dst_ptr as usize,
261 self as *const #struct_name as usize,
262 other as *const #struct_name as usize,
263 );
264 }
265 }
266
267 #[inline(always)]
270 unsafe fn sub_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
271 #[cfg(not(target_os = "zkvm"))]
272 {
273 let mut res = self.clone();
274 res -= other;
275 let dst = unsafe { &mut *dst_ptr };
277 *dst = res;
278 }
279 #[cfg(target_os = "zkvm")]
280 {
281 Self::set_up_once();
282 unsafe {
283 #sub_extern_func(
284 dst_ptr as usize,
285 self as *const #struct_name as usize,
286 other as *const #struct_name as usize,
287 );
288 }
289 }
290 }
291
292 #[inline(always)]
295 unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
296 #[cfg(not(target_os = "zkvm"))]
297 {
298 let mut res = self.clone();
299 res *= other;
300 let dst = unsafe { &mut *dst_ptr };
302 *dst = res;
303 }
304 #[cfg(target_os = "zkvm")]
305 {
306 Self::set_up_once();
307 unsafe {
308 #mul_extern_func(
309 dst_ptr as usize,
310 self as *const #struct_name as usize,
311 other as *const #struct_name as usize,
312 );
313 }
314 }
315 }
316
317 #[inline(always)]
318 fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
319 #[cfg(not(target_os = "zkvm"))]
320 {
321 let modulus = Self::modulus_biguint();
322 let inv = other.as_biguint().modinv(&modulus).unwrap();
323 Self::from_biguint((self.as_biguint() * inv) % modulus)
324 }
325 #[cfg(target_os = "zkvm")]
326 {
327 Self::set_up_once();
328 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
329 unsafe {
330 #div_extern_func(
331 uninit.as_mut_ptr() as usize,
332 self as *const #struct_name as usize,
333 other as *const #struct_name as usize,
334 );
335 }
336 unsafe { uninit.assume_init() }
337 }
338 }
339
340 #[inline(always)]
341 unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
342 #[cfg(not(target_os = "zkvm"))]
343 {
344 self.as_le_bytes() == other.as_le_bytes()
345 }
346 #[cfg(target_os = "zkvm")]
347 {
348 if CHECK_SETUP {
349 Self::set_up_once();
350 }
351 #is_eq_extern_func(self as *const #struct_name as usize, other as *const #struct_name as usize)
352 }
353 }
354
355 #[inline(always)]
357 #[cfg(target_os = "zkvm")]
358 fn set_up_once() {
359 static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new();
360 is_setup.get_or_init(|| {
361 unsafe { #moduli_setup_extern_func(); }
362 true
363 });
364 }
365 #[inline(always)]
366 #[cfg(not(target_os = "zkvm"))]
367 fn set_up_once() {
368 }
370 }
371
372 mod #module_name {
374 use openvm_algebra_guest::IntMod;
375
376 use super::#struct_name;
377
378 impl IntMod for #struct_name {
379 type Repr = [u8; #limbs];
380 type SelfRef<'a> = &'a Self;
381
382 const MODULUS: Self::Repr = [#(#modulus_bytes),*];
383
384 const ZERO: Self = Self([0; #limbs]);
385
386 const NUM_LIMBS: usize = #limbs;
387
388 const ONE: Self = Self::from_const_u8(1);
389
390 fn from_repr(repr: Self::Repr) -> Self {
391 Self(repr)
392 }
393
394 fn from_le_bytes(bytes: &[u8]) -> Option<Self> {
395 let elt = Self::from_le_bytes_unchecked(bytes);
396 if elt.is_reduced() {
397 Some(elt)
398 } else {
399 None
400 }
401 }
402
403 fn from_be_bytes(bytes: &[u8]) -> Option<Self> {
404 let elt = Self::from_be_bytes_unchecked(bytes);
405 if elt.is_reduced() {
406 Some(elt)
407 } else {
408 None
409 }
410 }
411
412 fn from_le_bytes_unchecked(bytes: &[u8]) -> Self {
413 let mut arr = [0u8; #limbs];
414 arr.copy_from_slice(bytes);
415 Self(arr)
416 }
417
418 fn from_be_bytes_unchecked(bytes: &[u8]) -> Self {
419 let mut arr = [0u8; #limbs];
420 for (a, b) in arr.iter_mut().zip(bytes.iter().rev()) {
421 *a = *b;
422 }
423 Self(arr)
424 }
425
426 fn from_u8(val: u8) -> Self {
427 Self::from_const_u8(val)
428 }
429
430 fn from_u32(val: u32) -> Self {
431 let mut bytes = [0; #limbs];
432 bytes[..4].copy_from_slice(&val.to_le_bytes());
433 Self(bytes)
434 }
435
436 fn from_u64(val: u64) -> Self {
437 let mut bytes = [0; #limbs];
438 bytes[..8].copy_from_slice(&val.to_le_bytes());
439 Self(bytes)
440 }
441
442 #[inline(always)]
443 fn as_le_bytes(&self) -> &[u8] {
444 &(self.0)
445 }
446
447 #[inline(always)]
448 fn to_be_bytes(&self) -> [u8; #limbs] {
449 core::array::from_fn(|i| self.0[#limbs - 1 - i])
450 }
451
452 #[cfg(not(target_os = "zkvm"))]
453 fn modulus_biguint() -> num_bigint::BigUint {
454 num_bigint::BigUint::from_bytes_le(&Self::MODULUS)
455 }
456
457 #[cfg(not(target_os = "zkvm"))]
458 fn from_biguint(biguint: num_bigint::BigUint) -> Self {
459 Self(openvm::utils::biguint_to_limbs(&biguint))
460 }
461
462 #[cfg(not(target_os = "zkvm"))]
463 fn as_biguint(&self) -> num_bigint::BigUint {
464 num_bigint::BigUint::from_bytes_le(self.as_le_bytes())
465 }
466
467 #[inline(always)]
468 fn neg_assign(&mut self) {
469 unsafe {
470 (#struct_name::ZERO).sub_refs_impl(self, self as *const Self as *mut Self);
473 }
474 }
475
476 #[inline(always)]
477 fn double_assign(&mut self) {
478 unsafe {
479 self.add_refs_impl::<true>(self, self as *const Self as *mut Self);
482 }
483 }
484
485 #[inline(always)]
486 fn square_assign(&mut self) {
487 unsafe {
488 self.mul_refs_impl(self, self as *const Self as *mut Self);
491 }
492 }
493
494 #[inline(always)]
495 fn double(&self) -> Self {
496 self + self
497 }
498
499 #[inline(always)]
500 fn square(&self) -> Self {
501 self * self
502 }
503
504 #[inline(always)]
505 fn cube(&self) -> Self {
506 &self.square() * self
507 }
508
509 fn assert_reduced(&self) {
514 let _ = core::hint::black_box(PartialEq::eq(self, self));
516 }
517
518 fn is_reduced(&self) -> bool {
519 for (x_limb, p_limb) in self.0.iter().rev().zip(Self::MODULUS.iter().rev()) {
521 if x_limb < p_limb {
522 return true;
523 } else if x_limb > p_limb {
524 return false;
525 }
526 }
527 false
529 }
530
531 #[inline(always)]
532 fn set_up_once() {
533 Self::set_up_once();
534 }
535
536 #[inline(always)]
537 unsafe fn eq_impl<const CHECK_SETUP: bool>(&self, other: &Self) -> bool {
538 Self::eq_impl::<CHECK_SETUP>(self, other)
539 }
540
541 #[inline(always)]
542 unsafe fn add_ref<const CHECK_SETUP: bool>(&self, other: &Self) -> Self {
543 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
544 self.add_refs_impl::<CHECK_SETUP>(other, uninit.as_mut_ptr());
545 uninit.assume_init()
546 }
547 }
548
549 impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
550 #[inline(always)]
551 fn add_assign(&mut self, other: &'a #struct_name) {
552 self.add_assign_impl(other);
553 }
554 }
555
556 impl core::ops::AddAssign for #struct_name {
557 #[inline(always)]
558 fn add_assign(&mut self, other: Self) {
559 self.add_assign_impl(&other);
560 }
561 }
562
563 impl core::ops::Add for #struct_name {
564 type Output = Self;
565 #[inline(always)]
566 fn add(mut self, other: Self) -> Self::Output {
567 self += other;
568 self
569 }
570 }
571
572 impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
573 type Output = Self;
574 #[inline(always)]
575 fn add(mut self, other: &'a #struct_name) -> Self::Output {
576 self += other;
577 self
578 }
579 }
580
581 impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
582 type Output = #struct_name;
583 #[inline(always)]
584 fn add(self, other: &'a #struct_name) -> Self::Output {
585 unsafe { self.add_ref::<true>(other) }
587 }
588 }
589
590 impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
591 #[inline(always)]
592 fn sub_assign(&mut self, other: &'a #struct_name) {
593 self.sub_assign_impl(other);
594 }
595 }
596
597 impl core::ops::SubAssign for #struct_name {
598 #[inline(always)]
599 fn sub_assign(&mut self, other: Self) {
600 self.sub_assign_impl(&other);
601 }
602 }
603
604 impl core::ops::Sub for #struct_name {
605 type Output = Self;
606 #[inline(always)]
607 fn sub(mut self, other: Self) -> Self::Output {
608 self -= other;
609 self
610 }
611 }
612
613 impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
614 type Output = Self;
615 #[inline(always)]
616 fn sub(mut self, other: &'a #struct_name) -> Self::Output {
617 self -= other;
618 self
619 }
620 }
621
622 impl<'a> core::ops::Sub<&'a #struct_name> for &'a #struct_name {
623 type Output = #struct_name;
624 #[inline(always)]
625 fn sub(self, other: &'a #struct_name) -> Self::Output {
626 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
627 unsafe {
628 self.sub_refs_impl(other, uninit.as_mut_ptr());
629 uninit.assume_init()
630 }
631 }
632 }
633
634 impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
635 #[inline(always)]
636 fn mul_assign(&mut self, other: &'a #struct_name) {
637 self.mul_assign_impl(other);
638 }
639 }
640
641 impl core::ops::MulAssign for #struct_name {
642 #[inline(always)]
643 fn mul_assign(&mut self, other: Self) {
644 self.mul_assign_impl(&other);
645 }
646 }
647
648 impl core::ops::Mul for #struct_name {
649 type Output = Self;
650 #[inline(always)]
651 fn mul(mut self, other: Self) -> Self::Output {
652 self *= other;
653 self
654 }
655 }
656
657 impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
658 type Output = Self;
659 #[inline(always)]
660 fn mul(mut self, other: &'a #struct_name) -> Self::Output {
661 self *= other;
662 self
663 }
664 }
665
666 impl<'a> core::ops::Mul<&'a #struct_name> for &#struct_name {
667 type Output = #struct_name;
668 #[inline(always)]
669 fn mul(self, other: &'a #struct_name) -> Self::Output {
670 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
671 unsafe {
672 self.mul_refs_impl(other, uninit.as_mut_ptr());
673 uninit.assume_init()
674 }
675 }
676 }
677
678 impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
679 #[inline(always)]
681 fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
682 self.div_assign_unsafe_impl(other);
683 }
684 }
685
686 impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
687 #[inline(always)]
689 fn div_assign_unsafe(&mut self, other: Self) {
690 self.div_assign_unsafe_impl(&other);
691 }
692 }
693
694 impl openvm_algebra_guest::DivUnsafe for #struct_name {
695 type Output = Self;
696 #[inline(always)]
698 fn div_unsafe(mut self, other: Self) -> Self::Output {
699 self.div_assign_unsafe_impl(&other);
700 self
701 }
702 }
703
704 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
705 type Output = Self;
706 #[inline(always)]
708 fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
709 self.div_assign_unsafe_impl(other);
710 self
711 }
712 }
713
714 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
715 type Output = #struct_name;
716 #[inline(always)]
718 fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
719 self.div_unsafe_refs_impl(other)
720 }
721 }
722
723 impl PartialEq for #struct_name {
724 #[inline(always)]
725 fn eq(&self, other: &Self) -> bool {
726 unsafe { self.eq_impl::<true>(other) }
728 }
729 }
730
731 impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
732 fn sum<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
733 iter.fold(Self::ZERO, |acc, x| &acc + x)
734 }
735 }
736
737 impl core::iter::Sum for #struct_name {
738 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
739 iter.fold(Self::ZERO, |acc, x| &acc + &x)
740 }
741 }
742
743 impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
744 fn product<I: Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
745 iter.fold(Self::ONE, |acc, x| &acc * x)
746 }
747 }
748
749 impl core::iter::Product for #struct_name {
750 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
751 iter.fold(Self::ONE, |acc, x| &acc * &x)
752 }
753 }
754
755 impl core::ops::Neg for #struct_name {
756 type Output = #struct_name;
757 #[inline(always)]
758 fn neg(self) -> Self::Output {
759 #struct_name::ZERO - &self
760 }
761 }
762
763 impl<'a> core::ops::Neg for &'a #struct_name {
764 type Output = #struct_name;
765 #[inline(always)]
766 fn neg(self) -> Self::Output {
767 #struct_name::ZERO - self
768 }
769 }
770
771 impl core::fmt::Debug for #struct_name {
772 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
773 write!(f, "{:?}", self.as_le_bytes())
774 }
775 }
776 }
777
778 impl openvm_algebra_guest::Reduce for #struct_name {
779 fn reduce_le_bytes(bytes: &[u8]) -> Self {
780 debug_assert!(
781 bytes.len() % #limbs == 0,
782 "reduce_le_bytes: input length {} is not a multiple of modulus byte size {}",
783 bytes.len(),
784 #limbs,
785 );
786 let mut res = <Self as openvm_algebra_guest::IntMod>::ZERO;
787 let mut base = <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(&[255u8; #limbs]);
789 base += <Self as openvm_algebra_guest::IntMod>::ONE;
790 for chunk in bytes.chunks(#limbs).rev() {
791 res = res * &base + <Self as openvm_algebra_guest::IntMod>::from_le_bytes_unchecked(chunk);
792 }
793 openvm_algebra_guest::IntMod::assert_reduced(&res);
794 res
795 }
796 }
797 });
798
799 output.push(result);
800
801 let modulus_biguint = BigUint::from_bytes_le(&modulus_bytes);
802 let modulus_is_prime = is_prime(&modulus_biguint, None);
803
804 if modulus_is_prime.probably() {
805 let field_and_sqrt_impl = TokenStream::from(quote::quote_spanned! { span.into() =>
807 impl ::openvm_algebra_guest::Field for #struct_name {
808 const ZERO: Self = <Self as ::openvm_algebra_guest::IntMod>::ZERO;
809 const ONE: Self = <Self as ::openvm_algebra_guest::IntMod>::ONE;
810
811 type SelfRef<'a> = &'a Self;
812
813 fn double_assign(&mut self) {
814 ::openvm_algebra_guest::IntMod::double_assign(self);
815 }
816
817 fn square_assign(&mut self) {
818 ::openvm_algebra_guest::IntMod::square_assign(self);
819 }
820
821 }
822
823 impl openvm_algebra_guest::Sqrt for #struct_name {
824 fn sqrt(&self) -> Option<Self> {
829 match self.honest_host_sqrt() {
830 Some(Some(sqrt)) => Some(sqrt),
832 Some(None) => None,
834 None => {
836 loop {
838 openvm::io::println("ERROR: Square root hint is invalid. Entering infinite loop.");
839 }
840 }
841 }
842 }
843 }
844
845 impl #struct_name {
846 fn honest_host_sqrt(&self) -> Option<Option<Self>> {
850 if self == &<Self as ::openvm_algebra_guest::IntMod>::ZERO {
854 return Some(Some(<Self as ::openvm_algebra_guest::IntMod>::ZERO));
855 }
856
857 let (is_square, sqrt) = self.hint_sqrt_impl()?;
858
859 if is_square {
860 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
862
863 if &(&sqrt * &sqrt) == self {
864 Some(Some(sqrt))
865 } else {
866 None
867 }
868 } else {
869 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&sqrt);
871
872 if &sqrt * &sqrt == self * Self::get_non_qr() {
873 Some(None)
874 } else {
875 None
876 }
877 }
878 }
879
880
881 fn hint_sqrt_impl(&self) -> Option<(bool, Self)> {
885 #[cfg(not(target_os = "zkvm"))]
886 {
887 unimplemented!();
888 }
889 #[cfg(target_os = "zkvm")]
890 {
891 use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; let is_square = core::mem::MaybeUninit::<u32>::uninit();
894 let sqrt = core::mem::MaybeUninit::<#struct_name>::uninit();
895 unsafe {
896 #hint_sqrt_extern_func(self as *const #struct_name as usize);
897 let is_square_ptr = is_square.as_ptr() as *const u32;
898 openvm_rv32im_guest::hint_store_u32!(is_square_ptr);
899 openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#struct_name as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
900 let is_square = is_square.assume_init();
901 if is_square == 0 || is_square == 1 {
902 Some((is_square == 1, sqrt.assume_init()))
903 } else {
904 None
905 }
906 }
907 }
908 }
909
910 fn init_non_qr() -> alloc::boxed::Box<#struct_name> {
912 #[cfg(not(target_os = "zkvm"))]
913 {
914 unimplemented!();
915 }
916 #[cfg(target_os = "zkvm")]
917 {
918 use ::openvm_algebra_guest::{openvm_custom_insn, openvm_rv32im_guest}; let mut non_qr_uninit = core::mem::MaybeUninit::<Self>::uninit();
921 let mut non_qr;
922 unsafe {
923 #hint_non_qr_extern_func();
924 let ptr = non_qr_uninit.as_ptr() as *const u8;
925 openvm_rv32im_guest::hint_buffer_u32!(ptr, <Self as ::openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
926 non_qr = non_qr_uninit.assume_init();
927 }
928 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&non_qr);
930
931 use ::openvm_algebra_guest::{DivUnsafe, ExpBytes};
932 let exp = -<Self as ::openvm_algebra_guest::IntMod>::ONE.div_unsafe(Self::from_const_u8(2));
934 <Self as ::openvm_algebra_guest::IntMod>::assert_reduced(&exp);
935
936 if non_qr.exp_bytes(true, &<Self as ::openvm_algebra_guest::IntMod>::to_be_bytes(&exp)) != -<Self as ::openvm_algebra_guest::IntMod>::ONE
937 {
938 loop {
940 openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop.");
941 }
942 }
943
944 alloc::boxed::Box::new(non_qr)
945 }
946 }
947
948 pub fn get_non_qr() -> &'static #struct_name {
950 static non_qr: ::openvm_algebra_guest::once_cell::race::OnceBox<#struct_name> = ::openvm_algebra_guest::once_cell::race::OnceBox::new();
951 &non_qr.get_or_init(Self::init_non_qr)
952 }
953 }
954 });
955
956 output.push(field_and_sqrt_impl);
957 }
958 }
959
960 TokenStream::from_iter(output)
961}
962
963struct ModuliDefine {
964 items: Vec<LitStr>,
965}
966
967impl Parse for ModuliDefine {
968 fn parse(input: ParseStream) -> syn::Result<Self> {
969 let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
970 Ok(Self {
971 items: items.into_iter().collect(),
972 })
973 }
974}
975
976#[proc_macro]
977pub fn moduli_init(input: TokenStream) -> TokenStream {
978 let ModuliDefine { items } = parse_macro_input!(input as ModuliDefine);
979
980 let mut externs = Vec::new();
981
982 let mut two_modular_limbs_flattened_list = Vec::<u8>::new();
984 let mut limb_list_borders = vec![0usize];
986
987 let span = proc_macro::Span::call_site();
988
989 let mut max_block_size = 4;
990
991 for (mod_idx, item) in items.into_iter().enumerate() {
992 let modulus = item.value();
993 let modulus_bytes = string_to_bytes(&modulus);
994 let mut limbs = modulus_bytes.len();
995 let mut block_size = 32;
996
997 if limbs <= 32 {
998 limbs = 32;
999 } else if limbs <= 48 {
1000 limbs = 48;
1001 block_size = 16;
1002 } else {
1003 panic!("limbs must be at most 48");
1004 }
1005
1006 max_block_size = max_block_size.max(block_size);
1007
1008 let block_size = proc_macro::Literal::usize_unsuffixed(block_size);
1009 let block_size = syn::Lit::new(block_size.to_string().parse::<_>().unwrap());
1010
1011 let modulus_bytes = modulus_bytes
1012 .into_iter()
1013 .chain(vec![0u8; limbs])
1014 .take(limbs)
1015 .collect::<Vec<_>>();
1016
1017 let doubled_modulus = [modulus_bytes.clone(), modulus_bytes.clone()].concat();
1019 two_modular_limbs_flattened_list.extend(doubled_modulus);
1020 limb_list_borders.push(two_modular_limbs_flattened_list.len());
1021
1022 let modulus_hex = modulus_bytes
1023 .iter()
1024 .rev()
1025 .map(|x| format!("{x:02x}"))
1026 .collect::<Vec<_>>()
1027 .join("");
1028
1029 let setup_extern_func = syn::Ident::new(
1030 &format!("moduli_setup_extern_func_{modulus_hex}"),
1031 span.into(),
1032 );
1033
1034 for op_type in ["add", "sub", "mul", "div"] {
1035 let func_name =
1036 syn::Ident::new(&format!("{op_type}_extern_func_{modulus_hex}"), span.into());
1037 let mut chars = op_type.chars().collect::<Vec<_>>();
1038 chars[0] = chars[0].to_ascii_uppercase();
1039 let local_opcode = syn::Ident::new(
1040 &format!("{}Mod", chars.iter().collect::<String>()),
1041 span.into(),
1042 );
1043 externs.push(quote::quote_spanned! { span.into() =>
1044 #[no_mangle]
1045 extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
1046 openvm::platform::custom_insn_r!(
1047 opcode = ::openvm_algebra_guest::OPCODE,
1048 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1049 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::#local_opcode as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1050 rd = In rd,
1051 rs1 = In rs1,
1052 rs2 = In rs2
1053 )
1054 }
1055 });
1056 }
1057
1058 let is_eq_extern_func =
1059 syn::Ident::new(&format!("is_eq_extern_func_{modulus_hex}"), span.into());
1060 externs.push(quote::quote_spanned! { span.into() =>
1061 #[no_mangle]
1062 extern "C" fn #is_eq_extern_func(rs1: usize, rs2: usize) -> bool {
1063 let mut x: u32;
1064 openvm::platform::custom_insn_r!(
1065 opcode = ::openvm_algebra_guest::OPCODE,
1066 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1067 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::IsEqMod as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1068 rd = Out x,
1069 rs1 = In rs1,
1070 rs2 = In rs2
1071 );
1072 x != 0
1073 }
1074 });
1075
1076 let hint_non_qr_extern_func = syn::Ident::new(
1077 &format!("hint_non_qr_extern_func_{modulus_hex}"),
1078 span.into(),
1079 );
1080 externs.push(quote::quote_spanned! { span.into() =>
1081 #[no_mangle]
1082 extern "C" fn #hint_non_qr_extern_func() {
1083 openvm::platform::custom_insn_r!(
1084 opcode = ::openvm_algebra_guest::OPCODE,
1085 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1086 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintNonQr as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1087 rd = Const "x0",
1088 rs1 = Const "x0",
1089 rs2 = Const "x0"
1090 );
1091 }
1092
1093
1094 });
1095
1096 let hint_sqrt_extern_func =
1099 syn::Ident::new(&format!("hint_sqrt_extern_func_{modulus_hex}"), span.into());
1100 externs.push(quote::quote_spanned! { span.into() =>
1101 #[no_mangle]
1102 extern "C" fn #hint_sqrt_extern_func(rs1: usize) {
1103 openvm::platform::custom_insn_r!(
1104 opcode = ::openvm_algebra_guest::OPCODE,
1105 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1106 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::HintSqrt as usize + #mod_idx * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1107 rd = Const "x0",
1108 rs1 = In rs1,
1109 rs2 = Const "x0"
1110 );
1111 }
1112 });
1113
1114 externs.push(quote::quote_spanned! { span.into() =>
1115 #[no_mangle]
1116 extern "C" fn #setup_extern_func() {
1117 #[cfg(target_os = "zkvm")]
1118 {
1119 #[repr(C, align(#block_size))]
1121 struct AlignedPlaceholder([u8; #limbs]);
1122
1123 const MODULUS_BYTES: AlignedPlaceholder = AlignedPlaceholder([#(#modulus_bytes),*]);
1124
1125 let mut uninit: core::mem::MaybeUninit<AlignedPlaceholder> = core::mem::MaybeUninit::uninit();
1128 openvm::platform::custom_insn_r!(
1129 opcode = ::openvm_algebra_guest::OPCODE,
1130 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1131 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1132 + #mod_idx
1133 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1134 rd = In uninit.as_mut_ptr(),
1135 rs1 = In MODULUS_BYTES.0.as_ptr(),
1136 rs2 = Const "x0" );
1138 openvm::platform::custom_insn_r!(
1139 opcode = ::openvm_algebra_guest::OPCODE,
1140 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3,
1141 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1142 + #mod_idx
1143 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1144 rd = In uninit.as_mut_ptr(),
1145 rs1 = In MODULUS_BYTES.0.as_ptr(),
1146 rs2 = Const "x1" );
1148 unsafe {
1149 let mut tmp = uninit.as_mut_ptr() as usize;
1151 openvm::platform::custom_insn_r!(
1152 opcode = ::openvm_algebra_guest::OPCODE,
1153 funct3 = ::openvm_algebra_guest::MODULAR_ARITHMETIC_FUNCT3 as usize,
1154 funct7 = ::openvm_algebra_guest::ModArithBaseFunct7::SetupMod as usize
1155 + #mod_idx
1156 * (::openvm_algebra_guest::ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize),
1157 rd = InOut tmp,
1158 rs1 = In MODULUS_BYTES.0.as_ptr(),
1159 rs2 = Const "x2" );
1161 }
1163 }
1164 }
1165 });
1166 }
1167
1168 let max_block_size = proc_macro::Literal::usize_unsuffixed(max_block_size);
1169 let max_block_size = syn::Lit::new(max_block_size.to_string().parse::<_>().unwrap());
1170
1171 let total_limbs_cnt = two_modular_limbs_flattened_list.len();
1172 let cnt_limbs_list_len = limb_list_borders.len();
1173 TokenStream::from(quote::quote_spanned! { span.into() =>
1174 #[allow(non_snake_case)]
1175 #[cfg(target_os = "zkvm")]
1176 mod openvm_intrinsics_ffi {
1177 #(#externs)*
1178 }
1179 #[allow(non_snake_case, non_upper_case_globals)]
1180 pub mod openvm_intrinsics_meta_do_not_type_this_by_yourself {
1181 #[repr(C, align(#max_block_size))]
1182 pub struct Aligned<T>(pub T);
1183
1184 pub const two_modular_limbs_list: Aligned<[u8; #total_limbs_cnt]> = Aligned([#(#two_modular_limbs_flattened_list),*]);
1185 pub const limb_list_borders: [usize; #cnt_limbs_list_len] = [#(#limb_list_borders),*];
1186 }
1187 })
1188}