1extern crate alloc;
2extern crate proc_macro;
3
4use itertools::{multiunzip, Itertools};
5use proc_macro::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{
8 parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DataStruct, Field, Fields,
9 GenericParam, Ident, Meta, Token,
10};
11
12mod common;
13#[cfg(not(feature = "tco"))]
14mod nontco;
15#[cfg(feature = "tco")]
16mod tco;
17
18#[proc_macro_derive(PreflightExecutor)]
19pub fn preflight_executor_derive(input: TokenStream) -> TokenStream {
20 let ast: syn::DeriveInput = syn::parse(input).unwrap();
21
22 let name = &ast.ident;
23 let generics = &ast.generics;
24 let (_, ty_generics, _) = generics.split_for_impl();
25
26 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
27 let mut new_generics = generics.clone();
28 new_generics.params.push(syn::parse_quote! { RA });
29 let field_ty_generic = generics
30 .params
31 .first()
32 .and_then(|param| match param {
33 GenericParam::Type(type_param) => Some(&type_param.ident),
34 _ => None,
35 })
36 .unwrap_or_else(|| {
37 new_generics.params.push(syn::parse_quote! { F });
38 &default_ty_generic
39 });
40
41 match &ast.data {
42 Data::Struct(inner) => {
43 let inner_ty = match &inner.fields {
45 Fields::Unnamed(fields) => {
46 if fields.unnamed.len() != 1 {
47 panic!("Only one unnamed field is supported");
48 }
49 fields.unnamed.first().unwrap().ty.clone()
50 }
51 _ => panic!("Only unnamed fields are supported"),
52 };
53 let where_clause = new_generics.make_where_clause();
56 where_clause.predicates.push(
57 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> },
58 );
59 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
60 quote! {
61 impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
62 fn execute(
63 &self,
64 state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
65 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
66 ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
67 self.0.execute(state, instruction)
68 }
69
70 fn get_opcode_name(&self, opcode: usize) -> String {
71 self.0.get_opcode_name(opcode)
72 }
73 }
74 }
75 .into()
76 }
77 Data::Enum(e) => {
78 let variants = e
79 .variants
80 .iter()
81 .map(|variant| {
82 let variant_name = &variant.ident;
83
84 let mut fields = variant.fields.iter();
85 let field = fields.next().unwrap();
86 assert!(fields.next().is_none(), "Only one field is supported");
87 (variant_name, field)
88 })
89 .collect::<Vec<_>>();
90 let (execute_arms, get_opcode_name_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) =
93 multiunzip(variants.iter().map(|(variant_name, field)| {
94 let field_ty = &field.ty;
95 let execute_arm = quote! {
96 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::execute(x, state, instruction)
97 };
98 let get_opcode_name_arm = quote! {
99 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::get_opcode_name(x, opcode)
100 };
101 let where_predicate = syn::parse_quote! {
102 #field_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>
103 };
104 (execute_arm, get_opcode_name_arm, where_predicate)
105 }));
106 let where_clause = new_generics.make_where_clause();
107 for predicate in where_predicates {
108 where_clause.predicates.push(predicate);
109 }
110 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
112 quote! {
113 impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
114 fn execute(
115 &self,
116 state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
117 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
118 ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
119 match self {
120 #(#execute_arms,)*
121 }
122 }
123
124 fn get_opcode_name(&self, opcode: usize) -> String {
125 match self {
126 #(#get_opcode_name_arms,)*
127 }
128 }
129 }
130 }
131 .into()
132 }
133 Data::Union(_) => unimplemented!("Unions are not supported"),
134 }
135}
136
137#[proc_macro_derive(Executor)]
138pub fn executor_derive(input: TokenStream) -> TokenStream {
139 let ast: syn::DeriveInput = syn::parse(input).unwrap();
140
141 let name = &ast.ident;
142 let generics = &ast.generics;
143 let (impl_generics, ty_generics, _) = generics.split_for_impl();
144
145 match &ast.data {
146 Data::Struct(inner) => {
147 let inner_ty = match &inner.fields {
149 Fields::Unnamed(fields) => {
150 if fields.unnamed.len() != 1 {
151 panic!("Only one unnamed field is supported");
152 }
153 fields.unnamed.first().unwrap().ty.clone()
154 }
155 _ => panic!("Only unnamed fields are supported"),
156 };
157 let mut new_generics = generics.clone();
160 let where_clause = new_generics.make_where_clause();
161 where_clause.predicates.push(
162 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterExecutor<F> },
163 );
164
165 #[cfg(feature = "tco")]
168 let handler = quote! {
169 fn handler<Ctx>(
170 &self,
171 pc: u32,
172 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
173 data: &mut [u8],
174 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
175 where
176 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
177 self.0.handler(pc, inst, data)
178 }
179 };
180 #[cfg(not(feature = "tco"))]
181 let handler = quote! {};
182
183 quote! {
184 impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<F> for #name #ty_generics #where_clause {
185 #[inline(always)]
186 fn pre_compute_size(&self) -> usize {
187 self.0.pre_compute_size()
188 }
189 #[cfg(not(feature = "tco"))]
190 #[inline(always)]
191 fn pre_compute<Ctx>(
192 &self,
193 pc: u32,
194 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
195 data: &mut [u8],
196 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
197 where
198 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
199 self.0.pre_compute(pc, inst, data)
200 }
201
202 #handler
203 }
204 }
205 .into()
206 }
207 Data::Enum(e) => {
208 let variants = e
209 .variants
210 .iter()
211 .map(|variant| {
212 let variant_name = &variant.ident;
213
214 let mut fields = variant.fields.iter();
215 let field = fields.next().unwrap();
216 assert!(fields.next().is_none(), "Only one field is supported");
217 (variant_name, field)
218 })
219 .collect::<Vec<_>>();
220 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
221 let mut new_generics = generics.clone();
222 let first_ty_generic = ast
223 .generics
224 .params
225 .first()
226 .and_then(|param| match param {
227 GenericParam::Type(type_param) => Some(&type_param.ident),
228 _ => None,
229 })
230 .unwrap_or_else(|| {
231 new_generics.params.push(syn::parse_quote! { F });
232 &default_ty_generic
233 });
234 let (pre_compute_size_arms, pre_compute_arms, _handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
237 let field_ty = &field.ty;
238 let pre_compute_size_arm = quote! {
239 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute_size(x)
240 };
241 let pre_compute_arm = quote! {
242 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute(x, pc, instruction, data)
243 };
244 let handler_arm = quote! {
245 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::handler(x, pc, instruction, data)
246 };
247 let where_predicate = syn::parse_quote! {
248 #field_ty: ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>
249 };
250 (pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate)
251 }));
252 let where_clause = new_generics.make_where_clause();
253 for predicate in where_predicates {
254 where_clause.predicates.push(predicate);
255 }
256 #[cfg(feature = "tco")]
259 let handler = quote! {
260 fn handler<Ctx>(
261 &self,
262 pc: u32,
263 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
264 data: &mut [u8],
265 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
266 where
267 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
268 match self {
269 #(#_handler_arms,)*
270 }
271 }
272 };
273 #[cfg(not(feature = "tco"))]
274 let handler = quote! {};
275
276 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
278
279 quote! {
280 impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
281 #[inline(always)]
282 fn pre_compute_size(&self) -> usize {
283 match self {
284 #(#pre_compute_size_arms,)*
285 }
286 }
287
288 #[cfg(not(feature = "tco"))]
289 #[inline(always)]
290 fn pre_compute<Ctx>(
291 &self,
292 pc: u32,
293 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
294 data: &mut [u8],
295 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
296 where
297 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
298 match self {
299 #(#pre_compute_arms,)*
300 }
301 }
302 #handler
303 }
304 }
305 .into()
306 }
307 Data::Union(_) => unimplemented!("Unions are not supported"),
308 }
309}
310
311#[proc_macro_derive(AotExecutor)]
312pub fn aot_executor_derive(input: TokenStream) -> TokenStream {
313 let ast: syn::DeriveInput = syn::parse(input).unwrap();
314
315 let name = &ast.ident;
316 let generics = &ast.generics;
317 let (_, ty_generics, _) = generics.split_for_impl();
318
319 match &ast.data {
320 Data::Struct(inner) => {
321 let inner_ty = match &inner.fields {
322 Fields::Unnamed(fields) => {
323 if fields.unnamed.len() != 1 {
324 panic!("Only one unnamed field is supported");
325 }
326 fields.unnamed.first().unwrap().ty.clone()
327 }
328 _ => panic!("Only unnamed fields are supported"),
329 };
330 let mut new_generics = generics.clone();
331 let where_clause = new_generics.make_where_clause();
332 where_clause
333 .predicates
334 .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotExecutor<F> });
335 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
336
337 quote! {
338 #[cfg(feature = "aot")]
339 impl #impl_generics ::openvm_circuit::arch::AotExecutor<F> for #name #ty_generics #where_clause {
340 #[inline(always)]
341 fn is_aot_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
342 self.0.is_aot_supported(inst)
343 }
344
345 fn generate_x86_asm(
346 &self,
347 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
348 pc: u32,
349 ) -> ::std::result::Result<
350 ::std::string::String,
351 ::openvm_circuit::arch::AotError,
352 > {
353 self.0.generate_x86_asm(inst, pc)
354 }
355 }
356 }
357 .into()
358 }
359 Data::Enum(e) => {
360 let variants = e
361 .variants
362 .iter()
363 .map(|variant| {
364 let variant_name = &variant.ident;
365 let mut fields = variant.fields.iter();
366 let field = fields.next().unwrap();
367 assert!(fields.next().is_none(), "Only one field is supported");
368 (variant_name, field)
369 })
370 .collect::<Vec<_>>();
371 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
372 let mut new_generics = generics.clone();
373 let first_ty_generic = ast
374 .generics
375 .params
376 .first()
377 .and_then(|param| match param {
378 GenericParam::Type(type_param) => Some(&type_param.ident),
379 _ => None,
380 })
381 .unwrap_or_else(|| {
382 new_generics.params.push(syn::parse_quote! { F });
383 &default_ty_generic
384 });
385 let (
386 is_aot_supported_arms,
387 generate_x86_asm_arms,
388 where_predicates,
389 ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
390 |(variant_name, field)| {
391 let field_ty = &field.ty;
392 let is_aot_supported_arm = quote! {
393 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::is_aot_supported(x, inst)
394 };
395 let generate_x86_asm_arm = quote! {
396 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::generate_x86_asm(
397 x,
398 inst,
399 pc,
400 )
401 };
402 let where_predicate =
403 syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotExecutor<#first_ty_generic> };
404 (
405 is_aot_supported_arm,
406 generate_x86_asm_arm,
407 where_predicate,
408 )
409 },
410 ));
411 let where_clause = new_generics.make_where_clause();
412 for predicate in where_predicates {
413 where_clause.predicates.push(predicate);
414 }
415 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
416
417 quote! {
418 #[cfg(feature = "aot")]
419 impl #impl_generics ::openvm_circuit::arch::AotExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
420 #[inline(always)]
421 fn is_aot_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
422 match self {
423 #(#is_aot_supported_arms,)*
424 }
425 }
426
427 fn generate_x86_asm(
428 &self,
429 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
430 pc: u32,
431 ) -> ::std::result::Result<
432 ::std::string::String,
433 ::openvm_circuit::arch::AotError,
434 > {
435 match self {
436 #(#generate_x86_asm_arms,)*
437 }
438 }
439 }
440 }
441 .into()
442 }
443 Data::Union(_) => unimplemented!("Unions are not supported"),
444 }
445}
446
447#[proc_macro_derive(MeteredExecutor)]
448pub fn metered_executor_derive(input: TokenStream) -> TokenStream {
449 let ast: syn::DeriveInput = syn::parse(input).unwrap();
450
451 let name = &ast.ident;
452 let generics = &ast.generics;
453 let (impl_generics, ty_generics, _) = generics.split_for_impl();
454
455 match &ast.data {
456 Data::Struct(inner) => {
457 let inner_ty = match &inner.fields {
459 Fields::Unnamed(fields) => {
460 if fields.unnamed.len() != 1 {
461 panic!("Only one unnamed field is supported");
462 }
463 fields.unnamed.first().unwrap().ty.clone()
464 }
465 _ => panic!("Only unnamed fields are supported"),
466 };
467 let mut new_generics = generics.clone();
470 let where_clause = new_generics.make_where_clause();
471 where_clause
472 .predicates
473 .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<F> });
474
475 #[cfg(feature = "tco")]
478 let metered_handler = quote! {
479 fn metered_handler<Ctx>(
480 &self,
481 chip_idx: usize,
482 pc: u32,
483 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
484 data: &mut [u8],
485 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
486 where
487 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
488 self.0.metered_handler(chip_idx, pc, inst, data)
489 }
490 };
491 #[cfg(not(feature = "tco"))]
492 let metered_handler = quote! {};
493
494 quote! {
495 impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<F> for #name #ty_generics #where_clause {
496 #[inline(always)]
497 fn metered_pre_compute_size(&self) -> usize {
498 self.0.metered_pre_compute_size()
499 }
500 #[cfg(not(feature = "tco"))]
501 #[inline(always)]
502 fn metered_pre_compute<Ctx>(
503 &self,
504 chip_idx: usize,
505 pc: u32,
506 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
507 data: &mut [u8],
508 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
509 where
510 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
511 self.0.metered_pre_compute(chip_idx, pc, inst, data)
512 }
513 #metered_handler
514 }
515 }
516 .into()
517 }
518 Data::Enum(e) => {
519 let variants = e
520 .variants
521 .iter()
522 .map(|variant| {
523 let variant_name = &variant.ident;
524
525 let mut fields = variant.fields.iter();
526 let field = fields.next().unwrap();
527 assert!(fields.next().is_none(), "Only one field is supported");
528 (variant_name, field)
529 })
530 .collect::<Vec<_>>();
531 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
532 let mut new_generics = generics.clone();
533 let first_ty_generic = ast
534 .generics
535 .params
536 .first()
537 .and_then(|param| match param {
538 GenericParam::Type(type_param) => Some(&type_param.ident),
539 _ => None,
540 })
541 .unwrap_or_else(|| {
542 new_generics.params.push(syn::parse_quote! { F });
543 &default_ty_generic
544 });
545 let (pre_compute_size_arms, metered_pre_compute_arms, _metered_handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
548 let field_ty = &field.ty;
549 let pre_compute_size_arm = quote! {
550 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x)
551 };
552 let metered_pre_compute_arm = quote! {
553 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data)
554 };
555 let metered_handler_arm = quote! {
556 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_handler(x, chip_idx, pc, instruction, data)
557 };
558 let where_predicate = syn::parse_quote! {
559 #field_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>
560 };
561 (pre_compute_size_arm, metered_pre_compute_arm, metered_handler_arm, where_predicate)
562 }));
563 let where_clause = new_generics.make_where_clause();
564 for predicate in where_predicates {
565 where_clause.predicates.push(predicate);
566 }
567 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
569
570 #[cfg(feature = "tco")]
573 let metered_handler = quote! {
574 fn metered_handler<Ctx>(
575 &self,
576 chip_idx: usize,
577 pc: u32,
578 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
579 data: &mut [u8],
580 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
581 where
582 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait,
583 {
584 match self {
585 #(#_metered_handler_arms,)*
586 }
587 }
588 };
589 #[cfg(not(feature = "tco"))]
590 let metered_handler = quote! {};
591
592 quote! {
593 impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
594 #[inline(always)]
595 fn metered_pre_compute_size(&self) -> usize {
596 match self {
597 #(#pre_compute_size_arms,)*
598 }
599 }
600
601 #[cfg(not(feature = "tco"))]
602 #[inline(always)]
603 fn metered_pre_compute<Ctx>(
604 &self,
605 chip_idx: usize,
606 pc: u32,
607 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
608 data: &mut [u8],
609 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
610 where
611 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
612 match self {
613 #(#metered_pre_compute_arms,)*
614 }
615 }
616
617 #metered_handler
618 }
619 }
620 .into()
621 }
622 Data::Union(_) => unimplemented!("Unions are not supported"),
623 }
624}
625
626#[proc_macro_derive(AotMeteredExecutor)]
627pub fn aot_metered_executor_derive(input: TokenStream) -> TokenStream {
628 let ast: syn::DeriveInput = syn::parse(input).unwrap();
629
630 let name = &ast.ident;
631 let generics = &ast.generics;
632 let (_, ty_generics, _) = generics.split_for_impl();
633
634 match &ast.data {
635 Data::Struct(inner) => {
636 let inner_ty = match &inner.fields {
637 Fields::Unnamed(fields) => {
638 if fields.unnamed.len() != 1 {
639 panic!("Only one unnamed field is supported");
640 }
641 fields.unnamed.first().unwrap().ty.clone()
642 }
643 _ => panic!("Only unnamed fields are supported"),
644 };
645 let mut new_generics = generics.clone();
646 let where_clause = new_generics.make_where_clause();
647 where_clause.predicates.push(
648 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotMeteredExecutor<F> },
649 );
650 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
651
652 quote! {
653 #[cfg(feature = "aot")]
654 impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<F> for #name #ty_generics #where_clause {
655 #[inline(always)]
656 fn is_aot_metered_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
657 self.0.is_aot_metered_supported(inst)
658 }
659
660 fn generate_x86_metered_asm(
661 &self,
662 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
663 pc: u32,
664 chip_idx: usize,
665 config: &::openvm_circuit::arch::SystemConfig,
666 ) -> ::std::result::Result<
667 ::std::string::String,
668 ::openvm_circuit::arch::AotError,
669 > {
670 self.0.generate_x86_metered_asm(inst, pc, chip_idx, config)
671 }
672 }
673 }
674 .into()
675 }
676 Data::Enum(e) => {
677 let variants = e
678 .variants
679 .iter()
680 .map(|variant| {
681 let variant_name = &variant.ident;
682 let mut fields = variant.fields.iter();
683 let field = fields.next().unwrap();
684 assert!(fields.next().is_none(), "Only one field is supported");
685 (variant_name, field)
686 })
687 .collect::<Vec<_>>();
688 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
689 let mut new_generics = generics.clone();
690 let first_ty_generic = ast
691 .generics
692 .params
693 .first()
694 .and_then(|param| match param {
695 GenericParam::Type(type_param) => Some(&type_param.ident),
696 _ => None,
697 })
698 .unwrap_or_else(|| {
699 new_generics.params.push(syn::parse_quote! { F });
700 &default_ty_generic
701 });
702 let (
703 is_aot_metered_supported_arms,
704 generate_x86_metered_asm_arms,
705 where_predicates,
706 ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
707 |(variant_name, field)| {
708 let field_ty = &field.ty;
709 let is_aot_metered_supported_arm = quote! {
710 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::is_aot_metered_supported(x, inst)
711 };
712 let generate_x86_metered_asm_arm = quote! {
713 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::generate_x86_metered_asm(
714 x,
715 inst,
716 pc,
717 chip_idx,
718 config,
719 )
720 };
721 let where_predicate =
722 syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> };
723 (
724 is_aot_metered_supported_arm,
725 generate_x86_metered_asm_arm,
726 where_predicate,
727 )
728 },
729 ));
730 let where_clause = new_generics.make_where_clause();
731 for predicate in where_predicates {
732 where_clause.predicates.push(predicate);
733 }
734 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
735
736 quote! {
737 #[cfg(feature = "aot")]
738 impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
739 #[inline(always)]
740 fn is_aot_metered_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
741 match self {
742 #(#is_aot_metered_supported_arms,)*
743 }
744 }
745
746 fn generate_x86_metered_asm(
747 &self,
748 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
749 pc: u32,
750 chip_idx: usize,
751 config: &::openvm_circuit::arch::SystemConfig,
752 ) -> ::std::result::Result<
753 ::std::string::String,
754 ::openvm_circuit::arch::AotError,
755 > {
756 match self {
757 #(#generate_x86_metered_asm_arms,)*
758 }
759 }
760 }
761 }
762 .into()
763 }
764 Data::Union(_) => unimplemented!("Unions are not supported"),
765 }
766}
767
768#[proc_macro_derive(AnyEnum, attributes(any_enum))]
774pub fn any_enum_derive(input: TokenStream) -> TokenStream {
775 let ast: syn::DeriveInput = syn::parse(input).unwrap();
776
777 let name = &ast.ident;
778 let generics = &ast.generics;
779 let (impl_generics, ty_generics, _) = generics.split_for_impl();
780
781 match &ast.data {
782 Data::Enum(e) => {
783 let variants = e
784 .variants
785 .iter()
786 .map(|variant| {
787 let variant_name = &variant.ident;
788
789 let is_enum = variant
791 .attrs
792 .iter()
793 .any(|attr| attr.path().is_ident("any_enum"));
794 let mut fields = variant.fields.iter();
795 let field = fields.next().unwrap();
796 assert!(fields.next().is_none(), "Only one field is supported");
797 (variant_name, field, is_enum)
798 })
799 .collect::<Vec<_>>();
800 let (arms, arms_mut): (Vec<_>, Vec<_>) =
801 variants.iter().map(|(variant_name, field, is_enum)| {
802 let field_ty = &field.ty;
803
804 if *is_enum {
805 (quote! {
807 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind(x)
808 },
809 quote! {
810 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind_mut(x)
811 })
812 } else {
813 (quote! {
814 #name::#variant_name(x) => x
815 },
816 quote! {
817 #name::#variant_name(x) => x
818 })
819 }
820 }).unzip();
821 quote! {
822 impl #impl_generics ::openvm_circuit::arch::AnyEnum for #name #ty_generics {
823 fn as_any_kind(&self) -> &dyn std::any::Any {
824 match self {
825 #(#arms,)*
826 }
827 }
828
829 fn as_any_kind_mut(&mut self) -> &mut dyn std::any::Any {
830 match self {
831 #(#arms_mut,)*
832 }
833 }
834 }
835 }
836 .into()
837 }
838 _ => syn::Error::new(name.span(), "Only enums are supported")
839 .to_compile_error()
840 .into(),
841 }
842}
843
844#[proc_macro_derive(VmConfig, attributes(config, extension))]
845pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
846 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
847 let name = &ast.ident;
848
849 match &ast.data {
850 syn::Data::Struct(inner) => match generate_config_traits_impl(name, inner) {
851 Ok(tokens) => tokens,
852 Err(err) => err.to_compile_error().into(),
853 },
854 _ => syn::Error::new(name.span(), "Only structs are supported")
855 .to_compile_error()
856 .into(),
857 }
858}
859
860fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result<TokenStream> {
861 let gen_name_with_uppercase_idents = |ident: &Ident| {
862 let mut name = ident.to_string().chars().collect::<Vec<_>>();
863 assert!(name[0].is_lowercase(), "Field name must not be capitalized");
864 let res_lower = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
865 name[0] = name[0].to_ascii_uppercase();
866 let res_upper = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
867 (res_lower, res_upper)
868 };
869
870 let fields = match &inner.fields {
871 Fields::Named(named) => named.named.iter().collect(),
872 Fields::Unnamed(_) => {
873 return Err(syn::Error::new(
874 name.span(),
875 "Only named fields are supported",
876 ))
877 }
878 Fields::Unit => vec![],
879 };
880
881 let source_field = fields
882 .iter()
883 .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config")))
884 .exactly_one()
885 .map_err(|_| {
886 syn::Error::new(
887 name.span(),
888 "Exactly one field must have the #[config] attribute",
889 )
890 })?;
891 let (source_name, source_name_upper) =
892 gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap());
893
894 let extensions = fields
895 .iter()
896 .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension")))
897 .cloned()
898 .collect::<Vec<_>>();
899
900 let mut executor_enum_fields = Vec::new();
901 let mut create_executors = Vec::new();
902 let mut create_airs = Vec::new();
903 let mut execution_where_predicates: Vec<syn::WherePredicate> = Vec::new();
904 let mut circuit_where_predicates: Vec<syn::WherePredicate> = Vec::new();
905 execution_where_predicates.push(parse_quote! { F: ::openvm_circuit::arch::VmField });
906
907 let source_field_ty = source_field.ty.clone();
908
909 for e in extensions.iter() {
910 let (ext_field_name, ext_name_upper) =
911 gen_name_with_uppercase_idents(e.ident.as_ref().expect("field must be named"));
912 let executor_type = parse_executor_type(e, false)?;
913 executor_enum_fields.push(quote! {
914 #[any_enum]
915 #ext_name_upper(#executor_type),
916 });
917 create_executors.push(quote! {
918 let inventory: ::openvm_circuit::arch::ExecutorInventory<Self::Executor> = inventory.extend::<F, _, _>(&self.#ext_field_name)?;
919 });
920 let extension_ty = e.ty.clone();
921 execution_where_predicates.push(parse_quote! {
922 #extension_ty: ::openvm_circuit::arch::VmExecutionExtension<F, Executor = #executor_type>
923 });
924 create_airs.push(quote! {
925 inventory.start_new_extension();
926 ::openvm_circuit::arch::VmCircuitExtension::extend_circuit(&self.#ext_field_name, &mut inventory)?;
927 });
928 circuit_where_predicates.push(parse_quote! {
929 #extension_ty: ::openvm_circuit::arch::VmCircuitExtension<SC>
930 });
931 }
932
933 let source_executor_type = parse_executor_type(source_field, true)?;
935 execution_where_predicates.push(parse_quote! {
936 #source_field_ty: ::openvm_circuit::arch::VmExecutionConfig<F, Executor = #source_executor_type>
937 });
938 circuit_where_predicates.push(parse_quote! {
939 #source_field_ty: ::openvm_circuit::arch::VmCircuitConfig<SC>
940 });
941 let execution_where_clause = quote! { where #(#execution_where_predicates),* };
942 let circuit_where_clause = quote! { where #(#circuit_where_predicates),* };
943
944 let executor_type = Ident::new(&format!("{name}Executor"), name.span());
945
946 let token_stream = TokenStream::from(quote! {
947 #[derive(
948 Clone,
949 ::derive_more::derive::From,
950 ::openvm_circuit::derive::AnyEnum,
951 ::openvm_circuit::derive::Executor,
952 ::openvm_circuit::derive::MeteredExecutor,
953 ::openvm_circuit::derive::PreflightExecutor,
954 )]
955 #[cfg_attr(feature = "aot", derive(::openvm_circuit::derive::AotExecutor, ::openvm_circuit::derive::AotMeteredExecutor))]
956 pub enum #executor_type<F: ::openvm_circuit::arch::VmField> #execution_where_clause {
957 #[any_enum]
958 #source_name_upper(#source_executor_type),
959 #(#executor_enum_fields)*
960 }
961
962 impl<F: ::openvm_circuit::arch::VmField> ::openvm_circuit::arch::VmExecutionConfig<F> for #name #execution_where_clause {
963 type Executor = #executor_type<F>;
964
965 fn create_executors(
966 &self,
967 ) -> Result<::openvm_circuit::arch::ExecutorInventory<Self::Executor>, ::openvm_circuit::arch::ExecutorInventoryError> {
968 let inventory = self.#source_name.create_executors()?.transmute::<Self::Executor>();
969 #(#create_executors)*
970 Ok(inventory)
971 }
972 }
973
974 impl<SC: openvm_stark_backend::config::StarkGenericConfig> ::openvm_circuit::arch::VmCircuitConfig<SC> for #name #circuit_where_clause {
975 fn create_airs(
976 &self,
977 ) -> Result<::openvm_circuit::arch::AirInventory<SC>, ::openvm_circuit::arch::AirInventoryError> {
978 let mut inventory = self.#source_name.create_airs()?;
979 #(#create_airs)*
980 Ok(inventory)
981 }
982 }
983
984 impl AsRef<SystemConfig> for #name {
985 fn as_ref(&self) -> &SystemConfig {
986 self.#source_name.as_ref()
987 }
988 }
989
990 impl AsMut<SystemConfig> for #name {
991 fn as_mut(&mut self) -> &mut SystemConfig {
992 self.#source_name.as_mut()
993 }
994 }
995 });
996 Ok(token_stream)
997}
998
999fn parse_executor_type(
1003 f: &Field,
1004 default_needs_generics: bool,
1005) -> syn::Result<proc_macro2::TokenStream> {
1006 let mut executor_type = None;
1009 let executor_name = syn::parse_str::<Ident>(&format!("{}Executor", f.ty.to_token_stream()));
1011
1012 if let Some(attr) = f
1013 .attrs
1014 .iter()
1015 .find(|attr| attr.path().is_ident("extension") || attr.path().is_ident("config"))
1016 {
1017 match attr.meta {
1018 Meta::Path(_) => {}
1019 Meta::NameValue(_) => {
1020 return Err(syn::Error::new(
1021 f.ty.span(),
1022 "Only `#[config]`, `#[extension]`, `#[config(...)]` or `#[extension(...)]` formats are supported",
1023 ))
1024 }
1025 _ => {
1026 let nested = attr
1027 .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
1028 for meta in nested {
1029 match meta {
1030 Meta::NameValue(nv) => {
1031 if nv.path.is_ident("executor") {
1032 executor_type = match nv.value {
1033 syn::Expr::Lit(syn::ExprLit {
1034 lit: syn::Lit::Str(lit_str), ..
1035 }) => {
1036 let executor_type: syn::Type = syn::parse_str(&lit_str.value())?;
1037 Some(quote! { #executor_type })
1038 },
1039 syn::Expr::Path(path) => {
1040 Some(path.to_token_stream())
1042 },
1043 _ => {
1044 return Err(syn::Error::new(
1045 nv.value.span(),
1046 "executor value must be a string literal or identifier"
1047 ));
1048 }
1049 };
1050 } else if nv.path.is_ident("generics") {
1051 let value_str = nv.value.to_token_stream().to_string();
1053 let needs_generics = match value_str.as_str() {
1054 "true" => true,
1055 "false" => false,
1056 _ => return Err(syn::Error::new(
1057 nv.value.span(),
1058 "generics attribute must be either true or false"
1059 ))
1060 };
1061 let executor_name = executor_name.clone()?;
1062 executor_type = Some(if needs_generics {
1063 quote! { #executor_name<F> }
1064 } else {
1065 quote! { #executor_name }
1066 });
1067 } else {
1068 return Err(syn::Error::new(nv.span(), "only executor and generics keys are supported"));
1069 }
1070 }
1071 _ => {
1072 return Err(syn::Error::new(meta.span(), "only name = value format is supported"));
1073 }
1074 }
1075 }
1076 }
1077 }
1078 }
1079 if let Some(executor_type) = executor_type {
1080 Ok(executor_type)
1081 } else {
1082 let executor_name = executor_name?;
1083 Ok(if default_needs_generics {
1084 quote! { #executor_name<F> }
1085 } else {
1086 quote! { #executor_name }
1087 })
1088 }
1089}
1090
1091#[proc_macro_attribute]
1121pub fn create_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
1122 #[cfg(feature = "tco")]
1123 {
1124 tco::tco_impl(item)
1125 }
1126 #[cfg(not(feature = "tco"))]
1127 {
1128 nontco::nontco_impl(item)
1129 }
1130}