openvm_algebra_complex_macros/
lib.rs
1extern crate proc_macro;
2
3use openvm_macros_common::MacroArgs;
4use proc_macro::TokenStream;
5use syn::{
6 parse::{Parse, ParseStream},
7 parse_macro_input, Expr, ExprPath, Path, Token,
8};
9
10#[proc_macro]
19pub fn complex_declare(input: TokenStream) -> TokenStream {
20 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
21
22 let mut output = Vec::new();
23
24 let span = proc_macro::Span::call_site();
25
26 for item in items.into_iter() {
27 let struct_name = item.name.to_string();
28 let struct_name = syn::Ident::new(&struct_name, span.into());
29 let mut intmod_type: Option<syn::Path> = None;
30 for param in item.params {
31 match param.name.to_string().as_str() {
32 "mod_type" => {
33 if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
34 intmod_type = Some(path)
35 } else {
36 return syn::Error::new_spanned(param.value, "Expected a type")
37 .to_compile_error()
38 .into();
39 }
40 }
41 _ => {
42 panic!("Unknown parameter {}", param.name);
43 }
44 }
45 }
46
47 let intmod_type = intmod_type.expect("mod_type parameter is required");
48
49 macro_rules! create_extern_func {
50 ($name:ident) => {
51 let $name = syn::Ident::new(
52 &format!("{}_{}", stringify!($name), struct_name),
53 span.into(),
54 );
55 };
56 }
57 create_extern_func!(complex_add_extern_func);
58 create_extern_func!(complex_sub_extern_func);
59 create_extern_func!(complex_mul_extern_func);
60 create_extern_func!(complex_div_extern_func);
61
62 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
63 extern "C" {
64 fn #complex_add_extern_func(rd: usize, rs1: usize, rs2: usize);
65 fn #complex_sub_extern_func(rd: usize, rs1: usize, rs2: usize);
66 fn #complex_mul_extern_func(rd: usize, rs1: usize, rs2: usize);
67 fn #complex_div_extern_func(rd: usize, rs1: usize, rs2: usize);
68 }
69
70
71 #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
77 #[repr(C)]
78 pub struct #struct_name {
79 pub c0: #intmod_type,
81 pub c1: #intmod_type,
83 }
84
85 impl #struct_name {
86 pub const fn new(c0: #intmod_type, c1: #intmod_type) -> Self {
87 Self { c0, c1 }
88 }
89 }
90
91 impl #struct_name {
92 pub const ZERO: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
94
95 pub const ONE: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
97
98 pub fn neg_assign(&mut self) {
99 self.c0.neg_assign();
100 self.c1.neg_assign();
101 }
102
103 #[inline(always)]
105 fn add_assign_impl(&mut self, other: &Self) {
106 #[cfg(not(target_os = "zkvm"))]
107 {
108 self.c0 += &other.c0;
109 self.c1 += &other.c1;
110 }
111 #[cfg(target_os = "zkvm")]
112 {
113 unsafe {
114 #complex_add_extern_func(
115 self as *mut Self as usize,
116 self as *const Self as usize,
117 other as *const Self as usize
118 );
119 }
120 }
121 }
122
123 #[inline(always)]
125 fn sub_assign_impl(&mut self, other: &Self) {
126 #[cfg(not(target_os = "zkvm"))]
127 {
128 self.c0 -= &other.c0;
129 self.c1 -= &other.c1;
130 }
131 #[cfg(target_os = "zkvm")]
132 {
133 unsafe {
134 #complex_sub_extern_func(
135 self as *mut Self as usize,
136 self as *const Self as usize,
137 other as *const Self as usize
138 );
139 }
140 }
141 }
142
143 #[inline(always)]
145 fn mul_assign_impl(&mut self, other: &Self) {
146 #[cfg(not(target_os = "zkvm"))]
147 {
148 let (c0, c1) = (&self.c0, &self.c1);
149 let (d0, d1) = (&other.c0, &other.c1);
150 *self = Self::new(
151 c0.clone() * d0 - c1.clone() * d1,
152 c0.clone() * d1 + c1.clone() * d0,
153 );
154 }
155 #[cfg(target_os = "zkvm")]
156 {
157 unsafe {
158 #complex_mul_extern_func(
159 self as *mut Self as usize,
160 self as *const Self as usize,
161 other as *const Self as usize
162 );
163 }
164 }
165 }
166
167 #[inline(always)]
169 fn div_assign_unsafe_impl(&mut self, other: &Self) {
170 #[cfg(not(target_os = "zkvm"))]
171 {
172 let (c0, c1) = (&self.c0, &self.c1);
173 let (d0, d1) = (&other.c0, &other.c1);
174 let denom = openvm_algebra_guest::DivUnsafe::div_unsafe(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, d0.square() + d1.square());
175 *self = Self::new(
176 denom.clone() * (c0.clone() * d0 + c1.clone() * d1),
177 denom * &(c1.clone() * d0 - c0.clone() * d1),
178 );
179 }
180 #[cfg(target_os = "zkvm")]
181 {
182 unsafe {
183 #complex_div_extern_func(
184 self as *mut Self as usize,
185 self as *const Self as usize,
186 other as *const Self as usize
187 );
188 }
189 }
190 }
191
192 fn add_refs_impl(&self, other: &Self) -> Self {
194 #[cfg(not(target_os = "zkvm"))]
195 {
196 let mut res = self.clone();
197 res.add_assign_impl(other);
198 res
199 }
200 #[cfg(target_os = "zkvm")]
201 {
202 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
203 unsafe {
204 #complex_add_extern_func(
205 uninit.as_mut_ptr() as usize,
206 self as *const Self as usize,
207 other as *const Self as usize
208 );
209 }
210 unsafe { uninit.assume_init() }
211 }
212 }
213
214 #[inline(always)]
216 fn sub_refs_impl(&self, other: &Self) -> Self {
217 #[cfg(not(target_os = "zkvm"))]
218 {
219 let mut res = self.clone();
220 res.sub_assign_impl(other);
221 res
222 }
223 #[cfg(target_os = "zkvm")]
224 {
225 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
226 unsafe {
227 #complex_sub_extern_func(
228 uninit.as_mut_ptr() as usize,
229 self as *const Self as usize,
230 other as *const Self as usize
231 );
232 }
233 unsafe { uninit.assume_init() }
234 }
235 }
236
237 #[inline(always)]
242 unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
243 #[cfg(not(target_os = "zkvm"))]
244 {
245 let mut res = self.clone();
246 res.mul_assign_impl(other);
247 let dst = unsafe { &mut *dst_ptr };
248 *dst = res;
249 }
250 #[cfg(target_os = "zkvm")]
251 {
252 unsafe {
253 #complex_mul_extern_func(
254 dst_ptr as usize,
255 self as *const Self as usize,
256 other as *const Self as usize
257 );
258 }
259 }
260 }
261
262 #[inline(always)]
264 fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
265 #[cfg(not(target_os = "zkvm"))]
266 {
267 let mut res = self.clone();
268 res.div_assign_unsafe_impl(other);
269 res
270 }
271 #[cfg(target_os = "zkvm")]
272 {
273 let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
274 unsafe {
275 #complex_div_extern_func(
276 uninit.as_mut_ptr() as usize,
277 self as *const Self as usize,
278 other as *const Self as usize
279 );
280 }
281 unsafe { uninit.assume_init() }
282 }
283 }
284 }
285
286 impl openvm_algebra_guest::field::ComplexConjugate for #struct_name {
287 fn conjugate(self) -> Self {
288 Self {
289 c0: self.c0,
290 c1: -self.c1,
291 }
292 }
293
294 fn conjugate_assign(&mut self) {
295 self.c1.neg_assign();
296 }
297 }
298
299 impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
300 #[inline(always)]
301 fn add_assign(&mut self, other: &'a #struct_name) {
302 self.add_assign_impl(other);
303 }
304 }
305
306 impl core::ops::AddAssign for #struct_name {
307 #[inline(always)]
308 fn add_assign(&mut self, other: Self) {
309 self.add_assign_impl(&other);
310 }
311 }
312
313 impl core::ops::Add for #struct_name {
314 type Output = Self;
315 #[inline(always)]
316 fn add(mut self, other: Self) -> Self::Output {
317 self += other;
318 self
319 }
320 }
321
322 impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
323 type Output = Self;
324 #[inline(always)]
325 fn add(mut self, other: &'a #struct_name) -> Self::Output {
326 self += other;
327 self
328 }
329 }
330
331 impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
332 type Output = #struct_name;
333 #[inline(always)]
334 fn add(self, other: &'a #struct_name) -> Self::Output {
335 self.add_refs_impl(other)
336 }
337 }
338
339 impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
340 #[inline(always)]
341 fn sub_assign(&mut self, other: &'a #struct_name) {
342 self.sub_assign_impl(other);
343 }
344 }
345
346 impl core::ops::SubAssign for #struct_name {
347 #[inline(always)]
348 fn sub_assign(&mut self, other: Self) {
349 self.sub_assign_impl(&other);
350 }
351 }
352
353 impl core::ops::Sub for #struct_name {
354 type Output = Self;
355 #[inline(always)]
356 fn sub(mut self, other: Self) -> Self::Output {
357 self -= other;
358 self
359 }
360 }
361
362 impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
363 type Output = Self;
364 #[inline(always)]
365 fn sub(mut self, other: &'a #struct_name) -> Self::Output {
366 self -= other;
367 self
368 }
369 }
370
371 impl<'a> core::ops::Sub<&'a #struct_name> for &#struct_name {
372 type Output = #struct_name;
373 #[inline(always)]
374 fn sub(self, other: &'a #struct_name) -> Self::Output {
375 self.sub_refs_impl(other)
376 }
377 }
378
379 impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
380 #[inline(always)]
381 fn mul_assign(&mut self, other: &'a #struct_name) {
382 self.mul_assign_impl(other);
383 }
384 }
385
386 impl core::ops::MulAssign for #struct_name {
387 #[inline(always)]
388 fn mul_assign(&mut self, other: Self) {
389 self.mul_assign_impl(&other);
390 }
391 }
392
393 impl core::ops::Mul for #struct_name {
394 type Output = Self;
395 #[inline(always)]
396 fn mul(mut self, other: Self) -> Self::Output {
397 self *= other;
398 self
399 }
400 }
401
402 impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
403 type Output = Self;
404 #[inline(always)]
405 fn mul(mut self, other: &'a #struct_name) -> Self::Output {
406 self *= other;
407 self
408 }
409 }
410
411 impl<'a> core::ops::Mul<&'a #struct_name> for &'a #struct_name {
412 type Output = #struct_name;
413 #[inline(always)]
414 fn mul(self, other: &'a #struct_name) -> Self::Output {
415 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
416 unsafe {
417 self.mul_refs_impl(other, uninit.as_mut_ptr());
418 uninit.assume_init()
419 }
420 }
421 }
422
423 impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
424 #[inline(always)]
425 fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
426 self.div_assign_unsafe_impl(other);
427 }
428 }
429
430 impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
431 #[inline(always)]
432 fn div_assign_unsafe(&mut self, other: Self) {
433 self.div_assign_unsafe_impl(&other);
434 }
435 }
436
437 impl openvm_algebra_guest::DivUnsafe for #struct_name {
438 type Output = Self;
439 #[inline(always)]
440 fn div_unsafe(mut self, other: Self) -> Self::Output {
441 self = self.div_unsafe_refs_impl(&other);
442 self
443 }
444 }
445
446 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
447 type Output = Self;
448 #[inline(always)]
449 fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
450 self = self.div_unsafe_refs_impl(other);
451 self
452 }
453 }
454
455 impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
456 type Output = #struct_name;
457 #[inline(always)]
458 fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
459 self.div_unsafe_refs_impl(other)
460 }
461 }
462
463 impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
464 fn sum<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
465 iter.fold(Self::ZERO, |acc, x| &acc + x)
466 }
467 }
468
469 impl core::iter::Sum for #struct_name {
470 fn sum<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
471 iter.fold(Self::ZERO, |acc, x| &acc + &x)
472 }
473 }
474
475 impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
476 fn product<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
477 iter.fold(Self::ONE, |acc, x| &acc * x)
478 }
479 }
480
481 impl core::iter::Product for #struct_name {
482 fn product<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
483 iter.fold(Self::ONE, |acc, x| &acc * &x)
484 }
485 }
486
487 impl core::ops::Neg for #struct_name {
488 type Output = #struct_name;
489 fn neg(self) -> Self::Output {
490 Self::ZERO - &self
491 }
492 }
493
494 impl core::ops::Neg for &#struct_name {
495 type Output = #struct_name;
496 fn neg(self) -> Self::Output {
497 #struct_name::ZERO - self
498 }
499 }
500
501 impl core::fmt::Debug for #struct_name {
502 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
503 write!(f, "{:?} + {:?} * u", self.c0, self.c1)
504 }
505 }
506 });
507 output.push(result);
508 }
509
510 TokenStream::from_iter(output)
511}
512
513#[proc_macro]
525pub fn complex_init(input: TokenStream) -> TokenStream {
526 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
527
528 let mut externs = Vec::new();
529 let mut setups = Vec::new();
530 let mut setup_all_complex_extensions = Vec::new();
531
532 let span = proc_macro::Span::call_site();
533
534 for (complex_idx, item) in items.into_iter().enumerate() {
535 let struct_name = item.name.to_string();
536 let struct_name = syn::Ident::new(&struct_name, span.into());
537 let mut intmod_idx: Option<usize> = None;
538 for param in item.params {
539 match param.name.to_string().as_str() {
540 "mod_idx" => {
541 if let syn::Expr::Lit(syn::ExprLit {
542 lit: syn::Lit::Int(int),
543 ..
544 }) = param.value
545 {
546 intmod_idx = Some(int.base10_parse::<usize>().unwrap());
547 } else {
548 return syn::Error::new_spanned(param.value, "Expected usize")
549 .to_compile_error()
550 .into();
551 }
552 }
553 _ => {
554 panic!("Unknown parameter {}", param.name);
555 }
556 }
557 }
558 let mod_idx = intmod_idx.expect("mod_idx is required");
559
560 println!(
561 "[init] complex #{} = {} (mod_idx = {})",
562 complex_idx, struct_name, mod_idx
563 );
564
565 for op_type in ["add", "sub", "mul", "div"] {
566 let func_name = syn::Ident::new(
567 &format!("complex_{}_extern_func_{}", op_type, struct_name),
568 span.into(),
569 );
570 let mut chars = op_type.chars().collect::<Vec<_>>();
571 chars[0] = chars[0].to_ascii_uppercase();
572 let local_opcode = syn::Ident::new(&chars.iter().collect::<String>(), span.into());
573 externs.push(quote::quote_spanned! { span.into() =>
574 #[no_mangle]
575 extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
576 openvm::platform::custom_insn_r!(
577 opcode = openvm_algebra_guest::OPCODE,
578 funct3 = openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
579 funct7 = openvm_algebra_guest::ComplexExtFieldBaseFunct7::#local_opcode as usize
580 + #complex_idx * (openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
581 rd = In rd,
582 rs1 = In rs1,
583 rs2 = In rs2
584 )
585 }
586 });
587 }
588
589 let setup_function =
590 syn::Ident::new(&format!("setup_complex_{}", complex_idx), span.into());
591
592 setup_all_complex_extensions.push(quote::quote_spanned! { span.into() =>
593 #setup_function();
594 });
595 setups.push(quote::quote_spanned! { span.into() =>
596 #[allow(non_snake_case)]
597 pub fn #setup_function() {
598 #[cfg(target_os = "zkvm")]
599 {
600 let two_modulus_bytes = &openvm_intrinsics_meta_do_not_type_this_by_yourself::two_modular_limbs_list[openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx]..openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx + 1]];
601
602 let mut uninit: core::mem::MaybeUninit<[u8; openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx + 1] - openvm_intrinsics_meta_do_not_type_this_by_yourself::limb_list_borders[#mod_idx]]> = core::mem::MaybeUninit::uninit();
605 openvm::platform::custom_insn_r!(
606 opcode = ::openvm_algebra_guest::OPCODE,
607 funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
608 funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
609 + #complex_idx
610 * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
611 rd = In uninit.as_mut_ptr(),
612 rs1 = In two_modulus_bytes.as_ptr(),
613 rs2 = Const "x0" );
615 openvm::platform::custom_insn_r!(
616 opcode = ::openvm_algebra_guest::OPCODE,
617 funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
618 funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
619 + #complex_idx
620 * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
621 rd = In uninit.as_mut_ptr(),
622 rs1 = In two_modulus_bytes.as_ptr(),
623 rs2 = Const "x1" );
625 }
626 }
627 });
628 }
629
630 TokenStream::from(quote::quote_spanned! { span.into() =>
631 #[cfg(target_os = "zkvm")]
632 mod openvm_intrinsics_ffi_complex {
633 #(#externs)*
634 }
635 #(#setups)*
636 pub fn setup_all_complex_extensions() {
637 #(#setup_all_complex_extensions)*
638 }
639 })
640}
641
642struct ComplexSimpleItem {
643 items: Vec<Path>,
644}
645
646impl Parse for ComplexSimpleItem {
647 fn parse(input: ParseStream) -> syn::Result<Self> {
648 let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
649 Ok(Self {
650 items: items
651 .into_iter()
652 .map(|e| {
653 if let Expr::Path(p) = e {
654 p.path
655 } else {
656 panic!("expected path");
657 }
658 })
659 .collect(),
660 })
661 }
662}
663
664#[proc_macro]
665pub fn complex_impl_field(input: TokenStream) -> TokenStream {
666 let ComplexSimpleItem { items } = parse_macro_input!(input as ComplexSimpleItem);
667
668 let mut output = Vec::new();
669
670 let span = proc_macro::Span::call_site();
671
672 for item in items.into_iter() {
673 let str_path = item
674 .segments
675 .iter()
676 .map(|x| x.ident.to_string())
677 .collect::<Vec<_>>()
678 .join("_");
679 let struct_name = syn::Ident::new(&str_path, span.into());
680
681 output.push(quote::quote_spanned! { span.into() =>
682 impl openvm_algebra_guest::field::Field for #struct_name {
683 type SelfRef<'a>
684 = &'a Self
685 where
686 Self: 'a;
687
688 const ZERO: Self = Self::ZERO;
689 const ONE: Self = Self::ONE;
690
691 fn double_assign(&mut self) {
692 openvm_algebra_guest::field::Field::double_assign(&mut self.c0);
693 openvm_algebra_guest::field::Field::double_assign(&mut self.c1);
694 }
695
696 fn square_assign(&mut self) {
697 unsafe {
698 self.mul_refs_impl(self, self as *const Self as *mut Self);
699 }
700 }
701 }
702 });
703 }
704
705 TokenStream::from(quote::quote_spanned! { span.into() =>
706 #(#output)*
707 })
708}