openvm_circuit_derive/
lib.rs

1extern crate alloc;
2extern crate proc_macro;
3
4use itertools::{multiunzip, Itertools};
5use proc_macro::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{punctuated::Punctuated, Data, Fields, GenericParam, Ident, Meta, Token};
8
9#[proc_macro_derive(InstructionExecutor)]
10pub fn instruction_executor_derive(input: TokenStream) -> TokenStream {
11    let ast: syn::DeriveInput = syn::parse(input).unwrap();
12
13    let name = &ast.ident;
14    let generics = &ast.generics;
15    let (impl_generics, ty_generics, _) = generics.split_for_impl();
16
17    match &ast.data {
18        Data::Struct(inner) => {
19            // Check if the struct has only one unnamed field
20            let inner_ty = match &inner.fields {
21                Fields::Unnamed(fields) => {
22                    if fields.unnamed.len() != 1 {
23                        panic!("Only one unnamed field is supported");
24                    }
25                    fields.unnamed.first().unwrap().ty.clone()
26                }
27                _ => panic!("Only unnamed fields are supported"),
28            };
29            // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate.
30            // Assume F is already generic of the field.
31            let mut new_generics = generics.clone();
32            let where_clause = new_generics.make_where_clause();
33            where_clause.predicates.push(
34                syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InstructionExecutor<F> },
35            );
36            quote! {
37                impl #impl_generics ::openvm_circuit::arch::InstructionExecutor<F> for #name #ty_generics #where_clause {
38                    fn execute(
39                        &mut self,
40                        memory: &mut ::openvm_circuit::system::memory::MemoryController<F>,
41                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
42                        from_state: ::openvm_circuit::arch::ExecutionState<u32>,
43                    ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState<u32>> {
44                        self.0.execute(memory, instruction, from_state)
45                    }
46
47                    fn get_opcode_name(&self, opcode: usize) -> String {
48                        self.0.get_opcode_name(opcode)
49                    }
50                }
51            }
52            .into()
53        }
54        Data::Enum(e) => {
55            let variants = e
56                .variants
57                .iter()
58                .map(|variant| {
59                    let variant_name = &variant.ident;
60
61                    let mut fields = variant.fields.iter();
62                    let field = fields.next().unwrap();
63                    assert!(fields.next().is_none(), "Only one field is supported");
64                    (variant_name, field)
65                })
66                .collect::<Vec<_>>();
67            let first_ty_generic = ast
68                .generics
69                .params
70                .first()
71                .and_then(|param| match param {
72                    GenericParam::Type(type_param) => Some(&type_param.ident),
73                    _ => None,
74                })
75                .expect("First generic must be type for Field");
76            // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate.
77            // Assume F is already generic of the field.
78            let (execute_arms, get_opcode_name_arms): (Vec<_>, Vec<_>) =
79                multiunzip(variants.iter().map(|(variant_name, field)| {
80                    let field_ty = &field.ty;
81                    let execute_arm = quote! {
82                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::execute(x, memory, instruction, from_state)
83                    };
84                    let get_opcode_name_arm = quote! {
85                        #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic>>::get_opcode_name(x, opcode)
86                    };
87
88                    (execute_arm, get_opcode_name_arm)
89                }));
90            quote! {
91                impl #impl_generics ::openvm_circuit::arch::InstructionExecutor<#first_ty_generic> for #name #ty_generics {
92                    fn execute(
93                        &mut self,
94                        memory: &mut ::openvm_circuit::system::memory::MemoryController<#first_ty_generic>,
95                        instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
96                        from_state: ::openvm_circuit::arch::ExecutionState<u32>,
97                    ) -> ::openvm_circuit::arch::Result<::openvm_circuit::arch::ExecutionState<u32>> {
98                        match self {
99                            #(#execute_arms,)*
100                        }
101                    }
102
103                    fn get_opcode_name(&self, opcode: usize) -> String {
104                        match self {
105                            #(#get_opcode_name_arms,)*
106                        }
107                    }
108                }
109            }
110            .into()
111        }
112        Data::Union(_) => unimplemented!("Unions are not supported"),
113    }
114}
115
116/// Derives `AnyEnum` trait on an enum type.
117/// By default an enum arm will just return `self` as `&dyn Any`.
118///
119/// Use the `#[any_enum]` field attribute to specify that the
120/// arm itself implements `AnyEnum` and should call the inner `as_any_kind` method.
121#[proc_macro_derive(AnyEnum, attributes(any_enum))]
122pub fn any_enum_derive(input: TokenStream) -> TokenStream {
123    let ast: syn::DeriveInput = syn::parse(input).unwrap();
124
125    let name = &ast.ident;
126    let generics = &ast.generics;
127    let (impl_generics, ty_generics, _) = generics.split_for_impl();
128
129    match &ast.data {
130        Data::Enum(e) => {
131            let variants = e
132                .variants
133                .iter()
134                .map(|variant| {
135                    let variant_name = &variant.ident;
136
137                    // Check if the variant has #[any_enum] attribute
138                    let is_enum = variant
139                        .attrs
140                        .iter()
141                        .any(|attr| attr.path().is_ident("any_enum"));
142                    let mut fields = variant.fields.iter();
143                    let field = fields.next().unwrap();
144                    assert!(fields.next().is_none(), "Only one field is supported");
145                    (variant_name, field, is_enum)
146                })
147                .collect::<Vec<_>>();
148            let (arms, arms_mut): (Vec<_>, Vec<_>) =
149                variants.iter().map(|(variant_name, field, is_enum)| {
150                    let field_ty = &field.ty;
151
152                    if *is_enum {
153                        // Call the inner trait impl
154                        (quote! {
155                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind(x)
156                        },
157                        quote! {
158                            #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind_mut(x)
159                        })
160                    } else {
161                        (quote! {
162                            #name::#variant_name(x) => x
163                        },
164                        quote! {
165                            #name::#variant_name(x) => x
166                        })
167                    }
168                }).unzip();
169            quote! {
170                impl #impl_generics ::openvm_circuit::arch::AnyEnum for #name #ty_generics {
171                    fn as_any_kind(&self) -> &dyn std::any::Any {
172                        match self {
173                            #(#arms,)*
174                        }
175                    }
176
177                    fn as_any_kind_mut(&mut self) -> &mut dyn std::any::Any {
178                        match self {
179                            #(#arms_mut,)*
180                        }
181                    }
182                }
183            }
184            .into()
185        }
186        _ => syn::Error::new(name.span(), "Only enums are supported")
187            .to_compile_error()
188            .into(),
189    }
190}
191
192// VmConfig derive macro
193#[derive(Debug)]
194enum Source {
195    System(Ident),
196    Config(Ident),
197}
198
199#[proc_macro_derive(VmConfig, attributes(system, config, extension))]
200pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
201    let ast = syn::parse_macro_input!(input as syn::DeriveInput);
202    let name = &ast.ident;
203
204    let gen_name_with_uppercase_idents = |ident: &Ident| {
205        let mut name = ident.to_string().chars().collect::<Vec<_>>();
206        assert!(name[0].is_lowercase(), "Field name must not be capitalized");
207        let res_lower = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
208        name[0] = name[0].to_ascii_uppercase();
209        let res_upper = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
210        (res_lower, res_upper)
211    };
212
213    match &ast.data {
214        syn::Data::Struct(inner) => {
215            let fields = match &inner.fields {
216                Fields::Named(named) => named.named.iter().collect(),
217                Fields::Unnamed(_) => {
218                    return syn::Error::new(name.span(), "Only named fields are supported")
219                        .to_compile_error()
220                        .into();
221                }
222                Fields::Unit => vec![],
223            };
224
225            let source = fields
226                .iter()
227                .filter_map(|f| {
228                    if f.attrs.iter().any(|attr| attr.path().is_ident("system")) {
229                        Some(Source::System(f.ident.clone().unwrap()))
230                    } else if f.attrs.iter().any(|attr| attr.path().is_ident("config")) {
231                        Some(Source::Config(f.ident.clone().unwrap()))
232                    } else {
233                        None
234                    }
235                })
236                .exactly_one()
237                .expect("Exactly one field must have #[system] or #[config] attribute");
238            let (source_name, source_name_upper) = match &source {
239                Source::System(ident) | Source::Config(ident) => {
240                    gen_name_with_uppercase_idents(ident)
241                }
242            };
243
244            let extensions = fields
245                .iter()
246                .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension")))
247                .cloned()
248                .collect::<Vec<_>>();
249
250            let mut executor_enum_fields = Vec::new();
251            let mut periphery_enum_fields = Vec::new();
252            let mut create_chip_complex = Vec::new();
253            for &e in extensions.iter() {
254                let (field_name, field_name_upper) =
255                    gen_name_with_uppercase_idents(&e.ident.clone().unwrap());
256                // TRACKING ISSUE:
257                // We cannot just use <e.ty.to_token_stream() as VmExtension<F>>::Executor because of this: <https://github.com/rust-lang/rust/issues/85576>
258                let mut executor_name = Ident::new(
259                    &format!("{}Executor", e.ty.to_token_stream()),
260                    Span::call_site().into(),
261                );
262                let mut periphery_name = Ident::new(
263                    &format!("{}Periphery", e.ty.to_token_stream()),
264                    Span::call_site().into(),
265                );
266                if let Some(attr) = e
267                    .attrs
268                    .iter()
269                    .find(|attr| attr.path().is_ident("extension"))
270                {
271                    match attr.meta {
272                        Meta::Path(_) => {}
273                        Meta::NameValue(_) => {
274                            return syn::Error::new(
275                                name.span(),
276                                "Only `#[extension]` or `#[extension(...)] formats are supported",
277                            )
278                            .to_compile_error()
279                            .into()
280                        }
281                        _ => {
282                            let nested = attr
283                                .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)
284                                .unwrap();
285                            for meta in nested {
286                                match meta {
287                                    Meta::NameValue(nv) => {
288                                        if nv.path.is_ident("executor") {
289                                            executor_name = Ident::new(
290                                                &nv.value.to_token_stream().to_string(),
291                                                Span::call_site().into(),
292                                            );
293                                            Ok(())
294                                        } else if nv.path.is_ident("periphery") {
295                                            periphery_name = Ident::new(
296                                                &nv.value.to_token_stream().to_string(),
297                                                Span::call_site().into(),
298                                            );
299                                            Ok(())
300                                        } else {
301                                            Err("only executor and periphery keys are supported")
302                                        }
303                                    }
304                                    _ => Err("only name = value format is supported"),
305                                }
306                                .expect("wrong attributes format");
307                            }
308                        }
309                    }
310                };
311                executor_enum_fields.push(quote! {
312                    #[any_enum]
313                    #field_name_upper(#executor_name<F>),
314                });
315                periphery_enum_fields.push(quote! {
316                    #[any_enum]
317                    #field_name_upper(#periphery_name<F>),
318                });
319                create_chip_complex.push(quote! {
320                    let complex: VmChipComplex<F, Self::Executor, Self::Periphery> = complex.extend(&self.#field_name)?;
321                });
322            }
323
324            let (source_executor_type, source_periphery_type) = match &source {
325                Source::System(_) => (quote! { SystemExecutor }, quote! { SystemPeriphery }),
326                Source::Config(field_ident) => {
327                    let field_type = fields
328                        .iter()
329                        .find(|f| f.ident.as_ref() == Some(field_ident))
330                        .map(|f| &f.ty)
331                        .expect("Field not found");
332
333                    let executor_type = format!("{}Executor", quote!(#field_type));
334                    let periphery_type = format!("{}Periphery", quote!(#field_type));
335
336                    let executor_ident = Ident::new(&executor_type, field_ident.span());
337                    let periphery_ident = Ident::new(&periphery_type, field_ident.span());
338
339                    (quote! { #executor_ident }, quote! { #periphery_ident })
340                }
341            };
342
343            let executor_type = Ident::new(&format!("{}Executor", name), name.span());
344            let periphery_type = Ident::new(&format!("{}Periphery", name), name.span());
345
346            TokenStream::from(quote! {
347                #[derive(ChipUsageGetter, Chip, InstructionExecutor, From, AnyEnum)]
348                pub enum #executor_type<F: PrimeField32> {
349                    #[any_enum]
350                    #source_name_upper(#source_executor_type<F>),
351                    #(#executor_enum_fields)*
352                }
353
354                #[derive(ChipUsageGetter, Chip, From, AnyEnum)]
355                pub enum #periphery_type<F: PrimeField32> {
356                    #[any_enum]
357                    #source_name_upper(#source_periphery_type<F>),
358                    #(#periphery_enum_fields)*
359                }
360
361                impl<F: PrimeField32> VmConfig<F> for #name {
362                    type Executor = #executor_type<F>;
363                    type Periphery = #periphery_type<F>;
364
365                    fn system(&self) -> &SystemConfig {
366                        VmConfig::<F>::system(&self.#source_name)
367                    }
368                    fn system_mut(&mut self) -> &mut SystemConfig {
369                        VmConfig::<F>::system_mut(&mut self.#source_name)
370                    }
371
372                    fn create_chip_complex(
373                        &self,
374                    ) -> Result<VmChipComplex<F, Self::Executor, Self::Periphery>, VmInventoryError> {
375                        let complex = self.#source_name.create_chip_complex()?;
376                        #(#create_chip_complex)*
377                        Ok(complex)
378                    }
379                }
380            })
381        }
382        _ => syn::Error::new(name.span(), "Only structs are supported")
383            .to_compile_error()
384            .into(),
385    }
386}