1extern crate alloc;
2extern crate proc_macro;
3
4use itertools::{multiunzip, Itertools};
5use proc_macro::{Span, TokenStream};
6use quote::{quote, ToTokens};
7use syn::{
8 parse_quote, punctuated::Punctuated, spanned::Spanned, Data, DataStruct, Field, Fields,
9 GenericParam, Ident, Meta, Token,
10};
11
12mod common;
13#[cfg(not(feature = "tco"))]
14mod nontco;
15#[cfg(feature = "tco")]
16mod tco;
17
18#[proc_macro_derive(PreflightExecutor)]
19pub fn preflight_executor_derive(input: TokenStream) -> TokenStream {
20 let ast: syn::DeriveInput = syn::parse(input).unwrap();
21
22 let name = &ast.ident;
23 let generics = &ast.generics;
24 let (_, ty_generics, _) = generics.split_for_impl();
25
26 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
27 let mut new_generics = generics.clone();
28 new_generics.params.push(syn::parse_quote! { RA });
29 let field_ty_generic = generics
30 .params
31 .first()
32 .and_then(|param| match param {
33 GenericParam::Type(type_param) => Some(&type_param.ident),
34 _ => None,
35 })
36 .unwrap_or_else(|| {
37 new_generics.params.push(syn::parse_quote! { F });
38 &default_ty_generic
39 });
40
41 match &ast.data {
42 Data::Struct(inner) => {
43 let inner_ty = match &inner.fields {
45 Fields::Unnamed(fields) => {
46 if fields.unnamed.len() != 1 {
47 panic!("Only one unnamed field is supported");
48 }
49 fields.unnamed.first().unwrap().ty.clone()
50 }
51 _ => panic!("Only unnamed fields are supported"),
52 };
53 let where_clause = new_generics.make_where_clause();
56 where_clause.predicates.push(
57 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> },
58 );
59 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
60 quote! {
61 impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
62 fn execute(
63 &self,
64 state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
65 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
66 ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
67 self.0.execute(state, instruction)
68 }
69
70 fn get_opcode_name(&self, opcode: usize) -> String {
71 self.0.get_opcode_name(opcode)
72 }
73 }
74 }
75 .into()
76 }
77 Data::Enum(e) => {
78 let variants = e
79 .variants
80 .iter()
81 .map(|variant| {
82 let variant_name = &variant.ident;
83
84 let mut fields = variant.fields.iter();
85 let field = fields.next().unwrap();
86 assert!(fields.next().is_none(), "Only one field is supported");
87 (variant_name, field)
88 })
89 .collect::<Vec<_>>();
90 let (execute_arms, get_opcode_name_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>) =
93 multiunzip(variants.iter().map(|(variant_name, field)| {
94 let field_ty = &field.ty;
95 let execute_arm = quote! {
96 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::execute(x, state, instruction)
97 };
98 let get_opcode_name_arm = quote! {
99 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>>::get_opcode_name(x, opcode)
100 };
101 let where_predicate = syn::parse_quote! {
102 #field_ty: ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA>
103 };
104 (execute_arm, get_opcode_name_arm, where_predicate)
105 }));
106 let where_clause = new_generics.make_where_clause();
107 for predicate in where_predicates {
108 where_clause.predicates.push(predicate);
109 }
110 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
112 quote! {
113 impl #impl_generics ::openvm_circuit::arch::PreflightExecutor<#field_ty_generic, RA> for #name #ty_generics #where_clause {
114 fn execute(
115 &self,
116 state: ::openvm_circuit::arch::VmStateMut<#field_ty_generic, ::openvm_circuit::system::memory::online::TracingMemory, RA>,
117 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<#field_ty_generic>,
118 ) -> Result<(), ::openvm_circuit::arch::ExecutionError> {
119 match self {
120 #(#execute_arms,)*
121 }
122 }
123
124 fn get_opcode_name(&self, opcode: usize) -> String {
125 match self {
126 #(#get_opcode_name_arms,)*
127 }
128 }
129 }
130 }
131 .into()
132 }
133 Data::Union(_) => unimplemented!("Unions are not supported"),
134 }
135}
136
137#[proc_macro_derive(Executor)]
138pub fn executor_derive(input: TokenStream) -> TokenStream {
139 let ast: syn::DeriveInput = syn::parse(input).unwrap();
140
141 let name = &ast.ident;
142 let generics = &ast.generics;
143 let (impl_generics, ty_generics, _) = generics.split_for_impl();
144
145 match &ast.data {
146 Data::Struct(inner) => {
147 let inner_ty = match &inner.fields {
149 Fields::Unnamed(fields) => {
150 if fields.unnamed.len() != 1 {
151 panic!("Only one unnamed field is supported");
152 }
153 fields.unnamed.first().unwrap().ty.clone()
154 }
155 _ => panic!("Only unnamed fields are supported"),
156 };
157 let mut new_generics = generics.clone();
160 let where_clause = new_generics.make_where_clause();
161 where_clause.predicates.push(
162 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterExecutor<F> },
163 );
164
165 #[cfg(feature = "tco")]
168 let handler = quote! {
169 fn handler<Ctx>(
170 &self,
171 pc: u32,
172 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
173 data: &mut [u8],
174 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
175 where
176 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
177 self.0.handler(pc, inst, data)
178 }
179 };
180 #[cfg(not(feature = "tco"))]
181 let handler = quote! {};
182
183 quote! {
184 impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<F> for #name #ty_generics #where_clause {
185 #[inline(always)]
186 fn pre_compute_size(&self) -> usize {
187 self.0.pre_compute_size()
188 }
189 #[cfg(not(feature = "tco"))]
190 #[inline(always)]
191 fn pre_compute<Ctx>(
192 &self,
193 pc: u32,
194 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
195 data: &mut [u8],
196 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
197 where
198 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
199 self.0.pre_compute(pc, inst, data)
200 }
201
202 #handler
203 }
204 }
205 .into()
206 }
207 Data::Enum(e) => {
208 let variants = e
209 .variants
210 .iter()
211 .map(|variant| {
212 let variant_name = &variant.ident;
213
214 let mut fields = variant.fields.iter();
215 let field = fields.next().unwrap();
216 assert!(fields.next().is_none(), "Only one field is supported");
217 (variant_name, field)
218 })
219 .collect::<Vec<_>>();
220 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
221 let mut new_generics = generics.clone();
222 let first_ty_generic = ast
223 .generics
224 .params
225 .first()
226 .and_then(|param| match param {
227 GenericParam::Type(type_param) => Some(&type_param.ident),
228 _ => None,
229 })
230 .unwrap_or_else(|| {
231 new_generics.params.push(syn::parse_quote! { F });
232 &default_ty_generic
233 });
234 let (pre_compute_size_arms, pre_compute_arms, _handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
237 let field_ty = &field.ty;
238 let pre_compute_size_arm = quote! {
239 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute_size(x)
240 };
241 let pre_compute_arm = quote! {
242 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::pre_compute(x, pc, instruction, data)
243 };
244 let handler_arm = quote! {
245 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>>::handler(x, pc, instruction, data)
246 };
247 let where_predicate = syn::parse_quote! {
248 #field_ty: ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic>
249 };
250 (pre_compute_size_arm, pre_compute_arm, handler_arm, where_predicate)
251 }));
252 let where_clause = new_generics.make_where_clause();
253 for predicate in where_predicates {
254 where_clause.predicates.push(predicate);
255 }
256 #[cfg(feature = "tco")]
259 let handler = quote! {
260 fn handler<Ctx>(
261 &self,
262 pc: u32,
263 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
264 data: &mut [u8],
265 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
266 where
267 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
268 match self {
269 #(#_handler_arms,)*
270 }
271 }
272 };
273 #[cfg(not(feature = "tco"))]
274 let handler = quote! {};
275
276 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
278
279 quote! {
280 impl #impl_generics ::openvm_circuit::arch::InterpreterExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
281 #[inline(always)]
282 fn pre_compute_size(&self) -> usize {
283 match self {
284 #(#pre_compute_size_arms,)*
285 }
286 }
287
288 #[cfg(not(feature = "tco"))]
289 #[inline(always)]
290 fn pre_compute<Ctx>(
291 &self,
292 pc: u32,
293 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
294 data: &mut [u8],
295 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
296 where
297 Ctx: ::openvm_circuit::arch::execution_mode::ExecutionCtxTrait, {
298 match self {
299 #(#pre_compute_arms,)*
300 }
301 }
302 #handler
303 }
304 }
305 .into()
306 }
307 Data::Union(_) => unimplemented!("Unions are not supported"),
308 }
309}
310
311#[proc_macro_derive(AotExecutor)]
312pub fn aot_executor_derive(input: TokenStream) -> TokenStream {
313 let ast: syn::DeriveInput = syn::parse(input).unwrap();
314
315 let name = &ast.ident;
316 let generics = &ast.generics;
317 let (_, ty_generics, _) = generics.split_for_impl();
318
319 match &ast.data {
320 Data::Struct(inner) => {
321 let inner_ty = match &inner.fields {
322 Fields::Unnamed(fields) => {
323 if fields.unnamed.len() != 1 {
324 panic!("Only one unnamed field is supported");
325 }
326 fields.unnamed.first().unwrap().ty.clone()
327 }
328 _ => panic!("Only unnamed fields are supported"),
329 };
330 let mut new_generics = generics.clone();
331 let where_clause = new_generics.make_where_clause();
332 where_clause
333 .predicates
334 .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotExecutor<F> });
335 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
336
337 quote! {
338 #[cfg(feature = "aot")]
339 impl #impl_generics ::openvm_circuit::arch::AotExecutor<F> for #name #ty_generics #where_clause {
340 #[inline(always)]
341 fn is_aot_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
342 self.0.is_aot_supported(inst)
343 }
344
345 fn generate_x86_asm(
346 &self,
347 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
348 pc: u32,
349 ) -> ::std::result::Result<
350 ::std::string::String,
351 ::openvm_circuit::arch::AotError,
352 > {
353 self.0.generate_x86_asm(inst, pc)
354 }
355 }
356 }
357 .into()
358 }
359 Data::Enum(e) => {
360 let variants = e
361 .variants
362 .iter()
363 .map(|variant| {
364 let variant_name = &variant.ident;
365 let mut fields = variant.fields.iter();
366 let field = fields.next().unwrap();
367 assert!(fields.next().is_none(), "Only one field is supported");
368 (variant_name, field)
369 })
370 .collect::<Vec<_>>();
371 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
372 let mut new_generics = generics.clone();
373 let first_ty_generic = ast
374 .generics
375 .params
376 .first()
377 .and_then(|param| match param {
378 GenericParam::Type(type_param) => Some(&type_param.ident),
379 _ => None,
380 })
381 .unwrap_or_else(|| {
382 new_generics.params.push(syn::parse_quote! { F });
383 &default_ty_generic
384 });
385 let (
386 is_aot_supported_arms,
387 generate_x86_asm_arms,
388 where_predicates,
389 ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
390 |(variant_name, field)| {
391 let field_ty = &field.ty;
392 let is_aot_supported_arm = quote! {
393 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::is_aot_supported(x, inst)
394 };
395 let generate_x86_asm_arm = quote! {
396 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotExecutor<#first_ty_generic>>::generate_x86_asm(
397 x,
398 inst,
399 pc,
400 )
401 };
402 let where_predicate =
403 syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotExecutor<#first_ty_generic> };
404 (
405 is_aot_supported_arm,
406 generate_x86_asm_arm,
407 where_predicate,
408 )
409 },
410 ));
411 let where_clause = new_generics.make_where_clause();
412 for predicate in where_predicates {
413 where_clause.predicates.push(predicate);
414 }
415 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
416
417 quote! {
418 #[cfg(feature = "aot")]
419 impl #impl_generics ::openvm_circuit::arch::AotExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
420 #[inline(always)]
421 fn is_aot_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
422 match self {
423 #(#is_aot_supported_arms,)*
424 }
425 }
426
427 fn generate_x86_asm(
428 &self,
429 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
430 pc: u32,
431 ) -> ::std::result::Result<
432 ::std::string::String,
433 ::openvm_circuit::arch::AotError,
434 > {
435 match self {
436 #(#generate_x86_asm_arms,)*
437 }
438 }
439 }
440 }
441 .into()
442 }
443 Data::Union(_) => unimplemented!("Unions are not supported"),
444 }
445}
446
447#[proc_macro_derive(MeteredExecutor)]
448pub fn metered_executor_derive(input: TokenStream) -> TokenStream {
449 let ast: syn::DeriveInput = syn::parse(input).unwrap();
450
451 let name = &ast.ident;
452 let generics = &ast.generics;
453 let (impl_generics, ty_generics, _) = generics.split_for_impl();
454
455 match &ast.data {
456 Data::Struct(inner) => {
457 let inner_ty = match &inner.fields {
459 Fields::Unnamed(fields) => {
460 if fields.unnamed.len() != 1 {
461 panic!("Only one unnamed field is supported");
462 }
463 fields.unnamed.first().unwrap().ty.clone()
464 }
465 _ => panic!("Only unnamed fields are supported"),
466 };
467 let mut new_generics = generics.clone();
470 let where_clause = new_generics.make_where_clause();
471 where_clause
472 .predicates
473 .push(syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<F> });
474
475 #[cfg(feature = "tco")]
478 let metered_handler = quote! {
479 fn metered_handler<Ctx>(
480 &self,
481 chip_idx: usize,
482 pc: u32,
483 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
484 data: &mut [u8],
485 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
486 where
487 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
488 self.0.metered_handler(chip_idx, pc, inst, data)
489 }
490 };
491 #[cfg(not(feature = "tco"))]
492 let metered_handler = quote! {};
493
494 quote! {
495 impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<F> for #name #ty_generics #where_clause {
496 #[inline(always)]
497 fn metered_pre_compute_size(&self) -> usize {
498 self.0.metered_pre_compute_size()
499 }
500 #[cfg(not(feature = "tco"))]
501 #[inline(always)]
502 fn metered_pre_compute<Ctx>(
503 &self,
504 chip_idx: usize,
505 pc: u32,
506 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
507 data: &mut [u8],
508 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
509 where
510 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
511 self.0.metered_pre_compute(chip_idx, pc, inst, data)
512 }
513 #metered_handler
514 }
515 }
516 .into()
517 }
518 Data::Enum(e) => {
519 let variants = e
520 .variants
521 .iter()
522 .map(|variant| {
523 let variant_name = &variant.ident;
524
525 let mut fields = variant.fields.iter();
526 let field = fields.next().unwrap();
527 assert!(fields.next().is_none(), "Only one field is supported");
528 (variant_name, field)
529 })
530 .collect::<Vec<_>>();
531 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
532 let mut new_generics = generics.clone();
533 let first_ty_generic = ast
534 .generics
535 .params
536 .first()
537 .and_then(|param| match param {
538 GenericParam::Type(type_param) => Some(&type_param.ident),
539 _ => None,
540 })
541 .unwrap_or_else(|| {
542 new_generics.params.push(syn::parse_quote! { F });
543 &default_ty_generic
544 });
545 let (pre_compute_size_arms, metered_pre_compute_arms, _metered_handler_arms, where_predicates): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(|(variant_name, field)| {
548 let field_ty = &field.ty;
549 let pre_compute_size_arm = quote! {
550 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute_size(x)
551 };
552 let metered_pre_compute_arm = quote! {
553 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_pre_compute(x, chip_idx, pc, instruction, data)
554 };
555 let metered_handler_arm = quote! {
556 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>>::metered_handler(x, chip_idx, pc, instruction, data)
557 };
558 let where_predicate = syn::parse_quote! {
559 #field_ty: ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic>
560 };
561 (pre_compute_size_arm, metered_pre_compute_arm, metered_handler_arm, where_predicate)
562 }));
563 let where_clause = new_generics.make_where_clause();
564 for predicate in where_predicates {
565 where_clause.predicates.push(predicate);
566 }
567 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
569
570 #[cfg(feature = "tco")]
573 let metered_handler = quote! {
574 fn metered_handler<Ctx>(
575 &self,
576 chip_idx: usize,
577 pc: u32,
578 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
579 data: &mut [u8],
580 ) -> Result<::openvm_circuit::arch::Handler<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
581 where
582 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait,
583 {
584 match self {
585 #(#_metered_handler_arms,)*
586 }
587 }
588 };
589 #[cfg(not(feature = "tco"))]
590 let metered_handler = quote! {};
591
592 quote! {
593 impl #impl_generics ::openvm_circuit::arch::InterpreterMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
594 #[inline(always)]
595 fn metered_pre_compute_size(&self) -> usize {
596 match self {
597 #(#pre_compute_size_arms,)*
598 }
599 }
600
601 #[cfg(not(feature = "tco"))]
602 #[inline(always)]
603 fn metered_pre_compute<Ctx>(
604 &self,
605 chip_idx: usize,
606 pc: u32,
607 instruction: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
608 data: &mut [u8],
609 ) -> Result<::openvm_circuit::arch::ExecuteFunc<F, Ctx>, ::openvm_circuit::arch::StaticProgramError>
610 where
611 Ctx: ::openvm_circuit::arch::execution_mode::MeteredExecutionCtxTrait, {
612 match self {
613 #(#metered_pre_compute_arms,)*
614 }
615 }
616
617 #metered_handler
618 }
619 }
620 .into()
621 }
622 Data::Union(_) => unimplemented!("Unions are not supported"),
623 }
624}
625
626#[proc_macro_derive(AotMeteredExecutor)]
627pub fn aot_metered_executor_derive(input: TokenStream) -> TokenStream {
628 let ast: syn::DeriveInput = syn::parse(input).unwrap();
629
630 let name = &ast.ident;
631 let generics = &ast.generics;
632 let (_, ty_generics, _) = generics.split_for_impl();
633
634 match &ast.data {
635 Data::Struct(inner) => {
636 let inner_ty = match &inner.fields {
637 Fields::Unnamed(fields) => {
638 if fields.unnamed.len() != 1 {
639 panic!("Only one unnamed field is supported");
640 }
641 fields.unnamed.first().unwrap().ty.clone()
642 }
643 _ => panic!("Only unnamed fields are supported"),
644 };
645 let mut new_generics = generics.clone();
646 let where_clause = new_generics.make_where_clause();
647 where_clause.predicates.push(
648 syn::parse_quote! { #inner_ty: ::openvm_circuit::arch::AotMeteredExecutor<F> },
649 );
650 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
651
652 quote! {
653 #[cfg(feature = "aot")]
654 impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<F> for #name #ty_generics #where_clause {
655 #[inline(always)]
656 fn is_aot_metered_supported(&self, inst: &::openvm_instructions::instruction::Instruction<F>) -> bool {
657 self.0.is_aot_metered_supported(inst)
658 }
659
660 fn generate_x86_metered_asm(
661 &self,
662 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>,
663 pc: u32,
664 chip_idx: usize,
665 config: &::openvm_circuit::arch::SystemConfig,
666 ) -> ::std::result::Result<
667 ::std::string::String,
668 ::openvm_circuit::arch::AotError,
669 > {
670 self.0.generate_x86_metered_asm(inst, pc, chip_idx, config)
671 }
672 }
673 }
674 .into()
675 }
676 Data::Enum(e) => {
677 let variants = e
678 .variants
679 .iter()
680 .map(|variant| {
681 let variant_name = &variant.ident;
682 let mut fields = variant.fields.iter();
683 let field = fields.next().unwrap();
684 assert!(fields.next().is_none(), "Only one field is supported");
685 (variant_name, field)
686 })
687 .collect::<Vec<_>>();
688 let default_ty_generic = Ident::new("F", proc_macro2::Span::call_site());
689 let mut new_generics = generics.clone();
690 let first_ty_generic = ast
691 .generics
692 .params
693 .first()
694 .and_then(|param| match param {
695 GenericParam::Type(type_param) => Some(&type_param.ident),
696 _ => None,
697 })
698 .unwrap_or_else(|| {
699 new_generics.params.push(syn::parse_quote! { F });
700 &default_ty_generic
701 });
702 let (
703 is_aot_metered_supported_arms,
704 generate_x86_metered_asm_arms,
705 where_predicates,
706 ): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(variants.iter().map(
707 |(variant_name, field)| {
708 let field_ty = &field.ty;
709 let is_aot_metered_supported_arm = quote! {
710 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::is_aot_metered_supported(x, inst)
711 };
712 let generate_x86_metered_asm_arm = quote! {
713 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic>>::generate_x86_metered_asm(
714 x,
715 inst,
716 pc,
717 chip_idx,
718 config,
719 )
720 };
721 let where_predicate =
722 syn::parse_quote! { #field_ty: ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> };
723 (
724 is_aot_metered_supported_arm,
725 generate_x86_metered_asm_arm,
726 where_predicate,
727 )
728 },
729 ));
730 let where_clause = new_generics.make_where_clause();
731 for predicate in where_predicates {
732 where_clause.predicates.push(predicate);
733 }
734 let (impl_generics, _, where_clause) = new_generics.split_for_impl();
735
736 quote! {
737 #[cfg(feature = "aot")]
738 impl #impl_generics ::openvm_circuit::arch::AotMeteredExecutor<#first_ty_generic> for #name #ty_generics #where_clause {
739 #[inline(always)]
740 fn is_aot_metered_supported(&self, inst: &::openvm_circuit::arch::instructions::instruction::Instruction<F>) -> bool {
741 match self {
742 #(#is_aot_metered_supported_arms,)*
743 }
744 }
745
746 fn generate_x86_metered_asm(
747 &self,
748 inst: &::openvm_circuit::arch::instructions::instruction::Instruction<#first_ty_generic>,
749 pc: u32,
750 chip_idx: usize,
751 config: &::openvm_circuit::arch::SystemConfig,
752 ) -> ::std::result::Result<
753 ::std::string::String,
754 ::openvm_circuit::arch::AotError,
755 > {
756 match self {
757 #(#generate_x86_metered_asm_arms,)*
758 }
759 }
760 }
761 }
762 .into()
763 }
764 Data::Union(_) => unimplemented!("Unions are not supported"),
765 }
766}
767
768#[proc_macro_derive(AnyEnum, attributes(any_enum))]
774pub fn any_enum_derive(input: TokenStream) -> TokenStream {
775 let ast: syn::DeriveInput = syn::parse(input).unwrap();
776
777 let name = &ast.ident;
778 let generics = &ast.generics;
779 let (impl_generics, ty_generics, _) = generics.split_for_impl();
780
781 match &ast.data {
782 Data::Enum(e) => {
783 let variants = e
784 .variants
785 .iter()
786 .map(|variant| {
787 let variant_name = &variant.ident;
788
789 let is_enum = variant
791 .attrs
792 .iter()
793 .any(|attr| attr.path().is_ident("any_enum"));
794 let mut fields = variant.fields.iter();
795 let field = fields.next().unwrap();
796 assert!(fields.next().is_none(), "Only one field is supported");
797 (variant_name, field, is_enum)
798 })
799 .collect::<Vec<_>>();
800 let (arms, arms_mut): (Vec<_>, Vec<_>) =
801 variants.iter().map(|(variant_name, field, is_enum)| {
802 let field_ty = &field.ty;
803
804 if *is_enum {
805 (quote! {
807 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind(x)
808 },
809 quote! {
810 #name::#variant_name(x) => <#field_ty as ::openvm_circuit::arch::AnyEnum>::as_any_kind_mut(x)
811 })
812 } else {
813 (quote! {
814 #name::#variant_name(x) => x
815 },
816 quote! {
817 #name::#variant_name(x) => x
818 })
819 }
820 }).unzip();
821 quote! {
822 impl #impl_generics ::openvm_circuit::arch::AnyEnum for #name #ty_generics {
823 fn as_any_kind(&self) -> &dyn std::any::Any {
824 match self {
825 #(#arms,)*
826 }
827 }
828
829 fn as_any_kind_mut(&mut self) -> &mut dyn std::any::Any {
830 match self {
831 #(#arms_mut,)*
832 }
833 }
834 }
835 }
836 .into()
837 }
838 _ => syn::Error::new(name.span(), "Only enums are supported")
839 .to_compile_error()
840 .into(),
841 }
842}
843
844#[proc_macro_derive(VmConfig, attributes(config, extension))]
845pub fn vm_generic_config_derive(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
846 let ast = syn::parse_macro_input!(input as syn::DeriveInput);
847 let name = &ast.ident;
848
849 match &ast.data {
850 syn::Data::Struct(inner) => match generate_config_traits_impl(name, inner) {
851 Ok(tokens) => tokens,
852 Err(err) => err.to_compile_error().into(),
853 },
854 _ => syn::Error::new(name.span(), "Only structs are supported")
855 .to_compile_error()
856 .into(),
857 }
858}
859
860fn generate_config_traits_impl(name: &Ident, inner: &DataStruct) -> syn::Result<TokenStream> {
861 let gen_name_with_uppercase_idents = |ident: &Ident| {
862 let mut name = ident.to_string().chars().collect::<Vec<_>>();
863 assert!(name[0].is_lowercase(), "Field name must not be capitalized");
864 let res_lower = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
865 name[0] = name[0].to_ascii_uppercase();
866 let res_upper = Ident::new(&name.iter().collect::<String>(), Span::call_site().into());
867 (res_lower, res_upper)
868 };
869
870 let fields = match &inner.fields {
871 Fields::Named(named) => named.named.iter().collect(),
872 Fields::Unnamed(_) => {
873 return Err(syn::Error::new(
874 name.span(),
875 "Only named fields are supported",
876 ))
877 }
878 Fields::Unit => vec![],
879 };
880
881 let source_field = fields
882 .iter()
883 .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("config")))
884 .exactly_one()
885 .map_err(|_| {
886 syn::Error::new(
887 name.span(),
888 "Exactly one field must have the #[config] attribute",
889 )
890 })?;
891 let (source_name, source_name_upper) =
892 gen_name_with_uppercase_idents(source_field.ident.as_ref().unwrap());
893
894 let extensions = fields
895 .iter()
896 .filter(|f| f.attrs.iter().any(|attr| attr.path().is_ident("extension")))
897 .cloned()
898 .collect::<Vec<_>>();
899
900 let mut executor_enum_fields = Vec::new();
901 let mut create_executors = Vec::new();
902 let mut create_airs = Vec::new();
903 let mut execution_where_predicates: Vec<syn::WherePredicate> = Vec::new();
904 let mut circuit_where_predicates: Vec<syn::WherePredicate> = Vec::new();
905
906 let source_field_ty = source_field.ty.clone();
907
908 for e in extensions.iter() {
909 let (ext_field_name, ext_name_upper) =
910 gen_name_with_uppercase_idents(e.ident.as_ref().expect("field must be named"));
911 let executor_type = parse_executor_type(e, false)?;
912 executor_enum_fields.push(quote! {
913 #[any_enum]
914 #ext_name_upper(#executor_type),
915 });
916 create_executors.push(quote! {
917 let inventory: ::openvm_circuit::arch::ExecutorInventory<Self::Executor> = inventory.extend::<F, _, _>(&self.#ext_field_name)?;
918 });
919 let extension_ty = e.ty.clone();
920 execution_where_predicates.push(parse_quote! {
921 #extension_ty: ::openvm_circuit::arch::VmExecutionExtension<F, Executor = #executor_type>
922 });
923 create_airs.push(quote! {
924 inventory.start_new_extension();
925 ::openvm_circuit::arch::VmCircuitExtension::extend_circuit(&self.#ext_field_name, &mut inventory)?;
926 });
927 circuit_where_predicates.push(parse_quote! {
928 #extension_ty: ::openvm_circuit::arch::VmCircuitExtension<SC>
929 });
930 }
931
932 let source_executor_type = parse_executor_type(source_field, true)?;
934 execution_where_predicates.push(parse_quote! {
935 #source_field_ty: ::openvm_circuit::arch::VmExecutionConfig<F, Executor = #source_executor_type>
936 });
937 circuit_where_predicates.push(parse_quote! {
938 #source_field_ty: ::openvm_circuit::arch::VmCircuitConfig<SC>
939 });
940 let execution_where_clause = quote! { where #(#execution_where_predicates),* };
941 let circuit_where_clause = quote! { where #(#circuit_where_predicates),* };
942
943 let executor_type = Ident::new(&format!("{name}Executor"), name.span());
944
945 let token_stream = TokenStream::from(quote! {
946 #[derive(
947 Clone,
948 ::derive_more::derive::From,
949 ::openvm_circuit::derive::AnyEnum,
950 ::openvm_circuit::derive::Executor,
951 ::openvm_circuit::derive::MeteredExecutor,
952 ::openvm_circuit::derive::PreflightExecutor,
953 )]
954 #[cfg_attr(feature = "aot", derive(::openvm_circuit::derive::AotExecutor, ::openvm_circuit::derive::AotMeteredExecutor))]
955 pub enum #executor_type<F: openvm_stark_backend::p3_field::Field> {
956 #[any_enum]
957 #source_name_upper(#source_executor_type),
958 #(#executor_enum_fields)*
959 }
960
961 impl<F: openvm_stark_backend::p3_field::Field> ::openvm_circuit::arch::VmExecutionConfig<F> for #name #execution_where_clause {
962 type Executor = #executor_type<F>;
963
964 fn create_executors(
965 &self,
966 ) -> Result<::openvm_circuit::arch::ExecutorInventory<Self::Executor>, ::openvm_circuit::arch::ExecutorInventoryError> {
967 let inventory = self.#source_name.create_executors()?.transmute::<Self::Executor>();
968 #(#create_executors)*
969 Ok(inventory)
970 }
971 }
972
973 impl<SC: openvm_stark_backend::config::StarkGenericConfig> ::openvm_circuit::arch::VmCircuitConfig<SC> for #name #circuit_where_clause {
974 fn create_airs(
975 &self,
976 ) -> Result<::openvm_circuit::arch::AirInventory<SC>, ::openvm_circuit::arch::AirInventoryError> {
977 let mut inventory = self.#source_name.create_airs()?;
978 #(#create_airs)*
979 Ok(inventory)
980 }
981 }
982
983 impl AsRef<SystemConfig> for #name {
984 fn as_ref(&self) -> &SystemConfig {
985 self.#source_name.as_ref()
986 }
987 }
988
989 impl AsMut<SystemConfig> for #name {
990 fn as_mut(&mut self) -> &mut SystemConfig {
991 self.#source_name.as_mut()
992 }
993 }
994 });
995 Ok(token_stream)
996}
997
998fn parse_executor_type(
1002 f: &Field,
1003 default_needs_generics: bool,
1004) -> syn::Result<proc_macro2::TokenStream> {
1005 let mut executor_type = None;
1008 let executor_name = syn::parse_str::<Ident>(&format!("{}Executor", f.ty.to_token_stream()));
1010
1011 if let Some(attr) = f
1012 .attrs
1013 .iter()
1014 .find(|attr| attr.path().is_ident("extension") || attr.path().is_ident("config"))
1015 {
1016 match attr.meta {
1017 Meta::Path(_) => {}
1018 Meta::NameValue(_) => {
1019 return Err(syn::Error::new(
1020 f.ty.span(),
1021 "Only `#[config]`, `#[extension]`, `#[config(...)]` or `#[extension(...)]` formats are supported",
1022 ))
1023 }
1024 _ => {
1025 let nested = attr
1026 .parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated)?;
1027 for meta in nested {
1028 match meta {
1029 Meta::NameValue(nv) => {
1030 if nv.path.is_ident("executor") {
1031 executor_type = match nv.value {
1032 syn::Expr::Lit(syn::ExprLit {
1033 lit: syn::Lit::Str(lit_str), ..
1034 }) => {
1035 let executor_type: syn::Type = syn::parse_str(&lit_str.value())?;
1036 Some(quote! { #executor_type })
1037 },
1038 syn::Expr::Path(path) => {
1039 Some(path.to_token_stream())
1041 },
1042 _ => {
1043 return Err(syn::Error::new(
1044 nv.value.span(),
1045 "executor value must be a string literal or identifier"
1046 ));
1047 }
1048 };
1049 } else if nv.path.is_ident("generics") {
1050 let value_str = nv.value.to_token_stream().to_string();
1052 let needs_generics = match value_str.as_str() {
1053 "true" => true,
1054 "false" => false,
1055 _ => return Err(syn::Error::new(
1056 nv.value.span(),
1057 "generics attribute must be either true or false"
1058 ))
1059 };
1060 let executor_name = executor_name.clone()?;
1061 executor_type = Some(if needs_generics {
1062 quote! { #executor_name<F> }
1063 } else {
1064 quote! { #executor_name }
1065 });
1066 } else {
1067 return Err(syn::Error::new(nv.span(), "only executor and generics keys are supported"));
1068 }
1069 }
1070 _ => {
1071 return Err(syn::Error::new(meta.span(), "only name = value format is supported"));
1072 }
1073 }
1074 }
1075 }
1076 }
1077 }
1078 if let Some(executor_type) = executor_type {
1079 Ok(executor_type)
1080 } else {
1081 let executor_name = executor_name?;
1082 Ok(if default_needs_generics {
1083 quote! { #executor_name<F> }
1084 } else {
1085 quote! { #executor_name }
1086 })
1087 }
1088}
1089
1090#[proc_macro_attribute]
1120pub fn create_handler(_attr: TokenStream, item: TokenStream) -> TokenStream {
1121 #[cfg(feature = "tco")]
1122 {
1123 tco::tco_impl(item)
1124 }
1125 #[cfg(not(feature = "tco"))]
1126 {
1127 nontco::nontco_impl(item)
1128 }
1129}