openvm_algebra_complex_macros/
lib.rs1extern crate proc_macro;
2
3use openvm_macros_common::{MacroArgs, Param};
4use proc_macro::TokenStream;
5use syn::{
6 parse::{Parse, ParseStream},
7 parse_macro_input,
8 punctuated::Punctuated,
9 Expr, ExprPath, LitStr, Path, Token,
10};
11
12#[proc_macro]
21pub fn complex_declare(input: TokenStream) -> TokenStream {
22 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
23
24 let mut output = Vec::new();
25
26 let span = proc_macro::Span::call_site();
27
28 for item in items.into_iter() {
29 let struct_name = item.name.to_string();
30 let struct_name = syn::Ident::new(&struct_name, span.into());
31 let mut intmod_type: Option<syn::Path> = None;
32 for param in item.params {
33 match param.name.to_string().as_str() {
34 "mod_type" => {
35 if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
36 intmod_type = Some(path)
37 } else {
38 return syn::Error::new_spanned(param.value, "Expected a type")
39 .to_compile_error()
40 .into();
41 }
42 }
43 _ => {
44 panic!("Unknown parameter {}", param.name);
45 }
46 }
47 }
48
49 let intmod_type = intmod_type.expect("mod_type parameter is required");
50
51 macro_rules! create_extern_func {
52 ($name:ident) => {
53 let $name = syn::Ident::new(
54 &format!("{}_{}", stringify!($name), struct_name),
55 span.into(),
56 );
57 };
58 }
59 create_extern_func!(complex_add_extern_func);
60 create_extern_func!(complex_sub_extern_func);
61 create_extern_func!(complex_mul_extern_func);
62 create_extern_func!(complex_div_extern_func);
63 create_extern_func!(complex_setup_extern_func);
64
65 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
66 extern "C" {
67 fn #complex_add_extern_func(rd: usize, rs1: usize, rs2: usize);
68 fn #complex_sub_extern_func(rd: usize, rs1: usize, rs2: usize);
69 fn #complex_mul_extern_func(rd: usize, rs1: usize, rs2: usize);
70 fn #complex_div_extern_func(rd: usize, rs1: usize, rs2: usize);
71 fn #complex_setup_extern_func();
72 }
73
74
75 #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81 #[repr(C)]
82 pub struct #struct_name {
83 pub c0: #intmod_type,
85 pub c1: #intmod_type,
87 }
88
89 impl #struct_name {
90 pub const fn new(c0: #intmod_type, c1: #intmod_type) -> Self {
91 Self { c0, c1 }
92 }
93 }
94
95 impl #struct_name {
96 pub const ZERO: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
98
99 pub const ONE: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
101
102 pub fn neg_assign(&mut self) {
103 self.c0.neg_assign();
104 self.c1.neg_assign();
105 }
106
107 #[inline(always)]
109 fn add_assign_impl(&mut self, other: &Self) {
110 #[cfg(not(target_os = "zkvm"))]
111 {
112 self.c0 += &other.c0;
113 self.c1 += &other.c1;
114 }
115 #[cfg(target_os = "zkvm")]
116 {
117 Self::set_up_once();
118 unsafe {
119 #complex_add_extern_func(
120 self as *mut Self as usize,
121 self as *const Self as usize,
122 other as *const Self as usize
123 );
124 }
125 }
126 }
127
128 #[inline(always)]
130 fn sub_assign_impl(&mut self, other: &Self) {
131 #[cfg(not(target_os = "zkvm"))]
132 {
133 self.c0 -= &other.c0;
134 self.c1 -= &other.c1;
135 }
136 #[cfg(target_os = "zkvm")]
137 {
138 Self::set_up_once();
139 unsafe {
140 #complex_sub_extern_func(
141 self as *mut Self as usize,
142 self as *const Self as usize,
143 other as *const Self as usize
144 );
145 }
146 }
147 }
148
149 #[inline(always)]
151 fn mul_assign_impl(&mut self, other: &Self) {
152 #[cfg(not(target_os = "zkvm"))]
153 {
154 let (c0, c1) = (&self.c0, &self.c1);
155 let (d0, d1) = (&other.c0, &other.c1);
156 *self = Self::new(
157 c0.clone() * d0 - c1.clone() * d1,
158 c0.clone() * d1 + c1.clone() * d0,
159 );
160 }
161 #[cfg(target_os = "zkvm")]
162 {
163 Self::set_up_once();
164 unsafe {
165 #complex_mul_extern_func(
166 self as *mut Self as usize,
167 self as *const Self as usize,
168 other as *const Self as usize
169 );
170 }
171 }
172 }
173
174 #[inline(always)]
176 fn div_assign_unsafe_impl(&mut self, other: &Self) {
177 #[cfg(not(target_os = "zkvm"))]
178 {
179 let (c0, c1) = (&self.c0, &self.c1);
180 let (d0, d1) = (&other.c0, &other.c1);
181 let denom = openvm_algebra_guest::DivUnsafe::div_unsafe(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, d0.square() + d1.square());
182 *self = Self::new(
183 denom.clone() * (c0.clone() * d0 + c1.clone() * d1),
184 denom * &(c1.clone() * d0 - c0.clone() * d1),
185 );
186 }
187 #[cfg(target_os = "zkvm")]
188 {
189 Self::set_up_once();
190 unsafe {
191 #complex_div_extern_func(
192 self as *mut Self as usize,
193 self as *const Self as usize,
194 other as *const Self as usize
195 );
196 }
197 }
198 }
199
200 fn add_refs_impl(&self, other: &Self) -> Self {
202 #[cfg(not(target_os = "zkvm"))]
203 {
204 let mut res = self.clone();
205 res.add_assign_impl(other);
206 res
207 }
208 #[cfg(target_os = "zkvm")]
209 {
210 Self::set_up_once();
211 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
212 unsafe {
213 #complex_add_extern_func(
214 uninit.as_mut_ptr() as usize,
215 self as *const Self as usize,
216 other as *const Self as usize
217 );
218 }
219 unsafe { uninit.assume_init() }
220 }
221 }
222
223 #[inline(always)]
225 fn sub_refs_impl(&self, other: &Self) -> Self {
226 #[cfg(not(target_os = "zkvm"))]
227 {
228 let mut res = self.clone();
229 res.sub_assign_impl(other);
230 res
231 }
232 #[cfg(target_os = "zkvm")]
233 {
234 Self::set_up_once();
235 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
236 unsafe {
237 #complex_sub_extern_func(
238 uninit.as_mut_ptr() as usize,
239 self as *const Self as usize,
240 other as *const Self as usize
241 );
242 }
243 unsafe { uninit.assume_init() }
244 }
245 }
246
247 #[inline(always)]
252 unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
253 #[cfg(not(target_os = "zkvm"))]
254 {
255 let mut res = self.clone();
256 res.mul_assign_impl(other);
257 let dst = unsafe { &mut *dst_ptr };
258 *dst = res;
259 }
260 #[cfg(target_os = "zkvm")]
261 {
262 Self::set_up_once();
263 unsafe {
264 #complex_mul_extern_func(
265 dst_ptr as usize,
266 self as *const Self as usize,
267 other as *const Self as usize
268 );
269 }
270 }
271 }
272
273 #[inline(always)]
275 fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
276 #[cfg(not(target_os = "zkvm"))]
277 {
278 let mut res = self.clone();
279 res.div_assign_unsafe_impl(other);
280 res
281 }
282 #[cfg(target_os = "zkvm")]
283 {
284 Self::set_up_once();
285 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
286 unsafe {
287 #complex_div_extern_func(
288 uninit.as_mut_ptr() as usize,
289 self as *const Self as usize,
290 other as *const Self as usize
291 );
292 }
293 unsafe { uninit.assume_init() }
294 }
295 }
296
297 fn set_up_once() {
299 static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new();
300 is_setup.get_or_init(|| {
301 unsafe { #complex_setup_extern_func(); }
302 true
303 });
304 }
305 }
306
307 impl openvm_algebra_guest::field::ComplexConjugate for #struct_name {
308 fn conjugate(self) -> Self {
309 Self {
310 c0: self.c0,
311 c1: -self.c1,
312 }
313 }
314
315 fn conjugate_assign(&mut self) {
316 self.c1.neg_assign();
317 }
318 }
319
320 impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
321 #[inline(always)]
322 fn add_assign(&mut self, other: &'a #struct_name) {
323 self.add_assign_impl(other);
324 }
325 }
326
327 impl core::ops::AddAssign for #struct_name {
328 #[inline(always)]
329 fn add_assign(&mut self, other: Self) {
330 self.add_assign_impl(&other);
331 }
332 }
333
334 impl core::ops::Add for #struct_name {
335 type Output = Self;
336 #[inline(always)]
337 fn add(mut self, other: Self) -> Self::Output {
338 self += other;
339 self
340 }
341 }
342
343 impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
344 type Output = Self;
345 #[inline(always)]
346 fn add(mut self, other: &'a #struct_name) -> Self::Output {
347 self += other;
348 self
349 }
350 }
351
352 impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
353 type Output = #struct_name;
354 #[inline(always)]
355 fn add(self, other: &'a #struct_name) -> Self::Output {
356 self.add_refs_impl(other)
357 }
358 }
359
360 impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
361 #[inline(always)]
362 fn sub_assign(&mut self, other: &'a #struct_name) {
363 self.sub_assign_impl(other);
364 }
365 }
366
367 impl core::ops::SubAssign for #struct_name {
368 #[inline(always)]
369 fn sub_assign(&mut self, other: Self) {
370 self.sub_assign_impl(&other);
371 }
372 }
373
374 impl core::ops::Sub for #struct_name {
375 type Output = Self;
376 #[inline(always)]
377 fn sub(mut self, other: Self) -> Self::Output {
378 self -= other;
379 self
380 }
381 }
382
383 impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
384 type Output = Self;
385 #[inline(always)]
386 fn sub(mut self, other: &'a #struct_name) -> Self::Output {
387 self -= other;
388 self
389 }
390 }
391
392 impl<'a> core::ops::Sub<&'a #struct_name> for &#struct_name {
393 type Output = #struct_name;
394 #[inline(always)]
395 fn sub(self, other: &'a #struct_name) -> Self::Output {
396 self.sub_refs_impl(other)
397 }
398 }
399
400 impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
401 #[inline(always)]
402 fn mul_assign(&mut self, other: &'a #struct_name) {
403 self.mul_assign_impl(other);
404 }
405 }
406
407 impl core::ops::MulAssign for #struct_name {
408 #[inline(always)]
409 fn mul_assign(&mut self, other: Self) {
410 self.mul_assign_impl(&other);
411 }
412 }
413
414 impl core::ops::Mul for #struct_name {
415 type Output = Self;
416 #[inline(always)]
417 fn mul(mut self, other: Self) -> Self::Output {
418 self *= other;
419 self
420 }
421 }
422
423 impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
424 type Output = Self;
425 #[inline(always)]
426 fn mul(mut self, other: &'a #struct_name) -> Self::Output {
427 self *= other;
428 self
429 }
430 }
431
432 impl<'a> core::ops::Mul<&'a #struct_name> for &'a #struct_name {
433 type Output = #struct_name;
434 #[inline(always)]
435 fn mul(self, other: &'a #struct_name) -> Self::Output {
436 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
437 unsafe {
438 self.mul_refs_impl(other, uninit.as_mut_ptr());
439 uninit.assume_init()
440 }
441 }
442 }
443
444 impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
445 #[inline(always)]
446 fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
447 self.div_assign_unsafe_impl(other);
448 }
449 }
450
451 impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
452 #[inline(always)]
453 fn div_assign_unsafe(&mut self, other: Self) {
454 self.div_assign_unsafe_impl(&other);
455 }
456 }
457
458 impl openvm_algebra_guest::DivUnsafe for #struct_name {
459 type Output = Self;
460 #[inline(always)]
461 fn div_unsafe(mut self, other: Self) -> Self::Output {
462 self = self.div_unsafe_refs_impl(&other);
463 self
464 }
465 }
466
467 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
468 type Output = Self;
469 #[inline(always)]
470 fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
471 self = self.div_unsafe_refs_impl(other);
472 self
473 }
474 }
475
476 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
477 type Output = #struct_name;
478 #[inline(always)]
479 fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
480 self.div_unsafe_refs_impl(other)
481 }
482 }
483
484 impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
485 fn sum<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
486 iter.fold(Self::ZERO, |acc, x| &acc + x)
487 }
488 }
489
490 impl core::iter::Sum for #struct_name {
491 fn sum<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
492 iter.fold(Self::ZERO, |acc, x| &acc + &x)
493 }
494 }
495
496 impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
497 fn product<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
498 iter.fold(Self::ONE, |acc, x| &acc * x)
499 }
500 }
501
502 impl core::iter::Product for #struct_name {
503 fn product<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
504 iter.fold(Self::ONE, |acc, x| &acc * &x)
505 }
506 }
507
508 impl core::ops::Neg for #struct_name {
509 type Output = #struct_name;
510 fn neg(self) -> Self::Output {
511 Self::ZERO - &self
512 }
513 }
514
515 impl core::ops::Neg for &#struct_name {
516 type Output = #struct_name;
517 fn neg(self) -> Self::Output {
518 #struct_name::ZERO - self
519 }
520 }
521
522 impl core::fmt::Debug for #struct_name {
523 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524 write!(f, "{:?} + {:?} * u", self.c0, self.c1)
525 }
526 }
527 });
528 output.push(result);
529 }
530
531 TokenStream::from_iter(output)
532}
533
534struct ComplexInitArgs {
537 pub items: Vec<ComplexInitItem>,
538}
539
540struct ComplexInitItem {
541 pub name: LitStr,
542 pub params: Punctuated<Param, Token![,]>,
543}
544
545impl Parse for ComplexInitArgs {
546 fn parse(input: ParseStream) -> syn::Result<Self> {
547 Ok(ComplexInitArgs {
548 items: input
549 .parse_terminated(ComplexInitItem::parse, Token![,])?
550 .into_iter()
551 .collect(),
552 })
553 }
554}
555
556impl Parse for ComplexInitItem {
557 fn parse(input: ParseStream) -> syn::Result<Self> {
558 let name = input.parse()?;
559 let content;
560 syn::braced!(content in input);
561 let params = content.parse_terminated(Param::parse, Token![,])?;
562 Ok(ComplexInitItem { name, params })
563 }
564}
565
566#[proc_macro]
579pub fn complex_init(input: TokenStream) -> TokenStream {
580 let ComplexInitArgs { items } = parse_macro_input!(input as ComplexInitArgs);
581
582 let mut externs = Vec::new();
583
584 let span = proc_macro::Span::call_site();
585
586 for (complex_idx, item) in items.into_iter().enumerate() {
587 let struct_name = item.name.value();
588 let struct_name = syn::Ident::new(&struct_name, span.into());
589 let mut intmod_idx: Option<usize> = None;
590 for param in item.params {
591 match param.name.to_string().as_str() {
592 "mod_idx" => {
593 if let syn::Expr::Lit(syn::ExprLit {
594 lit: syn::Lit::Int(int),
595 ..
596 }) = param.value
597 {
598 intmod_idx = Some(int.base10_parse::<usize>().unwrap());
599 } else {
600 return syn::Error::new_spanned(param.value, "Expected usize")
601 .to_compile_error()
602 .into();
603 }
604 }
605 _ => {
606 panic!("Unknown parameter {}", param.name);
607 }
608 }
609 }
610 let mod_idx = intmod_idx.expect("mod_idx is required");
611
612 println!(
613 "[init] complex #{} = {} (mod_idx = {})",
614 complex_idx, struct_name, mod_idx
615 );
616
617 for op_type in ["add", "sub", "mul", "div"] {
618 let func_name = syn::Ident::new(
619 &format!("complex_{}_extern_func_{}", op_type, struct_name),
620 span.into(),
621 );
622 let mut chars = op_type.chars().collect::<Vec<_>>();
623 chars[0] = chars[0].to_ascii_uppercase();
624 let local_opcode = syn::Ident::new(&chars.iter().collect::<String>(), span.into());
625 externs.push(quote::quote_spanned! { span.into() =>
626 #[no_mangle]
627 extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
628 openvm::platform::custom_insn_r!(
629 opcode = openvm_algebra_guest::OPCODE,
630 funct3 = openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
631 funct7 = openvm_algebra_guest::ComplexExtFieldBaseFunct7::#local_opcode as usize
632 + #complex_idx * (openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
633 rd = In rd,
634 rs1 = In rs1,
635 rs2 = In rs2
636 )
637 }
638 });
639 }
640
641 let setup_extern_func = syn::Ident::new(
642 &format!("complex_setup_extern_func_{}", struct_name),
643 span.into(),
644 );
645
646 externs.push(quote::quote_spanned! { span.into() =>
647 #[no_mangle]
648 extern "C" fn #setup_extern_func() {
649 #[cfg(target_os = "zkvm")]
650 {
651 use super::openvm_intrinsics_meta_do_not_type_this_by_yourself::{two_modular_limbs_list, limb_list_borders};
652 let two_modulus_bytes = &two_modular_limbs_list[limb_list_borders[#mod_idx]..limb_list_borders[#mod_idx + 1]];
653
654 let mut uninit: core::mem::MaybeUninit<[u8; limb_list_borders[#mod_idx + 1] - limb_list_borders[#mod_idx]]> = core::mem::MaybeUninit::uninit();
657 openvm::platform::custom_insn_r!(
658 opcode = ::openvm_algebra_guest::OPCODE,
659 funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
660 funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
661 + #complex_idx
662 * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
663 rd = In uninit.as_mut_ptr(),
664 rs1 = In two_modulus_bytes.as_ptr(),
665 rs2 = Const "x0" );
667 openvm::platform::custom_insn_r!(
668 opcode = ::openvm_algebra_guest::OPCODE,
669 funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
670 funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
671 + #complex_idx
672 * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
673 rd = In uninit.as_mut_ptr(),
674 rs1 = In two_modulus_bytes.as_ptr(),
675 rs2 = Const "x1" );
677 }
678 }
679 });
680 }
681
682 TokenStream::from(quote::quote_spanned! { span.into() =>
683 #[allow(non_snake_case)]
684 #[cfg(target_os = "zkvm")]
685 mod openvm_intrinsics_ffi_complex {
686 #(#externs)*
687 }
688 })
689}
690
691struct ComplexSimpleItem {
692 items: Vec<Path>,
693}
694
695impl Parse for ComplexSimpleItem {
696 fn parse(input: ParseStream) -> syn::Result<Self> {
697 let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
698 Ok(Self {
699 items: items
700 .into_iter()
701 .map(|e| {
702 if let Expr::Path(p) = e {
703 p.path
704 } else {
705 panic!("expected path");
706 }
707 })
708 .collect(),
709 })
710 }
711}
712
713#[proc_macro]
714pub fn complex_impl_field(input: TokenStream) -> TokenStream {
715 let ComplexSimpleItem { items } = parse_macro_input!(input as ComplexSimpleItem);
716
717 let mut output = Vec::new();
718
719 let span = proc_macro::Span::call_site();
720
721 for item in items.into_iter() {
722 let str_path = item
723 .segments
724 .iter()
725 .map(|x| x.ident.to_string())
726 .collect::<Vec<_>>()
727 .join("_");
728 let struct_name = syn::Ident::new(&str_path, span.into());
729
730 output.push(quote::quote_spanned! { span.into() =>
731 impl openvm_algebra_guest::field::Field for #struct_name {
732 type SelfRef<'a>
733 = &'a Self
734 where
735 Self: 'a;
736
737 const ZERO: Self = Self::ZERO;
738 const ONE: Self = Self::ONE;
739
740 fn double_assign(&mut self) {
741 openvm_algebra_guest::field::Field::double_assign(&mut self.c0);
742 openvm_algebra_guest::field::Field::double_assign(&mut self.c1);
743 }
744
745 fn square_assign(&mut self) {
746 unsafe {
747 self.mul_refs_impl(self, self as *const Self as *mut Self);
748 }
749 }
750 }
751 });
752 }
753
754 TokenStream::from(quote::quote_spanned! { span.into() =>
755 #(#output)*
756 })
757}