openvm_algebra_complex_macros/
lib.rs

1extern crate proc_macro;
2
3use openvm_macros_common::{MacroArgs, Param};
4use proc_macro::TokenStream;
5use syn::{
6    parse::{Parse, ParseStream},
7    parse_macro_input,
8    punctuated::Punctuated,
9    Expr, ExprPath, LitStr, Path, Token,
10};
11
12/// This macro is used to declare the complex extension fields.
13/// Usage:
14/// ```rust
15/// complex_declare! {
16///     Complex1 { mod_type = Mod1 },
17///     Complex2 { mod_type = Mod2 },
18/// }
19/// ```
20#[proc_macro]
21pub fn complex_declare(input: TokenStream) -> TokenStream {
22    let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
23
24    let mut output = Vec::new();
25
26    let span = proc_macro::Span::call_site();
27
28    for item in items.into_iter() {
29        let struct_name = item.name.to_string();
30        let struct_name = syn::Ident::new(&struct_name, span.into());
31        let mut intmod_type: Option<syn::Path> = None;
32        for param in item.params {
33            match param.name.to_string().as_str() {
34                "mod_type" => {
35                    if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
36                        intmod_type = Some(path)
37                    } else {
38                        return syn::Error::new_spanned(param.value, "Expected a type")
39                            .to_compile_error()
40                            .into();
41                    }
42                }
43                _ => {
44                    panic!("Unknown parameter {}", param.name);
45                }
46            }
47        }
48
49        let intmod_type = intmod_type.expect("mod_type parameter is required");
50
51        macro_rules! create_extern_func {
52            ($name:ident) => {
53                let $name = syn::Ident::new(
54                    &format!("{}_{}", stringify!($name), struct_name),
55                    span.into(),
56                );
57            };
58        }
59        create_extern_func!(complex_add_extern_func);
60        create_extern_func!(complex_sub_extern_func);
61        create_extern_func!(complex_mul_extern_func);
62        create_extern_func!(complex_div_extern_func);
63        create_extern_func!(complex_setup_extern_func);
64
65        let result = TokenStream::from(quote::quote_spanned! { span.into() =>
66            extern "C" {
67                fn #complex_add_extern_func(rd: usize, rs1: usize, rs2: usize);
68                fn #complex_sub_extern_func(rd: usize, rs1: usize, rs2: usize);
69                fn #complex_mul_extern_func(rd: usize, rs1: usize, rs2: usize);
70                fn #complex_div_extern_func(rd: usize, rs1: usize, rs2: usize);
71                fn #complex_setup_extern_func();
72            }
73
74
75            /// Quadratic extension field of `#intmod_type` with irreducible polynomial `X^2 + 1`.
76            /// Elements are represented as `c0 + c1 * u` where `u^2 = -1`.
77            ///
78            /// Memory alignment follows alignment of `#intmod_type`.
79            /// Memory layout is concatenation of `c0` and `c1`.
80            #[derive(Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
81            #[repr(C)]
82            pub struct #struct_name {
83                /// Real coordinate
84                pub c0: #intmod_type,
85                /// Imaginary coordinate
86                pub c1: #intmod_type,
87            }
88
89            impl #struct_name {
90                pub const fn new(c0: #intmod_type, c1: #intmod_type) -> Self {
91                    Self { c0, c1 }
92                }
93            }
94
95            impl #struct_name {
96                // Zero element (i.e. additive identity)
97                pub const ZERO: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
98
99                // One element (i.e. multiplicative identity)
100                pub const ONE: Self = Self::new(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, <#intmod_type as openvm_algebra_guest::IntMod>::ZERO);
101
102                pub fn neg_assign(&mut self) {
103                    self.c0.neg_assign();
104                    self.c1.neg_assign();
105                }
106
107                /// Implementation of AddAssign.
108                #[inline(always)]
109                fn add_assign_impl(&mut self, other: &Self) {
110                    #[cfg(not(target_os = "zkvm"))]
111                    {
112                        self.c0 += &other.c0;
113                        self.c1 += &other.c1;
114                    }
115                    #[cfg(target_os = "zkvm")]
116                    {
117                        Self::set_up_once();
118                        unsafe {
119                            #complex_add_extern_func(
120                                self as *mut Self as usize,
121                                self as *const Self as usize,
122                                other as *const Self as usize
123                            );
124                        }
125                    }
126                }
127
128                /// Implementation of SubAssign.
129                #[inline(always)]
130                fn sub_assign_impl(&mut self, other: &Self) {
131                    #[cfg(not(target_os = "zkvm"))]
132                    {
133                        self.c0 -= &other.c0;
134                        self.c1 -= &other.c1;
135                    }
136                    #[cfg(target_os = "zkvm")]
137                    {
138                        Self::set_up_once();
139                        unsafe {
140                            #complex_sub_extern_func(
141                                self as *mut Self as usize,
142                                self as *const Self as usize,
143                                other as *const Self as usize
144                            );
145                        }
146                    }
147                }
148
149                /// Implementation of MulAssign.
150                #[inline(always)]
151                fn mul_assign_impl(&mut self, other: &Self) {
152                    #[cfg(not(target_os = "zkvm"))]
153                    {
154                        let (c0, c1) = (&self.c0, &self.c1);
155                        let (d0, d1) = (&other.c0, &other.c1);
156                        *self = Self::new(
157                            c0.clone() * d0 - c1.clone() * d1,
158                            c0.clone() * d1 + c1.clone() * d0,
159                        );
160                    }
161                    #[cfg(target_os = "zkvm")]
162                    {
163                        Self::set_up_once();
164                        unsafe {
165                            #complex_mul_extern_func(
166                                self as *mut Self as usize,
167                                self as *const Self as usize,
168                                other as *const Self as usize
169                            );
170                        }
171                    }
172                }
173
174                /// Implementation of DivAssignUnsafe.
175                #[inline(always)]
176                fn div_assign_unsafe_impl(&mut self, other: &Self) {
177                    #[cfg(not(target_os = "zkvm"))]
178                    {
179                        let (c0, c1) = (&self.c0, &self.c1);
180                        let (d0, d1) = (&other.c0, &other.c1);
181                        let denom = openvm_algebra_guest::DivUnsafe::div_unsafe(<#intmod_type as openvm_algebra_guest::IntMod>::ONE, d0.square() + d1.square());
182                        *self = Self::new(
183                            denom.clone() * (c0.clone() * d0 + c1.clone() * d1),
184                            denom * &(c1.clone() * d0 - c0.clone() * d1),
185                        );
186                    }
187                    #[cfg(target_os = "zkvm")]
188                    {
189                        Self::set_up_once();
190                        unsafe {
191                            #complex_div_extern_func(
192                                self as *mut Self as usize,
193                                self as *const Self as usize,
194                                other as *const Self as usize
195                            );
196                        }
197                    }
198                }
199
200                /// Implementation of Add that doesn't cause zkvm to use an additional store.
201                fn add_refs_impl(&self, other: &Self) -> Self {
202                    #[cfg(not(target_os = "zkvm"))]
203                    {
204                        let mut res = self.clone();
205                        res.add_assign_impl(other);
206                        res
207                    }
208                    #[cfg(target_os = "zkvm")]
209                    {
210                        Self::set_up_once();
211                        let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
212                        unsafe {
213                            #complex_add_extern_func(
214                                uninit.as_mut_ptr() as usize,
215                                self as *const Self as usize,
216                                other as *const Self as usize
217                            );
218                        }
219                        unsafe { uninit.assume_init() }
220                    }
221                }
222
223                /// Implementation of Sub that doesn't cause zkvm to use an additional store.
224                #[inline(always)]
225                fn sub_refs_impl(&self, other: &Self) -> Self {
226                    #[cfg(not(target_os = "zkvm"))]
227                    {
228                        let mut res = self.clone();
229                        res.sub_assign_impl(other);
230                        res
231                    }
232                    #[cfg(target_os = "zkvm")]
233                    {
234                        Self::set_up_once();
235                        let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
236                        unsafe {
237                            #complex_sub_extern_func(
238                                uninit.as_mut_ptr() as usize,
239                                self as *const Self as usize,
240                                other as *const Self as usize
241                            );
242                        }
243                        unsafe { uninit.assume_init() }
244                    }
245                }
246
247                /// Implementation of Mul that doesn't cause zkvm to use an additional store.
248                ///
249                /// SAFETY: dst_ptr must be pointer for `&mut Self`.
250                /// It will only be written to at the end of the function.
251                #[inline(always)]
252                unsafe fn mul_refs_impl(&self, other: &Self, dst_ptr: *mut Self) {
253                    #[cfg(not(target_os = "zkvm"))]
254                    {
255                        let mut res = self.clone();
256                        res.mul_assign_impl(other);
257                        let dst = unsafe { &mut *dst_ptr };
258                        *dst = res;
259                    }
260                    #[cfg(target_os = "zkvm")]
261                    {
262                        Self::set_up_once();
263                        unsafe {
264                            #complex_mul_extern_func(
265                                dst_ptr as usize,
266                                self as *const Self as usize,
267                                other as *const Self as usize
268                            );
269                        }
270                    }
271                }
272
273                /// Implementation of DivUnsafe that doesn't cause zkvm to use an additional store.
274                #[inline(always)]
275                fn div_unsafe_refs_impl(&self, other: &Self) -> Self {
276                    #[cfg(not(target_os = "zkvm"))]
277                    {
278                        let mut res = self.clone();
279                        res.div_assign_unsafe_impl(other);
280                        res
281                    }
282                    #[cfg(target_os = "zkvm")]
283                    {
284                        Self::set_up_once();
285                        let mut uninit: core::mem::MaybeUninit<Self> = core::mem::MaybeUninit::uninit();
286                        unsafe {
287                            #complex_div_extern_func(
288                                uninit.as_mut_ptr() as usize,
289                                self as *const Self as usize,
290                                other as *const Self as usize
291                            );
292                        }
293                        unsafe { uninit.assume_init() }
294                    }
295                }
296
297                // Helper function to call the setup instruction on first use
298                fn set_up_once() {
299                    static is_setup: ::openvm_algebra_guest::once_cell::race::OnceBool = ::openvm_algebra_guest::once_cell::race::OnceBool::new();
300                    is_setup.get_or_init(|| {
301                        unsafe { #complex_setup_extern_func(); }
302                        true
303                    });
304                }
305            }
306
307            impl openvm_algebra_guest::field::ComplexConjugate for #struct_name {
308                fn conjugate(self) -> Self {
309                    Self {
310                        c0: self.c0,
311                        c1: -self.c1,
312                    }
313                }
314
315                fn conjugate_assign(&mut self) {
316                    self.c1.neg_assign();
317                }
318            }
319
320            impl<'a> core::ops::AddAssign<&'a #struct_name> for #struct_name {
321                #[inline(always)]
322                fn add_assign(&mut self, other: &'a #struct_name) {
323                    self.add_assign_impl(other);
324                }
325            }
326
327            impl core::ops::AddAssign for #struct_name {
328                #[inline(always)]
329                fn add_assign(&mut self, other: Self) {
330                    self.add_assign_impl(&other);
331                }
332            }
333
334            impl core::ops::Add for #struct_name {
335                type Output = Self;
336                #[inline(always)]
337                fn add(mut self, other: Self) -> Self::Output {
338                    self += other;
339                    self
340                }
341            }
342
343            impl<'a> core::ops::Add<&'a #struct_name> for #struct_name {
344                type Output = Self;
345                #[inline(always)]
346                fn add(mut self, other: &'a #struct_name) -> Self::Output {
347                    self += other;
348                    self
349                }
350            }
351
352            impl<'a> core::ops::Add<&'a #struct_name> for &#struct_name {
353                type Output = #struct_name;
354                #[inline(always)]
355                fn add(self, other: &'a #struct_name) -> Self::Output {
356                    self.add_refs_impl(other)
357                }
358            }
359
360            impl<'a> core::ops::SubAssign<&'a #struct_name> for #struct_name {
361                #[inline(always)]
362                fn sub_assign(&mut self, other: &'a #struct_name) {
363                    self.sub_assign_impl(other);
364                }
365            }
366
367            impl core::ops::SubAssign for #struct_name {
368                #[inline(always)]
369                fn sub_assign(&mut self, other: Self) {
370                    self.sub_assign_impl(&other);
371                }
372            }
373
374            impl core::ops::Sub for #struct_name {
375                type Output = Self;
376                #[inline(always)]
377                fn sub(mut self, other: Self) -> Self::Output {
378                    self -= other;
379                    self
380                }
381            }
382
383            impl<'a> core::ops::Sub<&'a #struct_name> for #struct_name {
384                type Output = Self;
385                #[inline(always)]
386                fn sub(mut self, other: &'a #struct_name) -> Self::Output {
387                    self -= other;
388                    self
389                }
390            }
391
392            impl<'a> core::ops::Sub<&'a #struct_name> for &#struct_name {
393                type Output = #struct_name;
394                #[inline(always)]
395                fn sub(self, other: &'a #struct_name) -> Self::Output {
396                    self.sub_refs_impl(other)
397                }
398            }
399
400            impl<'a> core::ops::MulAssign<&'a #struct_name> for #struct_name {
401                #[inline(always)]
402                fn mul_assign(&mut self, other: &'a #struct_name) {
403                    self.mul_assign_impl(other);
404                }
405            }
406
407            impl core::ops::MulAssign for #struct_name {
408                #[inline(always)]
409                fn mul_assign(&mut self, other: Self) {
410                    self.mul_assign_impl(&other);
411                }
412            }
413
414            impl core::ops::Mul for #struct_name {
415                type Output = Self;
416                #[inline(always)]
417                fn mul(mut self, other: Self) -> Self::Output {
418                    self *= other;
419                    self
420                }
421            }
422
423            impl<'a> core::ops::Mul<&'a #struct_name> for #struct_name {
424                type Output = Self;
425                #[inline(always)]
426                fn mul(mut self, other: &'a #struct_name) -> Self::Output {
427                    self *= other;
428                    self
429                }
430            }
431
432            impl<'a> core::ops::Mul<&'a #struct_name> for &'a #struct_name {
433                type Output = #struct_name;
434                #[inline(always)]
435                fn mul(self, other: &'a #struct_name) -> Self::Output {
436                    let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
437                    unsafe {
438                        self.mul_refs_impl(other, uninit.as_mut_ptr());
439                        uninit.assume_init()
440                    }
441                }
442            }
443
444            impl<'a> openvm_algebra_guest::DivAssignUnsafe<&'a #struct_name> for #struct_name {
445                #[inline(always)]
446                fn div_assign_unsafe(&mut self, other: &'a #struct_name) {
447                    self.div_assign_unsafe_impl(other);
448                }
449            }
450
451            impl openvm_algebra_guest::DivAssignUnsafe for #struct_name {
452                #[inline(always)]
453                fn div_assign_unsafe(&mut self, other: Self) {
454                    self.div_assign_unsafe_impl(&other);
455                }
456            }
457
458            impl openvm_algebra_guest::DivUnsafe for #struct_name {
459                type Output = Self;
460                #[inline(always)]
461                fn div_unsafe(mut self, other: Self) -> Self::Output {
462                    self = self.div_unsafe_refs_impl(&other);
463                    self
464                }
465            }
466
467            impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for #struct_name {
468                type Output = Self;
469                #[inline(always)]
470                fn div_unsafe(mut self, other: &'a #struct_name) -> Self::Output {
471                    self = self.div_unsafe_refs_impl(other);
472                    self
473                }
474            }
475
476            impl<'a> openvm_algebra_guest::DivUnsafe<&'a #struct_name> for &#struct_name {
477                type Output = #struct_name;
478                #[inline(always)]
479                fn div_unsafe(self, other: &'a #struct_name) -> Self::Output {
480                    self.div_unsafe_refs_impl(other)
481                }
482            }
483
484            impl<'a> core::iter::Sum<&'a #struct_name> for #struct_name {
485                fn sum<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
486                    iter.fold(Self::ZERO, |acc, x| &acc + x)
487                }
488            }
489
490            impl core::iter::Sum for #struct_name {
491                fn sum<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
492                    iter.fold(Self::ZERO, |acc, x| &acc + &x)
493                }
494            }
495
496            impl<'a> core::iter::Product<&'a #struct_name> for #struct_name {
497                fn product<I: core::iter::Iterator<Item = &'a #struct_name>>(iter: I) -> Self {
498                    iter.fold(Self::ONE, |acc, x| &acc * x)
499                }
500            }
501
502            impl core::iter::Product for #struct_name {
503                fn product<I: core::iter::Iterator<Item = Self>>(iter: I) -> Self {
504                    iter.fold(Self::ONE, |acc, x| &acc * &x)
505                }
506            }
507
508            impl core::ops::Neg for #struct_name {
509                type Output = #struct_name;
510                fn neg(self) -> Self::Output {
511                    Self::ZERO - &self
512                }
513            }
514
515            impl core::ops::Neg for &#struct_name {
516                type Output = #struct_name;
517                fn neg(self) -> Self::Output {
518                    #struct_name::ZERO - self
519                }
520            }
521
522            impl core::fmt::Debug for #struct_name {
523                fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
524                    write!(f, "{:?} + {:?} * u", self.c0, self.c1)
525                }
526            }
527        });
528        output.push(result);
529    }
530
531    TokenStream::from_iter(output)
532}
533
534// Override the MacroArgs struct to use LitStr for item names instead of Ident.
535// This removes the need to import the complex struct when using the complex_init macro.
536struct ComplexInitArgs {
537    pub items: Vec<ComplexInitItem>,
538}
539
540struct ComplexInitItem {
541    pub name: LitStr,
542    pub params: Punctuated<Param, Token![,]>,
543}
544
545impl Parse for ComplexInitArgs {
546    fn parse(input: ParseStream) -> syn::Result<Self> {
547        Ok(ComplexInitArgs {
548            items: input
549                .parse_terminated(ComplexInitItem::parse, Token![,])?
550                .into_iter()
551                .collect(),
552        })
553    }
554}
555
556impl Parse for ComplexInitItem {
557    fn parse(input: ParseStream) -> syn::Result<Self> {
558        let name = input.parse()?;
559        let content;
560        syn::braced!(content in input);
561        let params = content.parse_terminated(Param::parse, Token![,])?;
562        Ok(ComplexInitItem { name, params })
563    }
564}
565
566/// This macro is used to initialize the complex extension fields.
567/// It must be called after `moduli_init!` is called.
568///
569/// Usage:
570/// ```rust
571/// moduli_init!("998244353", "1000000007");
572///
573/// complex_init!(Complex2 { mod_idx = 1 }, Complex1 { mod_idx = 0 });
574/// ```
575/// In particular, the order of complex types in the macro doesn't have to match the order of moduli
576/// in `moduli_init!`, but they should be accompanied by the `mod_idx` corresponding to the order in
577/// the `moduli_init!` macro (not `moduli_declare!`).
578#[proc_macro]
579pub fn complex_init(input: TokenStream) -> TokenStream {
580    let ComplexInitArgs { items } = parse_macro_input!(input as ComplexInitArgs);
581
582    let mut externs = Vec::new();
583
584    let span = proc_macro::Span::call_site();
585
586    for (complex_idx, item) in items.into_iter().enumerate() {
587        let struct_name = item.name.value();
588        let struct_name = syn::Ident::new(&struct_name, span.into());
589        let mut intmod_idx: Option<usize> = None;
590        for param in item.params {
591            match param.name.to_string().as_str() {
592                "mod_idx" => {
593                    if let syn::Expr::Lit(syn::ExprLit {
594                        lit: syn::Lit::Int(int),
595                        ..
596                    }) = param.value
597                    {
598                        intmod_idx = Some(int.base10_parse::<usize>().unwrap());
599                    } else {
600                        return syn::Error::new_spanned(param.value, "Expected usize")
601                            .to_compile_error()
602                            .into();
603                    }
604                }
605                _ => {
606                    panic!("Unknown parameter {}", param.name);
607                }
608            }
609        }
610        let mod_idx = intmod_idx.expect("mod_idx is required");
611
612        println!(
613            "[init] complex #{} = {} (mod_idx = {})",
614            complex_idx, struct_name, mod_idx
615        );
616
617        for op_type in ["add", "sub", "mul", "div"] {
618            let func_name = syn::Ident::new(
619                &format!("complex_{}_extern_func_{}", op_type, struct_name),
620                span.into(),
621            );
622            let mut chars = op_type.chars().collect::<Vec<_>>();
623            chars[0] = chars[0].to_ascii_uppercase();
624            let local_opcode = syn::Ident::new(&chars.iter().collect::<String>(), span.into());
625            externs.push(quote::quote_spanned! { span.into() =>
626                #[no_mangle]
627                extern "C" fn #func_name(rd: usize, rs1: usize, rs2: usize) {
628                    openvm::platform::custom_insn_r!(
629                        opcode = openvm_algebra_guest::OPCODE,
630                        funct3 = openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
631                        funct7 = openvm_algebra_guest::ComplexExtFieldBaseFunct7::#local_opcode as usize
632                            + #complex_idx * (openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
633                        rd = In rd,
634                        rs1 = In rs1,
635                        rs2 = In rs2
636                    )
637                }
638            });
639        }
640
641        let setup_extern_func = syn::Ident::new(
642            &format!("complex_setup_extern_func_{}", struct_name),
643            span.into(),
644        );
645
646        externs.push(quote::quote_spanned! { span.into() =>
647            #[no_mangle]
648            extern "C" fn #setup_extern_func() {
649                #[cfg(target_os = "zkvm")]
650                {
651                    use super::openvm_intrinsics_meta_do_not_type_this_by_yourself::{two_modular_limbs_list, limb_list_borders};
652                    let two_modulus_bytes = &two_modular_limbs_list[limb_list_borders[#mod_idx]..limb_list_borders[#mod_idx + 1]];
653
654                    // We are going to use the numeric representation of the `rs2` register to distinguish the chip to setup.
655                    // The transpiler will transform this instruction, based on whether `rs2` is `x0` or `x1`, into a `SETUP_ADDSUB` or `SETUP_MULDIV` instruction.
656                    let mut uninit: core::mem::MaybeUninit<[u8; limb_list_borders[#mod_idx + 1] - limb_list_borders[#mod_idx]]> = core::mem::MaybeUninit::uninit();
657                    openvm::platform::custom_insn_r!(
658                        opcode = ::openvm_algebra_guest::OPCODE,
659                        funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
660                        funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
661                            + #complex_idx
662                                * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
663                        rd = In uninit.as_mut_ptr(),
664                        rs1 = In two_modulus_bytes.as_ptr(),
665                        rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_ADDMOD
666                    );
667                    openvm::platform::custom_insn_r!(
668                        opcode = ::openvm_algebra_guest::OPCODE,
669                        funct3 = ::openvm_algebra_guest::COMPLEX_EXT_FIELD_FUNCT3,
670                        funct7 = ::openvm_algebra_guest::ComplexExtFieldBaseFunct7::Setup as usize
671                            + #complex_idx
672                                * (::openvm_algebra_guest::ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize),
673                        rd = In uninit.as_mut_ptr(),
674                        rs1 = In two_modulus_bytes.as_ptr(),
675                        rs2 = Const "x1" // will be parsed as 1 and therefore transpiled to SETUP_MULDIV
676                    );
677                }
678            }
679        });
680    }
681
682    TokenStream::from(quote::quote_spanned! { span.into() =>
683        #[allow(non_snake_case)]
684        #[cfg(target_os = "zkvm")]
685        mod openvm_intrinsics_ffi_complex {
686            #(#externs)*
687        }
688    })
689}
690
691struct ComplexSimpleItem {
692    items: Vec<Path>,
693}
694
695impl Parse for ComplexSimpleItem {
696    fn parse(input: ParseStream) -> syn::Result<Self> {
697        let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
698        Ok(Self {
699            items: items
700                .into_iter()
701                .map(|e| {
702                    if let Expr::Path(p) = e {
703                        p.path
704                    } else {
705                        panic!("expected path");
706                    }
707                })
708                .collect(),
709        })
710    }
711}
712
713#[proc_macro]
714pub fn complex_impl_field(input: TokenStream) -> TokenStream {
715    let ComplexSimpleItem { items } = parse_macro_input!(input as ComplexSimpleItem);
716
717    let mut output = Vec::new();
718
719    let span = proc_macro::Span::call_site();
720
721    for item in items.into_iter() {
722        let str_path = item
723            .segments
724            .iter()
725            .map(|x| x.ident.to_string())
726            .collect::<Vec<_>>()
727            .join("_");
728        let struct_name = syn::Ident::new(&str_path, span.into());
729
730        output.push(quote::quote_spanned! { span.into() =>
731            impl openvm_algebra_guest::field::Field for #struct_name {
732                type SelfRef<'a>
733                    = &'a Self
734                where
735                    Self: 'a;
736
737                const ZERO: Self = Self::ZERO;
738                const ONE: Self = Self::ONE;
739
740                fn double_assign(&mut self) {
741                    openvm_algebra_guest::field::Field::double_assign(&mut self.c0);
742                    openvm_algebra_guest::field::Field::double_assign(&mut self.c1);
743                }
744
745                fn square_assign(&mut self) {
746                    unsafe {
747                        self.mul_refs_impl(self, self as *const Self as *mut Self);
748                    }
749                }
750            }
751        });
752    }
753
754    TokenStream::from(quote::quote_spanned! { span.into() =>
755        #(#output)*
756    })
757}