openvm_circuit_derive/
lib.rs

1extern crate alloc;
2extern crate proc_macro;
3
4use itertools::{multiunzip, Itertools};
5use proc_macro::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{
8    parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DataStruct, Field, Fields,
9    GenericParam, Ident, Meta, Token,
10};
11
12mod common;
13#[cfg(not(feature = "tco"))]
14mod nontco;
15#[cfg(feature = "tco")]
16mod tco;
17
18#[proc_macro_derive(PreflightExecutor)]
19pub fn preflight_executor_derive(input: TokenStream) -> TokenStream {
20    let ast: syn::DeriveInput = syn::parse(input).unwrap();
21
22    let name = &ast.ident;
23    let generics = &ast.generics;
24    let (_, ty_generics, _) = generics.split_for_impl();
25
26    let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
27    let mut new_generics = generics.clone();
28    new_generics.params.push(syn::parse_quote! { RA });
29    let field_ty_generic = generics
30        .params
31        .first()
32        .and_then(|param| match param {
33            GenericParam::Type(type_param) => Some(&type_param.ident),
34            _ => None,
35        })
36        .unwrap_or_else(|| {
37            new_generics.params.push(syn::parse_quote! { F });
38            &default_ty_generic
39        });
40
41    match &ast.data {
42        Data::Struct(inner) => {
43            // Check if the struct has only one unnamed field
44            let inner_ty = match &inner.fields {
45                Fields::Unnamed(fields) => {
46                    if fields.unnamed.len() != 1 {
47                        panic!("Only one unnamed field is supported");
48                    }
49                    fields.unnamed.first().unwrap().ty.clone()
50                }
51                _ => panic!("Only unnamed fields are supported"),
52            };
53            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
54            // crate. Assume F is already generic of the field.
55            let where_clause = new_generics.make_where_clause();
56            where_clause.predicates.push(
57                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> },
58            );
59            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
60            quote! {
61                impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
62                    fn execute(
63                        &self,
64                        state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
65                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
66                    ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
67                        self.0.execute(state, instruction)
68                    }
69
70                    fn get_opcode_name(&self, opcode: usize) -> String {
71                        self.0.get_opcode_name(opcode)
72                    }
73                }
74            }
75            .into()
76        }
77        Data::Enum(e) => {
78            let variants = e
79                .variants
80                .iter()
81                .map(|variant| {
82                    let variant_name = &variant.ident;
83
84                    let mut fields = variant.fields.iter();
85                    let field = fields.next().unwrap();
86                    assert!(fields.next().is_none(), "Only one field is supported");
87                    (variant_name, field)
88                })
89                .collect::<Vec<_>>();
90            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
91            // crate.
92            let (execute_arms, get_opcode_name_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) =
93                multiunzip(variants.iter().map(|(variant_name, field)| {
94                    let field_ty = &field.ty;
95                    let execute_arm = quote! {
96                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::execute(x, state, instruction)
97                    };
98                    let get_opcode_name_arm = quote! {
99                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::get_opcode_name(x, opcode)
100                    };
101                    let where_predicate = syn::parse_quote! {
102                        #field_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>
103                    };
104                    (execute_arm, get_opcode_name_arm, where_predicate)
105                }));
106            let where_clause = new_generics.make_where_clause();
107            for predicate in where_predicates {
108                where_clause.predicates.push(predicate);
109            }
110            // Don't use these ty_generics because it might have extra "F"
111            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
112            quote! {
113                impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
114                    fn execute(
115                        &self,
116                        state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
117                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
118                    ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
119                        match self {
120                            #(#execute_arms,)*
121                        }
122                    }
123
124                    fn get_opcode_name(&self, opcode: usize) -> String {
125                        match self {
126                            #(#get_opcode_name_arms,)*
127                        }
128                    }
129                }
130            }
131            .into()
132        }
133        Data::Union(_) => unimplemented!("Unions are not supported"),
134    }
135}
136
137#[proc_macro_derive(Executor)]
138pub fn executor_derive(input: TokenStream) -> TokenStream {
139    let ast: syn::DeriveInput = syn::parse(input).unwrap();
140
141    let name = &ast.ident;
142    let generics = &ast.generics;
143    let (impl_generics, ty_generics, _) = generics.split_for_impl();
144
145    match &ast.data {
146        Data::Struct(inner) => {
147            // Check if the struct has only one unnamed field
148            let inner_ty = match &inner.fields {
149                Fields::Unnamed(fields) => {
150                    if fields.unnamed.len() != 1 {
151                        panic!("Only one unnamed field is supported");
152                    }
153                    fields.unnamed.first().unwrap().ty.clone()
154                }
155                _ => panic!("Only unnamed fields are supported"),
156            };
157            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
158            // crate. Assume F is already generic of the field.
159            let mut new_generics = generics.clone();
160            let where_clause = new_generics.make_where_clause();
161            where_clause.predicates.push(
162                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterExecutor<F> },
163            );
164
165            // We use the macro's feature to decide whether to generate the impl or not. This avoids
166            // the target crate needing the "tco" feature defined.
167            #[cfg(feature = "tco")]
168            let handler = quote! {
169                fn handler<Ctx>(
170                    &self,
171                    pc: u32,
172                    inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
173                    data: &mut [u8],
174                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
175                where
176                    Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
177                    self.0.handler(pc, inst, data)
178                }
179            };
180            #[cfg(not(feature = "tco"))]
181            let handler = quote! {};
182
183            quote! {
184                impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<F> for #name #ty_generics #where_clause {
185                    #[inline(always)]
186                    fn pre_compute_size(&self) -> usize {
187                        self.0.pre_compute_size()
188                    }
189                    #[cfg(not(feature = "tco"))]
190                    #[inline(always)]
191                    fn pre_compute<Ctx>(
192                        &self,
193                        pc: u32,
194                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
195                        data: &mut [u8],
196                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
197                    where
198                        Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
199                        self.0.pre_compute(pc, inst, data)
200                    }
201
202                    #handler
203                }
204            }
205            .into()
206        }
207        Data::Enum(e) => {
208            let variants = e
209                .variants
210                .iter()
211                .map(|variant| {
212                    let variant_name = &variant.ident;
213
214                    let mut fields = variant.fields.iter();
215                    let field = fields.next().unwrap();
216                    assert!(fields.next().is_none(), "Only one field is supported");
217                    (variant_name, field)
218                })
219                .collect::<Vec<_>>();
220            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
221            let mut new_generics = generics.clone();
222            let first_ty_generic = ast
223                .generics
224                .params
225                .first()
226                .and_then(|param| match param {
227                    GenericParam::Type(type_param) => Some(&type_param.ident),
228                    _ => None,
229                })
230                .unwrap_or_else(|| {
231                    new_generics.params.push(syn::parse_quote! { F });
232                    &default_ty_generic
233                });
234            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
235            // crate. Assume F is already generic of the field.
236            let (pre_compute_size_arms, pre_compute_arms, _handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
237                let field_ty = &field.ty;
238                let pre_compute_size_arm = quote! {
239                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute_size(x)
240                };
241                let pre_compute_arm = quote! {
242                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute(x, pc, instruction, data)
243                };
244                let handler_arm = quote! {
245                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::handler(x, pc, instruction, data)
246                };
247                let where_predicate = syn::parse_quote! {
248                    #field_ty: ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>
249                };
250                (pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate)
251            }));
252            let where_clause = new_generics.make_where_clause();
253            for predicate in where_predicates {
254                where_clause.predicates.push(predicate);
255            }
256            // We use the macro's feature to decide whether to generate the impl or not. This avoids
257            // the target crate needing the "tco" feature defined.
258            #[cfg(feature = "tco")]
259            let handler = quote! {
260                fn handler<Ctx>(
261                    &self,
262                    pc: u32,
263                    instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
264                    data: &mut [u8],
265                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
266                where
267                    Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
268                    match self {
269                        #(#_handler_arms,)*
270                    }
271                }
272            };
273            #[cfg(not(feature = "tco"))]
274            let handler = quote! {};
275
276            // Don't use these ty_generics because it might have extra "F"
277            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
278
279            quote! {
280                impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
281                    #[inline(always)]
282                    fn pre_compute_size(&self) -> usize {
283                        match self {
284                            #(#pre_compute_size_arms,)*
285                        }
286                    }
287
288                    #[cfg(not(feature = "tco"))]
289                    #[inline(always)]
290                    fn pre_compute<Ctx>(
291                        &self,
292                        pc: u32,
293                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
294                        data: &mut [u8],
295                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
296                    where
297                        Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
298                        match self {
299                            #(#pre_compute_arms,)*
300                        }
301                    }
302                    #handler
303                }
304            }
305            .into()
306        }
307        Data::Union(_) => unimplemented!("Unions are not supported"),
308    }
309}
310
311#[proc_macro_derive(AotExecutor)]
312pub fn aot_executor_derive(input: TokenStream) -> TokenStream {
313    let ast: syn::DeriveInput = syn::parse(input).unwrap();
314
315    let name = &ast.ident;
316    let generics = &ast.generics;
317    let (_, ty_generics, _) = generics.split_for_impl();
318
319    match &ast.data {
320        Data::Struct(inner) => {
321            let inner_ty = match &inner.fields {
322                Fields::Unnamed(fields) => {
323                    if fields.unnamed.len() != 1 {
324                        panic!("Only one unnamed field is supported");
325                    }
326                    fields.unnamed.first().unwrap().ty.clone()
327                }
328                _ => panic!("Only unnamed fields are supported"),
329            };
330            let mut new_generics = generics.clone();
331            let where_clause = new_generics.make_where_clause();
332            where_clause
333                .predicates
334                .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotExecutor<F> });
335            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
336
337            quote! {
338                #[cfg(feature = "aot")]
339                impl #impl_generics ::openvm_circuit::arch::AotExecutor<F> for #name #ty_generics #where_clause {
340                    #[inline(always)]
341                    fn is_aot_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
342                        self.0.is_aot_supported(inst)
343                    }
344
345                    fn generate_x86_asm(
346                        &self,
347                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
348                        pc: u32,
349                    ) -> ::std::result::Result<
350                        ::std::string::String,
351                        ::openvm_circuit::arch::AotError,
352                    > {
353                        self.0.generate_x86_asm(inst, pc)
354                    }
355                }
356            }
357            .into()
358        }
359        Data::Enum(e) => {
360            let variants = e
361                .variants
362                .iter()
363                .map(|variant| {
364                    let variant_name = &variant.ident;
365                    let mut fields = variant.fields.iter();
366                    let field = fields.next().unwrap();
367                    assert!(fields.next().is_none(), "Only one field is supported");
368                    (variant_name, field)
369                })
370                .collect::<Vec<_>>();
371            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
372            let mut new_generics = generics.clone();
373            let first_ty_generic = ast
374                .generics
375                .params
376                .first()
377                .and_then(|param| match param {
378                    GenericParam::Type(type_param) => Some(&type_param.ident),
379                    _ => None,
380                })
381                .unwrap_or_else(|| {
382                    new_generics.params.push(syn::parse_quote! { F });
383                    &default_ty_generic
384                });
385            let (
386                is_aot_supported_arms,
387                generate_x86_asm_arms,
388                where_predicates,
389            ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
390                |(variant_name, field)| {
391                    let field_ty = &field.ty;
392                    let is_aot_supported_arm = quote! {
393                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::is_aot_supported(x, inst)
394                    };
395                    let generate_x86_asm_arm = quote! {
396                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::generate_x86_asm(
397                            x,
398                            inst,
399                            pc,
400                        )
401                    };
402                    let where_predicate =
403                        syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotExecutor<#first_ty_generic> };
404                    (
405                        is_aot_supported_arm,
406                        generate_x86_asm_arm,
407                        where_predicate,
408                    )
409                },
410            ));
411            let where_clause = new_generics.make_where_clause();
412            for predicate in where_predicates {
413                where_clause.predicates.push(predicate);
414            }
415            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
416
417            quote! {
418                #[cfg(feature = "aot")]
419                impl #impl_generics ::openvm_circuit::arch::AotExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
420                    #[inline(always)]
421                    fn is_aot_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
422                        match self {
423                            #(#is_aot_supported_arms,)*
424                        }
425                    }
426
427                    fn generate_x86_asm(
428                        &self,
429                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
430                        pc: u32,
431                    ) -> ::std::result::Result<
432                        ::std::string::String,
433                        ::openvm_circuit::arch::AotError,
434                    > {
435                        match self {
436                            #(#generate_x86_asm_arms,)*
437                        }
438                    }
439                }
440            }
441            .into()
442        }
443        Data::Union(_) => unimplemented!("Unions are not supported"),
444    }
445}
446
447#[proc_macro_derive(MeteredExecutor)]
448pub fn metered_executor_derive(input: TokenStream) -> TokenStream {
449    let ast: syn::DeriveInput = syn::parse(input).unwrap();
450
451    let name = &ast.ident;
452    let generics = &ast.generics;
453    let (impl_generics, ty_generics, _) = generics.split_for_impl();
454
455    match &ast.data {
456        Data::Struct(inner) => {
457            // Check if the struct has only one unnamed field
458            let inner_ty = match &inner.fields {
459                Fields::Unnamed(fields) => {
460                    if fields.unnamed.len() != 1 {
461                        panic!("Only one unnamed field is supported");
462                    }
463                    fields.unnamed.first().unwrap().ty.clone()
464                }
465                _ => panic!("Only unnamed fields are supported"),
466            };
467            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
468            // crate.
469            let mut new_generics = generics.clone();
470            let where_clause = new_generics.make_where_clause();
471            where_clause
472                .predicates
473                .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<F> });
474
475            // We use the macro's feature to decide whether to generate the impl or not. This avoids
476            // the target crate needing the "tco" feature defined.
477            #[cfg(feature = "tco")]
478            let metered_handler = quote! {
479                fn metered_handler<Ctx>(
480                    &self,
481                    chip_idx: usize,
482                    pc: u32,
483                    inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
484                    data: &mut [u8],
485                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
486                where
487                    Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
488                    self.0.metered_handler(chip_idx, pc, inst, data)
489                }
490            };
491            #[cfg(not(feature = "tco"))]
492            let metered_handler = quote! {};
493
494            quote! {
495                impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<F> for #name #ty_generics #where_clause {
496                    #[inline(always)]
497                    fn metered_pre_compute_size(&self) -> usize {
498                        self.0.metered_pre_compute_size()
499                    }
500                    #[cfg(not(feature = "tco"))]
501                    #[inline(always)]
502                    fn metered_pre_compute<Ctx>(
503                        &self,
504                        chip_idx: usize,
505                        pc: u32,
506                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
507                        data: &mut [u8],
508                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
509                    where
510                        Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
511                        self.0.metered_pre_compute(chip_idx, pc, inst, data)
512                    }
513                    #metered_handler
514                }
515            }
516                .into()
517        }
518        Data::Enum(e) => {
519            let variants = e
520                .variants
521                .iter()
522                .map(|variant| {
523                    let variant_name = &variant.ident;
524
525                    let mut fields = variant.fields.iter();
526                    let field = fields.next().unwrap();
527                    assert!(fields.next().is_none(), "Only one field is supported");
528                    (variant_name, field)
529                })
530                .collect::<Vec<_>>();
531            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
532            let mut new_generics = generics.clone();
533            let first_ty_generic = ast
534                .generics
535                .params
536                .first()
537                .and_then(|param| match param {
538                    GenericParam::Type(type_param) => Some(&type_param.ident),
539                    _ => None,
540                })
541                .unwrap_or_else(|| {
542                    new_generics.params.push(syn::parse_quote! { F });
543                    &default_ty_generic
544                });
545            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
546            // crate. Assume F is already generic of the field.
547            let (pre_compute_size_arms, metered_pre_compute_arms, _metered_handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
548                let field_ty = &field.ty;
549                let pre_compute_size_arm = quote! {
550                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x)
551                };
552                let metered_pre_compute_arm = quote! {
553                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data)
554                };
555                let metered_handler_arm = quote! {
556                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_handler(x, chip_idx, pc, instruction, data)
557                };
558                let where_predicate = syn::parse_quote! {
559                    #field_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>
560                };
561                (pre_compute_size_arm, metered_pre_compute_arm, metered_handler_arm, where_predicate)
562            }));
563            let where_clause = new_generics.make_where_clause();
564            for predicate in where_predicates {
565                where_clause.predicates.push(predicate);
566            }
567            // Don't use these ty_generics because it might have extra "F"
568            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
569
570            // We use the macro's feature to decide whether to generate the impl or not. This avoids
571            // the target crate needing the "tco" feature defined.
572            #[cfg(feature = "tco")]
573            let metered_handler = quote! {
574                fn metered_handler<Ctx>(
575                    &self,
576                    chip_idx: usize,
577                    pc: u32,
578                    instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
579                    data: &mut [u8],
580                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
581                where
582                    Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait,
583                {
584                    match self {
585                        #(#_metered_handler_arms,)*
586                    }
587                }
588            };
589            #[cfg(not(feature = "tco"))]
590            let metered_handler = quote! {};
591
592            quote! {
593                impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
594                    #[inline(always)]
595                    fn metered_pre_compute_size(&self) -> usize {
596                        match self {
597                            #(#pre_compute_size_arms,)*
598                        }
599                    }
600
601                    #[cfg(not(feature = "tco"))]
602                    #[inline(always)]
603                    fn metered_pre_compute<Ctx>(
604                        &self,
605                        chip_idx: usize,
606                        pc: u32,
607                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
608                        data: &mut [u8],
609                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
610                    where
611                        Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
612                        match self {
613                            #(#metered_pre_compute_arms,)*
614                        }
615                    }
616
617                    #metered_handler
618                }
619            }
620                .into()
621        }
622        Data::Union(_) => unimplemented!("Unions are not supported"),
623    }
624}
625
626#[proc_macro_derive(AotMeteredExecutor)]
627pub fn aot_metered_executor_derive(input: TokenStream) -> TokenStream {
628    let ast: syn::DeriveInput = syn::parse(input).unwrap();
629
630    let name = &ast.ident;
631    let generics = &ast.generics;
632    let (_, ty_generics, _) = generics.split_for_impl();
633
634    match &ast.data {
635        Data::Struct(inner) => {
636            let inner_ty = match &inner.fields {
637                Fields::Unnamed(fields) => {
638                    if fields.unnamed.len() != 1 {
639                        panic!("Only one unnamed field is supported");
640                    }
641                    fields.unnamed.first().unwrap().ty.clone()
642                }
643                _ => panic!("Only unnamed fields are supported"),
644            };
645            let mut new_generics = generics.clone();
646            let where_clause = new_generics.make_where_clause();
647            where_clause.predicates.push(
648                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotMeteredExecutor<F> },
649            );
650            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
651
652            quote! {
653                #[cfg(feature = "aot")]
654                impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<F> for #name #ty_generics #where_clause {
655                    #[inline(always)]
656                    fn is_aot_metered_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
657                        self.0.is_aot_metered_supported(inst)
658                    }
659
660                    fn generate_x86_metered_asm(
661                        &self,
662                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
663                        pc: u32,
664                        chip_idx: usize,
665                        config: &::openvm_circuit::arch::SystemConfig,
666                    ) -> ::std::result::Result<
667                        ::std::string::String,
668                        ::openvm_circuit::arch::AotError,
669                    > {
670                        self.0.generate_x86_metered_asm(inst, pc, chip_idx, config)
671                    }
672                }
673            }
674            .into()
675        }
676        Data::Enum(e) => {
677            let variants = e
678                .variants
679                .iter()
680                .map(|variant| {
681                    let variant_name = &variant.ident;
682                    let mut fields = variant.fields.iter();
683                    let field = fields.next().unwrap();
684                    assert!(fields.next().is_none(), "Only one field is supported");
685                    (variant_name, field)
686                })
687                .collect::<Vec<_>>();
688            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
689            let mut new_generics = generics.clone();
690            let first_ty_generic = ast
691                .generics
692                .params
693                .first()
694                .and_then(|param| match param {
695                    GenericParam::Type(type_param) => Some(&type_param.ident),
696                    _ => None,
697                })
698                .unwrap_or_else(|| {
699                    new_generics.params.push(syn::parse_quote! { F });
700                    &default_ty_generic
701                });
702            let (
703                is_aot_metered_supported_arms,
704                generate_x86_metered_asm_arms,
705                where_predicates,
706            ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
707                |(variant_name, field)| {
708                    let field_ty = &field.ty;
709                    let is_aot_metered_supported_arm = quote! {
710                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::is_aot_metered_supported(x, inst)
711                    };
712                    let generate_x86_metered_asm_arm = quote! {
713                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::generate_x86_metered_asm(
714                            x,
715                            inst,
716                            pc,
717                            chip_idx,
718                            config,
719                        )
720                    };
721                    let where_predicate =
722                        syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> };
723                    (
724                        is_aot_metered_supported_arm,
725                        generate_x86_metered_asm_arm,
726                        where_predicate,
727                    )
728                },
729            ));
730            let where_clause = new_generics.make_where_clause();
731            for predicate in where_predicates {
732                where_clause.predicates.push(predicate);
733            }
734            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
735
736            quote! {
737                #[cfg(feature = "aot")]
738                impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
739                    #[inline(always)]
740                    fn is_aot_metered_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
741                        match self {
742                            #(#is_aot_metered_supported_arms,)*
743                        }
744                    }
745
746                    fn generate_x86_metered_asm(
747                        &self,
748                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
749                        pc: u32,
750                        chip_idx: usize,
751                        config: &::openvm_circuit::arch::SystemConfig,
752                    ) -> ::std::result::Result<
753                        ::std::string::String,
754                        ::openvm_circuit::arch::AotError,
755                    > {
756                        match self {
757                            #(#generate_x86_metered_asm_arms,)*
758                        }
759                    }
760                }
761            }
762            .into()
763        }
764        Data::Union(_) => unimplemented!("Unions are not supported"),
765    }
766}
767
768/// Derives `AnyEnum` trait on an enum type.
769/// By default an enum arm will just return `self` as `&dyn Any`.
770///
771/// Use the `#[any_enum]` field attribute to specify that the
772/// arm itself implements `AnyEnum` and should call the inner `as_any_kind` method.
773#[proc_macro_derive(AnyEnum, attributes(any_enum))]
774pub fn any_enum_derive(input: TokenStream) -> TokenStream {
775    let ast: syn::DeriveInput = syn::parse(input).unwrap();
776
777    let name = &ast.ident;
778    let generics = &ast.generics;
779    let (impl_generics, ty_generics, _) = generics.split_for_impl();
780
781    match &ast.data {
782        Data::Enum(e) => {
783            let variants = e
784                .variants
785                .iter()
786                .map(|variant| {
787                    let variant_name = &variant.ident;
788
789                    // Check if the variant has #[any_enum] attribute
790                    let is_enum = variant
791                        .attrs
792                        .iter()
793                        .any(|attr| attr.path().is_ident("any_enum"));
794                    let mut fields = variant.fields.iter();
795                    let field = fields.next().unwrap();
796                    assert!(fields.next().is_none(), "Only one field is supported");
797                    (variant_name, field, is_enum)
798                })
799                .collect::<Vec<_>>();
800            let (arms, arms_mut): (Vec<_>, Vec<_>) =
801                variants.iter().map(|(variant_name, field, is_enum)| {
802                    let field_ty = &field.ty;
803
804                    if *is_enum {
805                        // Call the inner trait impl
806                        (quote! {
807                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind(x)
808                        },
809                        quote! {
810                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind_mut(x)
811                        })
812                    } else {
813                        (quote! {
814                            #name::#variant_name(x) => x
815                        },
816                        quote! {
817                            #name::#variant_name(x) => x
818                        })
819                    }
820                }).unzip();
821            quote! {
822                impl #impl_generics ::openvm_circuit::arch::AnyEnum for #name #ty_generics {
823                    fn as_any_kind(&self) -> &dyn std::any::Any {
824                        match self {
825                            #(#arms,)*
826                        }
827                    }
828
829                    fn as_any_kind_mut(&mut self) -> &mut dyn std::any::Any {
830                        match self {
831                            #(#arms_mut,)*
832                        }
833                    }
834                }
835            }
836            .into()
837        }
838        _ => syn::Error::new(name.span(), "Only enums are supported")
839            .to_compile_error()
840            .into(),
841    }
842}
843
844#[proc_macro_derive(VmConfig, attributes(config, extension))]
845pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
846    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
847    let name = &ast.ident;
848
849    match &ast.data {
850        syn::Data::Struct(inner) => match generate_config_traits_impl(name, inner) {
851            Ok(tokens) => tokens,
852            Err(err) => err.to_compile_error().into(),
853        },
854        _ => syn::Error::new(name.span(), "Only structs are supported")
855            .to_compile_error()
856            .into(),
857    }
858}
859
860fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result<TokenStream> {
861    let gen_name_with_uppercase_idents = |ident: &Ident| {
862        let mut name = ident.to_string().chars().collect::<Vec<_>>();
863        assert!(name[0].is_lowercase(), "Field name must not be capitalized");
864        let res_lower = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
865        name[0] = name[0].to_ascii_uppercase();
866        let res_upper = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
867        (res_lower, res_upper)
868    };
869
870    let fields = match &inner.fields {
871        Fields::Named(named) => named.named.iter().collect(),
872        Fields::Unnamed(_) => {
873            return Err(syn::Error::new(
874                name.span(),
875                "Only named fields are supported",
876            ))
877        }
878        Fields::Unit => vec![],
879    };
880
881    let source_field = fields
882        .iter()
883        .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config")))
884        .exactly_one()
885        .map_err(|_| {
886            syn::Error::new(
887                name.span(),
888                "Exactly one field must have the #[config] attribute",
889            )
890        })?;
891    let (source_name, source_name_upper) =
892        gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap());
893
894    let extensions = fields
895        .iter()
896        .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension")))
897        .cloned()
898        .collect::<Vec<_>>();
899
900    let mut executor_enum_fields = Vec::new();
901    let mut create_executors = Vec::new();
902    let mut create_airs = Vec::new();
903    let mut execution_where_predicates: Vec<syn::WherePredicate> = Vec::new();
904    let mut circuit_where_predicates: Vec<syn::WherePredicate> = Vec::new();
905
906    let source_field_ty = source_field.ty.clone();
907
908    for e in extensions.iter() {
909        let (ext_field_name, ext_name_upper) =
910            gen_name_with_uppercase_idents(e.ident.as_ref().expect("field must be named"));
911        let executor_type = parse_executor_type(e, false)?;
912        executor_enum_fields.push(quote! {
913            #[any_enum]
914            #ext_name_upper(#executor_type),
915        });
916        create_executors.push(quote! {
917            let inventory: ::openvm_circuit::arch::ExecutorInventory<Self::Executor> = inventory.extend::<F, _, _>(&self.#ext_field_name)?;
918        });
919        let extension_ty = e.ty.clone();
920        execution_where_predicates.push(parse_quote! {
921            #extension_ty: ::openvm_circuit::arch::VmExecutionExtension<F, Executor = #executor_type>
922        });
923        create_airs.push(quote! {
924            inventory.start_new_extension();
925            ::openvm_circuit::arch::VmCircuitExtension::extend_circuit(&self.#ext_field_name, &mut inventory)?;
926        });
927        circuit_where_predicates.push(parse_quote! {
928            #extension_ty: ::openvm_circuit::arch::VmCircuitExtension<SC>
929        });
930    }
931
932    // The config type always needs <F> due to SystemExecutor
933    let source_executor_type = parse_executor_type(source_field, true)?;
934    execution_where_predicates.push(parse_quote! {
935        #source_field_ty: ::openvm_circuit::arch::VmExecutionConfig<F, Executor = #source_executor_type>
936    });
937    circuit_where_predicates.push(parse_quote! {
938        #source_field_ty: ::openvm_circuit::arch::VmCircuitConfig<SC>
939    });
940    let execution_where_clause = quote! { where #(#execution_where_predicates),* };
941    let circuit_where_clause = quote! { where #(#circuit_where_predicates),* };
942
943    let executor_type = Ident::new(&format!("{name}Executor"), name.span());
944
945    let token_stream = TokenStream::from(quote! {
946        #[derive(
947            Clone,
948            ::derive_more::derive::From,
949            ::openvm_circuit::derive::AnyEnum,
950            ::openvm_circuit::derive::Executor,
951            ::openvm_circuit::derive::MeteredExecutor,
952            ::openvm_circuit::derive::PreflightExecutor,
953        )]
954        #[cfg_attr(feature = "aot", derive(::openvm_circuit::derive::AotExecutor, ::openvm_circuit::derive::AotMeteredExecutor))]
955        pub enum #executor_type<F: openvm_stark_backend::p3_field::Field> {
956            #[any_enum]
957            #source_name_upper(#source_executor_type),
958            #(#executor_enum_fields)*
959        }
960
961        impl<F: openvm_stark_backend::p3_field::Field> ::openvm_circuit::arch::VmExecutionConfig<F> for #name #execution_where_clause {
962            type Executor = #executor_type<F>;
963
964            fn create_executors(
965                &self,
966            ) -> Result<::openvm_circuit::arch::ExecutorInventory<Self::Executor>, ::openvm_circuit::arch::ExecutorInventoryError> {
967                let inventory = self.#source_name.create_executors()?.transmute::<Self::Executor>();
968                #(#create_executors)*
969                Ok(inventory)
970            }
971        }
972
973        impl<SC: openvm_stark_backend::config::StarkGenericConfig> ::openvm_circuit::arch::VmCircuitConfig<SC> for #name #circuit_where_clause {
974            fn create_airs(
975                &self,
976            ) -> Result<::openvm_circuit::arch::AirInventory<SC>, ::openvm_circuit::arch::AirInventoryError> {
977                let mut inventory = self.#source_name.create_airs()?;
978                #(#create_airs)*
979                Ok(inventory)
980            }
981        }
982
983        impl AsRef<SystemConfig> for #name {
984            fn as_ref(&self) -> &SystemConfig {
985                self.#source_name.as_ref()
986            }
987        }
988
989        impl AsMut<SystemConfig> for #name {
990            fn as_mut(&mut self) -> &mut SystemConfig {
991                self.#source_name.as_mut()
992            }
993        }
994    });
995    Ok(token_stream)
996}
997
998// Parse the executor name as either
999// `{type_name}Executor` or whatever the attribute `executor = ` specifies
1000// Also determines whether the executor type needs generic parameters
1001fn parse_executor_type(
1002    f: &Field,
1003    default_needs_generics: bool,
1004) -> syn::Result<proc_macro2::TokenStream> {
1005    // TRACKING ISSUE:
1006    // We cannot just use <e.ty.to_token_stream() as VmExecutionExtension<F>>::Executor because of this: <https://github.com/rust-lang/rust/issues/85576>
1007    let mut executor_type = None;
1008    // Do not unwrap the Result until needed
1009    let executor_name = syn::parse_str::<Ident>(&format!("{}Executor", f.ty.to_token_stream()));
1010
1011    if let Some(attr) = f
1012        .attrs
1013        .iter()
1014        .find(|attr| attr.path().is_ident("extension") || attr.path().is_ident("config"))
1015    {
1016        match attr.meta {
1017            Meta::Path(_) => {}
1018            Meta::NameValue(_) => {
1019                return Err(syn::Error::new(
1020                    f.ty.span(),
1021                    "Only `#[config]`, `#[extension]`, `#[config(...)]` or `#[extension(...)]` formats are supported",
1022                ))
1023            }
1024            _ => {
1025                let nested = attr
1026                    .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
1027                for meta in nested {
1028                    match meta {
1029                        Meta::NameValue(nv) => {
1030                            if nv.path.is_ident("executor") {
1031                                executor_type = match nv.value {
1032                                    syn::Expr::Lit(syn::ExprLit {
1033                                        lit: syn::Lit::Str(lit_str), ..
1034                                    }) => {
1035                                        let executor_type: syn::Type = syn::parse_str(&lit_str.value())?;
1036                                        Some(quote! { #executor_type })
1037                                    },
1038                                    syn::Expr::Path(path) => {
1039                                        // Handle identifier paths like `executor = MyExecutor`
1040                                        Some(path.to_token_stream())
1041                                    },
1042                                    _ => {
1043                                        return Err(syn::Error::new(
1044                                            nv.value.span(),
1045                                            "executor value must be a string literal or identifier"
1046                                        ));
1047                                    }
1048                                };
1049                            } else if nv.path.is_ident("generics") {
1050                                // Parse boolean value for generics
1051                                let value_str = nv.value.to_token_stream().to_string();
1052                                let needs_generics = match value_str.as_str() {
1053                                    "true" => true,
1054                                    "false" => false,
1055                                    _ => return Err(syn::Error::new(
1056                                        nv.value.span(),
1057                                        "generics attribute must be either true or false"
1058                                    ))
1059                                };
1060                                let executor_name = executor_name.clone()?;
1061                                executor_type = Some(if needs_generics {
1062                                    quote! { #executor_name<F> }
1063                                } else {
1064                                    quote! { #executor_name }
1065                                });
1066                            } else {
1067                                return Err(syn::Error::new(nv.span(), "only executor and generics keys are supported"));
1068                            }
1069                        }
1070                        _ => {
1071                            return Err(syn::Error::new(meta.span(), "only name = value format is supported"));
1072                        }
1073                    }
1074                }
1075            }
1076        }
1077    }
1078    if let Some(executor_type) = executor_type {
1079        Ok(executor_type)
1080    } else {
1081        let executor_name = executor_name?;
1082        Ok(if default_needs_generics {
1083            quote! { #executor_name<F> }
1084        } else {
1085            quote! { #executor_name }
1086        })
1087    }
1088}
1089
1090/// An attribute procedural macro for creating TCO (Tail Call Optimization) handlers.
1091///
1092/// This macro generates a handler function that wraps an execute implementation
1093/// with tail call optimization using the `become` keyword. It extracts the generics
1094/// and where clauses from the original function.
1095///
1096/// # Usage
1097///
1098/// Place this attribute above a function definition:
1099/// ```
1100/// #[create_tco_handler]
1101/// unsafe fn execute_e1_impl<F: PrimeField32, CTX, const B_IS_IMM: bool>(
1102///     pre_compute: *const u8,
1103///     state: &mut VmExecState<F, GuestMemory, CTX>,
1104/// ) where
1105///     CTX: ExecutionCtxTrait,
1106/// {
1107///     // function body
1108/// }
1109/// ```
1110///
1111/// This will generate a TCO handler function with the same generics and where clauses.
1112///
1113/// # Safety
1114///
1115/// Do not use this macro if your function wants to terminate execution without error with a
1116/// specific error code. The handler generated by this macro assumes that execution should continue
1117/// unless the execute_impl returns an error. This is done for performance to skip an exit code
1118/// check.
1119#[proc_macro_attribute]
1120pub fn create_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
1121    #[cfg(feature = "tco")]
1122    {
1123        tco::tco_impl(item)
1124    }
1125    #[cfg(not(feature = "tco"))]
1126    {
1127        nontco::nontco_impl(item)
1128    }
1129}