openvm_circuit_derive/
lib.rs

1extern crate alloc;
2extern crate proc_macro;
3
4use itertools::{multiunzip, Itertools};
5use proc_macro::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{
8    parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DataStruct, Field, Fields,
9    GenericParam, Ident, Meta, Token,
10};
11
12mod common;
13#[cfg(not(feature = "tco"))]
14mod nontco;
15#[cfg(feature = "tco")]
16mod tco;
17
18#[proc_macro_derive(PreflightExecutor)]
19pub fn preflight_executor_derive(input: TokenStream) -> TokenStream {
20    let ast: syn::DeriveInput = syn::parse(input).unwrap();
21
22    let name = &ast.ident;
23    let generics = &ast.generics;
24    let (_, ty_generics, _) = generics.split_for_impl();
25
26    let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
27    let mut new_generics = generics.clone();
28    new_generics.params.push(syn::parse_quote! { RA });
29    let field_ty_generic = generics
30        .params
31        .first()
32        .and_then(|param| match param {
33            GenericParam::Type(type_param) => Some(&type_param.ident),
34            _ => None,
35        })
36        .unwrap_or_else(|| {
37            new_generics.params.push(syn::parse_quote! { F });
38            &default_ty_generic
39        });
40
41    match &ast.data {
42        Data::Struct(inner) => {
43            // Check if the struct has only one unnamed field
44            let inner_ty = match &inner.fields {
45                Fields::Unnamed(fields) => {
46                    if fields.unnamed.len() != 1 {
47                        panic!("Only one unnamed field is supported");
48                    }
49                    fields.unnamed.first().unwrap().ty.clone()
50                }
51                _ => panic!("Only unnamed fields are supported"),
52            };
53            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
54            // crate. Assume F is already generic of the field.
55            let where_clause = new_generics.make_where_clause();
56            where_clause.predicates.push(
57                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> },
58            );
59            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
60            quote! {
61                impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
62                    fn execute(
63                        &self,
64                        state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
65                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
66                    ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
67                        self.0.execute(state, instruction)
68                    }
69
70                    fn get_opcode_name(&self, opcode: usize) -> String {
71                        self.0.get_opcode_name(opcode)
72                    }
73                }
74            }
75            .into()
76        }
77        Data::Enum(e) => {
78            let variants = e
79                .variants
80                .iter()
81                .map(|variant| {
82                    let variant_name = &variant.ident;
83
84                    let mut fields = variant.fields.iter();
85                    let field = fields.next().unwrap();
86                    assert!(fields.next().is_none(), "Only one field is supported");
87                    (variant_name, field)
88                })
89                .collect::<Vec<_>>();
90            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
91            // crate.
92            let (execute_arms, get_opcode_name_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) =
93                multiunzip(variants.iter().map(|(variant_name, field)| {
94                    let field_ty = &field.ty;
95                    let execute_arm = quote! {
96                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::execute(x, state, instruction)
97                    };
98                    let get_opcode_name_arm = quote! {
99                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::get_opcode_name(x, opcode)
100                    };
101                    let where_predicate = syn::parse_quote! {
102                        #field_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>
103                    };
104                    (execute_arm, get_opcode_name_arm, where_predicate)
105                }));
106            let where_clause = new_generics.make_where_clause();
107            for predicate in where_predicates {
108                where_clause.predicates.push(predicate);
109            }
110            // Don't use these ty_generics because it might have extra "F"
111            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
112            quote! {
113                impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
114                    fn execute(
115                        &self,
116                        state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
117                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
118                    ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
119                        match self {
120                            #(#execute_arms,)*
121                        }
122                    }
123
124                    fn get_opcode_name(&self, opcode: usize) -> String {
125                        match self {
126                            #(#get_opcode_name_arms,)*
127                        }
128                    }
129                }
130            }
131            .into()
132        }
133        Data::Union(_) => unimplemented!("Unions are not supported"),
134    }
135}
136
137#[proc_macro_derive(Executor)]
138pub fn executor_derive(input: TokenStream) -> TokenStream {
139    let ast: syn::DeriveInput = syn::parse(input).unwrap();
140
141    let name = &ast.ident;
142    let generics = &ast.generics;
143    let (impl_generics, ty_generics, _) = generics.split_for_impl();
144
145    match &ast.data {
146        Data::Struct(inner) => {
147            // Check if the struct has only one unnamed field
148            let inner_ty = match &inner.fields {
149                Fields::Unnamed(fields) => {
150                    if fields.unnamed.len() != 1 {
151                        panic!("Only one unnamed field is supported");
152                    }
153                    fields.unnamed.first().unwrap().ty.clone()
154                }
155                _ => panic!("Only unnamed fields are supported"),
156            };
157            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
158            // crate. Assume F is already generic of the field.
159            let mut new_generics = generics.clone();
160            let where_clause = new_generics.make_where_clause();
161            where_clause.predicates.push(
162                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterExecutor<F> },
163            );
164
165            // We use the macro's feature to decide whether to generate the impl or not. This avoids
166            // the target crate needing the "tco" feature defined.
167            #[cfg(feature = "tco")]
168            let handler = quote! {
169                fn handler<Ctx>(
170                    &self,
171                    pc: u32,
172                    inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
173                    data: &mut [u8],
174                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
175                where
176                    Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
177                    self.0.handler(pc, inst, data)
178                }
179            };
180            #[cfg(not(feature = "tco"))]
181            let handler = quote! {};
182
183            quote! {
184                impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<F> for #name #ty_generics #where_clause {
185                    #[inline(always)]
186                    fn pre_compute_size(&self) -> usize {
187                        self.0.pre_compute_size()
188                    }
189                    #[cfg(not(feature = "tco"))]
190                    #[inline(always)]
191                    fn pre_compute<Ctx>(
192                        &self,
193                        pc: u32,
194                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
195                        data: &mut [u8],
196                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
197                    where
198                        Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
199                        self.0.pre_compute(pc, inst, data)
200                    }
201
202                    #handler
203                }
204            }
205            .into()
206        }
207        Data::Enum(e) => {
208            let variants = e
209                .variants
210                .iter()
211                .map(|variant| {
212                    let variant_name = &variant.ident;
213
214                    let mut fields = variant.fields.iter();
215                    let field = fields.next().unwrap();
216                    assert!(fields.next().is_none(), "Only one field is supported");
217                    (variant_name, field)
218                })
219                .collect::<Vec<_>>();
220            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
221            let mut new_generics = generics.clone();
222            let first_ty_generic = ast
223                .generics
224                .params
225                .first()
226                .and_then(|param| match param {
227                    GenericParam::Type(type_param) => Some(&type_param.ident),
228                    _ => None,
229                })
230                .unwrap_or_else(|| {
231                    new_generics.params.push(syn::parse_quote! { F });
232                    &default_ty_generic
233                });
234            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
235            // crate. Assume F is already generic of the field.
236            let (pre_compute_size_arms, pre_compute_arms, _handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
237                let field_ty = &field.ty;
238                let pre_compute_size_arm = quote! {
239                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute_size(x)
240                };
241                let pre_compute_arm = quote! {
242                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute(x, pc, instruction, data)
243                };
244                let handler_arm = quote! {
245                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::handler(x, pc, instruction, data)
246                };
247                let where_predicate = syn::parse_quote! {
248                    #field_ty: ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>
249                };
250                (pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate)
251            }));
252            let where_clause = new_generics.make_where_clause();
253            for predicate in where_predicates {
254                where_clause.predicates.push(predicate);
255            }
256            // We use the macro's feature to decide whether to generate the impl or not. This avoids
257            // the target crate needing the "tco" feature defined.
258            #[cfg(feature = "tco")]
259            let handler = quote! {
260                fn handler<Ctx>(
261                    &self,
262                    pc: u32,
263                    instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
264                    data: &mut [u8],
265                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
266                where
267                    Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
268                    match self {
269                        #(#_handler_arms,)*
270                    }
271                }
272            };
273            #[cfg(not(feature = "tco"))]
274            let handler = quote! {};
275
276            // Don't use these ty_generics because it might have extra "F"
277            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
278
279            quote! {
280                impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
281                    #[inline(always)]
282                    fn pre_compute_size(&self) -> usize {
283                        match self {
284                            #(#pre_compute_size_arms,)*
285                        }
286                    }
287
288                    #[cfg(not(feature = "tco"))]
289                    #[inline(always)]
290                    fn pre_compute<Ctx>(
291                        &self,
292                        pc: u32,
293                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
294                        data: &mut [u8],
295                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
296                    where
297                        Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
298                        match self {
299                            #(#pre_compute_arms,)*
300                        }
301                    }
302                    #handler
303                }
304            }
305            .into()
306        }
307        Data::Union(_) => unimplemented!("Unions are not supported"),
308    }
309}
310
311#[proc_macro_derive(AotExecutor)]
312pub fn aot_executor_derive(input: TokenStream) -> TokenStream {
313    let ast: syn::DeriveInput = syn::parse(input).unwrap();
314
315    let name = &ast.ident;
316    let generics = &ast.generics;
317    let (_, ty_generics, _) = generics.split_for_impl();
318
319    match &ast.data {
320        Data::Struct(inner) => {
321            let inner_ty = match &inner.fields {
322                Fields::Unnamed(fields) => {
323                    if fields.unnamed.len() != 1 {
324                        panic!("Only one unnamed field is supported");
325                    }
326                    fields.unnamed.first().unwrap().ty.clone()
327                }
328                _ => panic!("Only unnamed fields are supported"),
329            };
330            let mut new_generics = generics.clone();
331            let where_clause = new_generics.make_where_clause();
332            where_clause
333                .predicates
334                .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotExecutor<F> });
335            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
336
337            quote! {
338                #[cfg(feature = "aot")]
339                impl #impl_generics ::openvm_circuit::arch::AotExecutor<F> for #name #ty_generics #where_clause {
340                    #[inline(always)]
341                    fn is_aot_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
342                        self.0.is_aot_supported(inst)
343                    }
344
345                    fn generate_x86_asm(
346                        &self,
347                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
348                        pc: u32,
349                    ) -> ::std::result::Result<
350                        ::std::string::String,
351                        ::openvm_circuit::arch::AotError,
352                    > {
353                        self.0.generate_x86_asm(inst, pc)
354                    }
355                }
356            }
357            .into()
358        }
359        Data::Enum(e) => {
360            let variants = e
361                .variants
362                .iter()
363                .map(|variant| {
364                    let variant_name = &variant.ident;
365                    let mut fields = variant.fields.iter();
366                    let field = fields.next().unwrap();
367                    assert!(fields.next().is_none(), "Only one field is supported");
368                    (variant_name, field)
369                })
370                .collect::<Vec<_>>();
371            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
372            let mut new_generics = generics.clone();
373            let first_ty_generic = ast
374                .generics
375                .params
376                .first()
377                .and_then(|param| match param {
378                    GenericParam::Type(type_param) => Some(&type_param.ident),
379                    _ => None,
380                })
381                .unwrap_or_else(|| {
382                    new_generics.params.push(syn::parse_quote! { F });
383                    &default_ty_generic
384                });
385            let (
386                is_aot_supported_arms,
387                generate_x86_asm_arms,
388                where_predicates,
389            ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
390                |(variant_name, field)| {
391                    let field_ty = &field.ty;
392                    let is_aot_supported_arm = quote! {
393                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::is_aot_supported(x, inst)
394                    };
395                    let generate_x86_asm_arm = quote! {
396                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::generate_x86_asm(
397                            x,
398                            inst,
399                            pc,
400                        )
401                    };
402                    let where_predicate =
403                        syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotExecutor<#first_ty_generic> };
404                    (
405                        is_aot_supported_arm,
406                        generate_x86_asm_arm,
407                        where_predicate,
408                    )
409                },
410            ));
411            let where_clause = new_generics.make_where_clause();
412            for predicate in where_predicates {
413                where_clause.predicates.push(predicate);
414            }
415            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
416
417            quote! {
418                #[cfg(feature = "aot")]
419                impl #impl_generics ::openvm_circuit::arch::AotExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
420                    #[inline(always)]
421                    fn is_aot_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
422                        match self {
423                            #(#is_aot_supported_arms,)*
424                        }
425                    }
426
427                    fn generate_x86_asm(
428                        &self,
429                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
430                        pc: u32,
431                    ) -> ::std::result::Result<
432                        ::std::string::String,
433                        ::openvm_circuit::arch::AotError,
434                    > {
435                        match self {
436                            #(#generate_x86_asm_arms,)*
437                        }
438                    }
439                }
440            }
441            .into()
442        }
443        Data::Union(_) => unimplemented!("Unions are not supported"),
444    }
445}
446
447#[proc_macro_derive(MeteredExecutor)]
448pub fn metered_executor_derive(input: TokenStream) -> TokenStream {
449    let ast: syn::DeriveInput = syn::parse(input).unwrap();
450
451    let name = &ast.ident;
452    let generics = &ast.generics;
453    let (impl_generics, ty_generics, _) = generics.split_for_impl();
454
455    match &ast.data {
456        Data::Struct(inner) => {
457            // Check if the struct has only one unnamed field
458            let inner_ty = match &inner.fields {
459                Fields::Unnamed(fields) => {
460                    if fields.unnamed.len() != 1 {
461                        panic!("Only one unnamed field is supported");
462                    }
463                    fields.unnamed.first().unwrap().ty.clone()
464                }
465                _ => panic!("Only unnamed fields are supported"),
466            };
467            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
468            // crate.
469            let mut new_generics = generics.clone();
470            let where_clause = new_generics.make_where_clause();
471            where_clause
472                .predicates
473                .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<F> });
474
475            // We use the macro's feature to decide whether to generate the impl or not. This avoids
476            // the target crate needing the "tco" feature defined.
477            #[cfg(feature = "tco")]
478            let metered_handler = quote! {
479                fn metered_handler<Ctx>(
480                    &self,
481                    chip_idx: usize,
482                    pc: u32,
483                    inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
484                    data: &mut [u8],
485                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
486                where
487                    Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
488                    self.0.metered_handler(chip_idx, pc, inst, data)
489                }
490            };
491            #[cfg(not(feature = "tco"))]
492            let metered_handler = quote! {};
493
494            quote! {
495                impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<F> for #name #ty_generics #where_clause {
496                    #[inline(always)]
497                    fn metered_pre_compute_size(&self) -> usize {
498                        self.0.metered_pre_compute_size()
499                    }
500                    #[cfg(not(feature = "tco"))]
501                    #[inline(always)]
502                    fn metered_pre_compute<Ctx>(
503                        &self,
504                        chip_idx: usize,
505                        pc: u32,
506                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
507                        data: &mut [u8],
508                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
509                    where
510                        Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
511                        self.0.metered_pre_compute(chip_idx, pc, inst, data)
512                    }
513                    #metered_handler
514                }
515            }
516                .into()
517        }
518        Data::Enum(e) => {
519            let variants = e
520                .variants
521                .iter()
522                .map(|variant| {
523                    let variant_name = &variant.ident;
524
525                    let mut fields = variant.fields.iter();
526                    let field = fields.next().unwrap();
527                    assert!(fields.next().is_none(), "Only one field is supported");
528                    (variant_name, field)
529                })
530                .collect::<Vec<_>>();
531            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
532            let mut new_generics = generics.clone();
533            let first_ty_generic = ast
534                .generics
535                .params
536                .first()
537                .and_then(|param| match param {
538                    GenericParam::Type(type_param) => Some(&type_param.ident),
539                    _ => None,
540                })
541                .unwrap_or_else(|| {
542                    new_generics.params.push(syn::parse_quote! { F });
543                    &default_ty_generic
544                });
545            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
546            // crate. Assume F is already generic of the field.
547            let (pre_compute_size_arms, metered_pre_compute_arms, _metered_handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
548                let field_ty = &field.ty;
549                let pre_compute_size_arm = quote! {
550                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x)
551                };
552                let metered_pre_compute_arm = quote! {
553                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data)
554                };
555                let metered_handler_arm = quote! {
556                    #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_handler(x, chip_idx, pc, instruction, data)
557                };
558                let where_predicate = syn::parse_quote! {
559                    #field_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>
560                };
561                (pre_compute_size_arm, metered_pre_compute_arm, metered_handler_arm, where_predicate)
562            }));
563            let where_clause = new_generics.make_where_clause();
564            for predicate in where_predicates {
565                where_clause.predicates.push(predicate);
566            }
567            // Don't use these ty_generics because it might have extra "F"
568            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
569
570            // We use the macro's feature to decide whether to generate the impl or not. This avoids
571            // the target crate needing the "tco" feature defined.
572            #[cfg(feature = "tco")]
573            let metered_handler = quote! {
574                fn metered_handler<Ctx>(
575                    &self,
576                    chip_idx: usize,
577                    pc: u32,
578                    instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
579                    data: &mut [u8],
580                ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
581                where
582                    Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait,
583                {
584                    match self {
585                        #(#_metered_handler_arms,)*
586                    }
587                }
588            };
589            #[cfg(not(feature = "tco"))]
590            let metered_handler = quote! {};
591
592            quote! {
593                impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
594                    #[inline(always)]
595                    fn metered_pre_compute_size(&self) -> usize {
596                        match self {
597                            #(#pre_compute_size_arms,)*
598                        }
599                    }
600
601                    #[cfg(not(feature = "tco"))]
602                    #[inline(always)]
603                    fn metered_pre_compute<Ctx>(
604                        &self,
605                        chip_idx: usize,
606                        pc: u32,
607                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
608                        data: &mut [u8],
609                    ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
610                    where
611                        Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
612                        match self {
613                            #(#metered_pre_compute_arms,)*
614                        }
615                    }
616
617                    #metered_handler
618                }
619            }
620                .into()
621        }
622        Data::Union(_) => unimplemented!("Unions are not supported"),
623    }
624}
625
626#[proc_macro_derive(AotMeteredExecutor)]
627pub fn aot_metered_executor_derive(input: TokenStream) -> TokenStream {
628    let ast: syn::DeriveInput = syn::parse(input).unwrap();
629
630    let name = &ast.ident;
631    let generics = &ast.generics;
632    let (_, ty_generics, _) = generics.split_for_impl();
633
634    match &ast.data {
635        Data::Struct(inner) => {
636            let inner_ty = match &inner.fields {
637                Fields::Unnamed(fields) => {
638                    if fields.unnamed.len() != 1 {
639                        panic!("Only one unnamed field is supported");
640                    }
641                    fields.unnamed.first().unwrap().ty.clone()
642                }
643                _ => panic!("Only unnamed fields are supported"),
644            };
645            let mut new_generics = generics.clone();
646            let where_clause = new_generics.make_where_clause();
647            where_clause.predicates.push(
648                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotMeteredExecutor<F> },
649            );
650            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
651
652            quote! {
653                #[cfg(feature = "aot")]
654                impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<F> for #name #ty_generics #where_clause {
655                    #[inline(always)]
656                    fn is_aot_metered_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
657                        self.0.is_aot_metered_supported(inst)
658                    }
659
660                    fn generate_x86_metered_asm(
661                        &self,
662                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
663                        pc: u32,
664                        chip_idx: usize,
665                        config: &::openvm_circuit::arch::SystemConfig,
666                    ) -> ::std::result::Result<
667                        ::std::string::String,
668                        ::openvm_circuit::arch::AotError,
669                    > {
670                        self.0.generate_x86_metered_asm(inst, pc, chip_idx, config)
671                    }
672                }
673            }
674            .into()
675        }
676        Data::Enum(e) => {
677            let variants = e
678                .variants
679                .iter()
680                .map(|variant| {
681                    let variant_name = &variant.ident;
682                    let mut fields = variant.fields.iter();
683                    let field = fields.next().unwrap();
684                    assert!(fields.next().is_none(), "Only one field is supported");
685                    (variant_name, field)
686                })
687                .collect::<Vec<_>>();
688            let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
689            let mut new_generics = generics.clone();
690            let first_ty_generic = ast
691                .generics
692                .params
693                .first()
694                .and_then(|param| match param {
695                    GenericParam::Type(type_param) => Some(&type_param.ident),
696                    _ => None,
697                })
698                .unwrap_or_else(|| {
699                    new_generics.params.push(syn::parse_quote! { F });
700                    &default_ty_generic
701                });
702            let (
703                is_aot_metered_supported_arms,
704                generate_x86_metered_asm_arms,
705                where_predicates,
706            ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
707                |(variant_name, field)| {
708                    let field_ty = &field.ty;
709                    let is_aot_metered_supported_arm = quote! {
710                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::is_aot_metered_supported(x, inst)
711                    };
712                    let generate_x86_metered_asm_arm = quote! {
713                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::generate_x86_metered_asm(
714                            x,
715                            inst,
716                            pc,
717                            chip_idx,
718                            config,
719                        )
720                    };
721                    let where_predicate =
722                        syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> };
723                    (
724                        is_aot_metered_supported_arm,
725                        generate_x86_metered_asm_arm,
726                        where_predicate,
727                    )
728                },
729            ));
730            let where_clause = new_generics.make_where_clause();
731            for predicate in where_predicates {
732                where_clause.predicates.push(predicate);
733            }
734            let (impl_generics, _, where_clause) = new_generics.split_for_impl();
735
736            quote! {
737                #[cfg(feature = "aot")]
738                impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
739                    #[inline(always)]
740                    fn is_aot_metered_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
741                        match self {
742                            #(#is_aot_metered_supported_arms,)*
743                        }
744                    }
745
746                    fn generate_x86_metered_asm(
747                        &self,
748                        inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
749                        pc: u32,
750                        chip_idx: usize,
751                        config: &::openvm_circuit::arch::SystemConfig,
752                    ) -> ::std::result::Result<
753                        ::std::string::String,
754                        ::openvm_circuit::arch::AotError,
755                    > {
756                        match self {
757                            #(#generate_x86_metered_asm_arms,)*
758                        }
759                    }
760                }
761            }
762            .into()
763        }
764        Data::Union(_) => unimplemented!("Unions are not supported"),
765    }
766}
767
768/// Derives `AnyEnum` trait on an enum type.
769/// By default an enum arm will just return `self` as `&dyn Any`.
770///
771/// Use the `#[any_enum]` field attribute to specify that the
772/// arm itself implements `AnyEnum` and should call the inner `as_any_kind` method.
773#[proc_macro_derive(AnyEnum, attributes(any_enum))]
774pub fn any_enum_derive(input: TokenStream) -> TokenStream {
775    let ast: syn::DeriveInput = syn::parse(input).unwrap();
776
777    let name = &ast.ident;
778    let generics = &ast.generics;
779    let (impl_generics, ty_generics, _) = generics.split_for_impl();
780
781    match &ast.data {
782        Data::Enum(e) => {
783            let variants = e
784                .variants
785                .iter()
786                .map(|variant| {
787                    let variant_name = &variant.ident;
788
789                    // Check if the variant has #[any_enum] attribute
790                    let is_enum = variant
791                        .attrs
792                        .iter()
793                        .any(|attr| attr.path().is_ident("any_enum"));
794                    let mut fields = variant.fields.iter();
795                    let field = fields.next().unwrap();
796                    assert!(fields.next().is_none(), "Only one field is supported");
797                    (variant_name, field, is_enum)
798                })
799                .collect::<Vec<_>>();
800            let (arms, arms_mut): (Vec<_>, Vec<_>) =
801                variants.iter().map(|(variant_name, field, is_enum)| {
802                    let field_ty = &field.ty;
803
804                    if *is_enum {
805                        // Call the inner trait impl
806                        (quote! {
807                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind(x)
808                        },
809                        quote! {
810                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind_mut(x)
811                        })
812                    } else {
813                        (quote! {
814                            #name::#variant_name(x) => x
815                        },
816                        quote! {
817                            #name::#variant_name(x) => x
818                        })
819                    }
820                }).unzip();
821            quote! {
822                impl #impl_generics ::openvm_circuit::arch::AnyEnum for #name #ty_generics {
823                    fn as_any_kind(&self) -> &dyn std::any::Any {
824                        match self {
825                            #(#arms,)*
826                        }
827                    }
828
829                    fn as_any_kind_mut(&mut self) -> &mut dyn std::any::Any {
830                        match self {
831                            #(#arms_mut,)*
832                        }
833                    }
834                }
835            }
836            .into()
837        }
838        _ => syn::Error::new(name.span(), "Only enums are supported")
839            .to_compile_error()
840            .into(),
841    }
842}
843
844#[proc_macro_derive(VmConfig, attributes(config, extension))]
845pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
846    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
847    let name = &ast.ident;
848
849    match &ast.data {
850        syn::Data::Struct(inner) => match generate_config_traits_impl(name, inner) {
851            Ok(tokens) => tokens,
852            Err(err) => err.to_compile_error().into(),
853        },
854        _ => syn::Error::new(name.span(), "Only structs are supported")
855            .to_compile_error()
856            .into(),
857    }
858}
859
860fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result<TokenStream> {
861    let gen_name_with_uppercase_idents = |ident: &Ident| {
862        let mut name = ident.to_string().chars().collect::<Vec<_>>();
863        assert!(name[0].is_lowercase(), "Field name must not be capitalized");
864        let res_lower = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
865        name[0] = name[0].to_ascii_uppercase();
866        let res_upper = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
867        (res_lower, res_upper)
868    };
869
870    let fields = match &inner.fields {
871        Fields::Named(named) => named.named.iter().collect(),
872        Fields::Unnamed(_) => {
873            return Err(syn::Error::new(
874                name.span(),
875                "Only named fields are supported",
876            ))
877        }
878        Fields::Unit => vec![],
879    };
880
881    let source_field = fields
882        .iter()
883        .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config")))
884        .exactly_one()
885        .map_err(|_| {
886            syn::Error::new(
887                name.span(),
888                "Exactly one field must have the #[config] attribute",
889            )
890        })?;
891    let (source_name, source_name_upper) =
892        gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap());
893
894    let extensions = fields
895        .iter()
896        .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension")))
897        .cloned()
898        .collect::<Vec<_>>();
899
900    let mut executor_enum_fields = Vec::new();
901    let mut create_executors = Vec::new();
902    let mut create_airs = Vec::new();
903    let mut execution_where_predicates: Vec<syn::WherePredicate> = Vec::new();
904    let mut circuit_where_predicates: Vec<syn::WherePredicate> = Vec::new();
905    execution_where_predicates.push(parse_quote! { F: ::openvm_circuit::arch::VmField });
906
907    let source_field_ty = source_field.ty.clone();
908
909    for e in extensions.iter() {
910        let (ext_field_name, ext_name_upper) =
911            gen_name_with_uppercase_idents(e.ident.as_ref().expect("field must be named"));
912        let executor_type = parse_executor_type(e, false)?;
913        executor_enum_fields.push(quote! {
914            #[any_enum]
915            #ext_name_upper(#executor_type),
916        });
917        create_executors.push(quote! {
918            let inventory: ::openvm_circuit::arch::ExecutorInventory<Self::Executor> = inventory.extend::<F, _, _>(&self.#ext_field_name)?;
919        });
920        let extension_ty = e.ty.clone();
921        execution_where_predicates.push(parse_quote! {
922            #extension_ty: ::openvm_circuit::arch::VmExecutionExtension<F, Executor = #executor_type>
923        });
924        create_airs.push(quote! {
925            inventory.start_new_extension();
926            ::openvm_circuit::arch::VmCircuitExtension::extend_circuit(&self.#ext_field_name, &mut inventory)?;
927        });
928        circuit_where_predicates.push(parse_quote! {
929            #extension_ty: ::openvm_circuit::arch::VmCircuitExtension<SC>
930        });
931    }
932
933    // The config type always needs <F> due to SystemExecutor
934    let source_executor_type = parse_executor_type(source_field, true)?;
935    execution_where_predicates.push(parse_quote! {
936        #source_field_ty: ::openvm_circuit::arch::VmExecutionConfig<F, Executor = #source_executor_type>
937    });
938    circuit_where_predicates.push(parse_quote! {
939        #source_field_ty: ::openvm_circuit::arch::VmCircuitConfig<SC>
940    });
941    let execution_where_clause = quote! { where #(#execution_where_predicates),* };
942    let circuit_where_clause = quote! { where #(#circuit_where_predicates),* };
943
944    let executor_type = Ident::new(&format!("{name}Executor"), name.span());
945
946    let token_stream = TokenStream::from(quote! {
947        #[derive(
948            Clone,
949            ::derive_more::derive::From,
950            ::openvm_circuit::derive::AnyEnum,
951            ::openvm_circuit::derive::Executor,
952            ::openvm_circuit::derive::MeteredExecutor,
953            ::openvm_circuit::derive::PreflightExecutor,
954        )]
955        #[cfg_attr(feature = "aot", derive(::openvm_circuit::derive::AotExecutor, ::openvm_circuit::derive::AotMeteredExecutor))]
956        pub enum #executor_type<F: ::openvm_circuit::arch::VmField> #execution_where_clause {
957            #[any_enum]
958            #source_name_upper(#source_executor_type),
959            #(#executor_enum_fields)*
960        }
961
962        impl<F: ::openvm_circuit::arch::VmField> ::openvm_circuit::arch::VmExecutionConfig<F> for #name #execution_where_clause {
963            type Executor = #executor_type<F>;
964
965            fn create_executors(
966                &self,
967            ) -> Result<::openvm_circuit::arch::ExecutorInventory<Self::Executor>, ::openvm_circuit::arch::ExecutorInventoryError> {
968                let inventory = self.#source_name.create_executors()?.transmute::<Self::Executor>();
969                #(#create_executors)*
970                Ok(inventory)
971            }
972        }
973
974        impl<SC: openvm_stark_backend::config::StarkGenericConfig> ::openvm_circuit::arch::VmCircuitConfig<SC> for #name #circuit_where_clause {
975            fn create_airs(
976                &self,
977            ) -> Result<::openvm_circuit::arch::AirInventory<SC>, ::openvm_circuit::arch::AirInventoryError> {
978                let mut inventory = self.#source_name.create_airs()?;
979                #(#create_airs)*
980                Ok(inventory)
981            }
982        }
983
984        impl AsRef<SystemConfig> for #name {
985            fn as_ref(&self) -> &SystemConfig {
986                self.#source_name.as_ref()
987            }
988        }
989
990        impl AsMut<SystemConfig> for #name {
991            fn as_mut(&mut self) -> &mut SystemConfig {
992                self.#source_name.as_mut()
993            }
994        }
995    });
996    Ok(token_stream)
997}
998
999// Parse the executor name as either
1000// `{type_name}Executor` or whatever the attribute `executor = ` specifies
1001// Also determines whether the executor type needs generic parameters
1002fn parse_executor_type(
1003    f: &Field,
1004    default_needs_generics: bool,
1005) -> syn::Result<proc_macro2::TokenStream> {
1006    // TRACKING ISSUE:
1007    // We cannot just use <e.ty.to_token_stream() as VmExecutionExtension<F>>::Executor because of this: <https://github.com/rust-lang/rust/issues/85576>
1008    let mut executor_type = None;
1009    // Do not unwrap the Result until needed
1010    let executor_name = syn::parse_str::<Ident>(&format!("{}Executor", f.ty.to_token_stream()));
1011
1012    if let Some(attr) = f
1013        .attrs
1014        .iter()
1015        .find(|attr| attr.path().is_ident("extension") || attr.path().is_ident("config"))
1016    {
1017        match attr.meta {
1018            Meta::Path(_) => {}
1019            Meta::NameValue(_) => {
1020                return Err(syn::Error::new(
1021                    f.ty.span(),
1022                    "Only `#[config]`, `#[extension]`, `#[config(...)]` or `#[extension(...)]` formats are supported",
1023                ))
1024            }
1025            _ => {
1026                let nested = attr
1027                    .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
1028                for meta in nested {
1029                    match meta {
1030                        Meta::NameValue(nv) => {
1031                            if nv.path.is_ident("executor") {
1032                                executor_type = match nv.value {
1033                                    syn::Expr::Lit(syn::ExprLit {
1034                                        lit: syn::Lit::Str(lit_str), ..
1035                                    }) => {
1036                                        let executor_type: syn::Type = syn::parse_str(&lit_str.value())?;
1037                                        Some(quote! { #executor_type })
1038                                    },
1039                                    syn::Expr::Path(path) => {
1040                                        // Handle identifier paths like `executor = MyExecutor`
1041                                        Some(path.to_token_stream())
1042                                    },
1043                                    _ => {
1044                                        return Err(syn::Error::new(
1045                                            nv.value.span(),
1046                                            "executor value must be a string literal or identifier"
1047                                        ));
1048                                    }
1049                                };
1050                            } else if nv.path.is_ident("generics") {
1051                                // Parse boolean value for generics
1052                                let value_str = nv.value.to_token_stream().to_string();
1053                                let needs_generics = match value_str.as_str() {
1054                                    "true" => true,
1055                                    "false" => false,
1056                                    _ => return Err(syn::Error::new(
1057                                        nv.value.span(),
1058                                        "generics attribute must be either true or false"
1059                                    ))
1060                                };
1061                                let executor_name = executor_name.clone()?;
1062                                executor_type = Some(if needs_generics {
1063                                    quote! { #executor_name<F> }
1064                                } else {
1065                                    quote! { #executor_name }
1066                                });
1067                            } else {
1068                                return Err(syn::Error::new(nv.span(), "only executor and generics keys are supported"));
1069                            }
1070                        }
1071                        _ => {
1072                            return Err(syn::Error::new(meta.span(), "only name = value format is supported"));
1073                        }
1074                    }
1075                }
1076            }
1077        }
1078    }
1079    if let Some(executor_type) = executor_type {
1080        Ok(executor_type)
1081    } else {
1082        let executor_name = executor_name?;
1083        Ok(if default_needs_generics {
1084            quote! { #executor_name<F> }
1085        } else {
1086            quote! { #executor_name }
1087        })
1088    }
1089}
1090
1091/// An attribute procedural macro for creating TCO (Tail Call Optimization) handlers.
1092///
1093/// This macro generates a handler function that wraps an execute implementation
1094/// with tail call optimization using the `become` keyword. It extracts the generics
1095/// and where clauses from the original function.
1096///
1097/// # Usage
1098///
1099/// Place this attribute above a function definition:
1100/// ```
1101/// #[create_tco_handler]
1102/// unsafe fn execute_e1_impl<F: PrimeField32, CTX, const B_IS_IMM: bool>(
1103///     pre_compute: *const u8,
1104///     state: &mut VmExecState<F, GuestMemory, CTX>,
1105/// ) where
1106///     CTX: ExecutionCtxTrait,
1107/// {
1108///     // function body
1109/// }
1110/// ```
1111///
1112/// This will generate a TCO handler function with the same generics and where clauses.
1113///
1114/// # Safety
1115///
1116/// Do not use this macro if your function wants to terminate execution without error with a
1117/// specific error code. The handler generated by this macro assumes that execution should continue
1118/// unless the execute_impl returns an error. This is done for performance to skip an exit code
1119/// check.
1120#[proc_macro_attribute]
1121pub fn create_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
1122    #[cfg(feature = "tco")]
1123    {
1124        tco::tco_impl(item)
1125    }
1126    #[cfg(not(feature = "tco"))]
1127    {
1128        nontco::nontco_impl(item)
1129    }
1130}