openvm_native_compiler_derive/
lib.rs
1extern crate alloc;
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8 parse::{Parse, ParseStream},
9 parse_macro_input,
10 punctuated::Punctuated,
11 Data, DeriveInput, Expr, Fields, GenericParam, Generics, Token, TypeParamBound,
12};
13
14pub(crate) fn has_config_generic(generics: &Generics) -> bool {
16 generics.params.iter().any(|param| match param {
17 GenericParam::Type(ty) => {
18 ty.ident == "C"
19 && ty.bounds.iter().any(|b| match b {
20 TypeParamBound::Trait(tr) => tr.path.segments.last().unwrap().ident == "Config",
21 _ => false,
22 })
23 }
24 _ => false,
25 })
26}
27
28#[proc_macro_derive(DslVariable)]
29pub fn derive_variable(input: TokenStream) -> TokenStream {
30 let input = parse_macro_input!(input as DeriveInput);
31 let name = input.ident; let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33 assert!(
34 has_config_generic(&input.generics),
35 "DslVariable requires a generic parameter C: Config"
36 );
37
38 let gen = match input.data {
39 Data::Struct(data) => match data.fields {
40 Fields::Named(fields) => {
41 let fields_init = fields.named.iter().map(|f| {
42 let fname = &f.ident;
43 let ftype = &f.ty;
44 let ftype_str = quote! { #ftype }.to_string();
45 if ftype_str.contains("Array") {
46 quote! {
47 #fname: if builder.flags.static_only {
48 builder.uninit_fixed_array(0)
49 } else {
50 Array::Dyn(builder.uninit(), builder.uninit())
51 },
52 }
53 } else {
54 quote! {
55 #fname: <#ftype as Variable<C>>::uninit(builder),
56 }
57 }
58 });
59
60 let fields_assign = fields.named.iter().map(|f| {
61 let fname = &f.ident;
62 quote! {
63 self.#fname.assign(src.#fname.into(), builder);
64 }
65 });
66
67 let fields_assert_eq = fields.named.iter().map(|f| {
68 let fname = &f.ident;
69 let ftype = &f.ty;
70 quote! {
71 <#ftype as Variable<C>>::assert_eq(lhs.#fname, rhs.#fname, builder);
72 }
73 });
74
75 let field_sizes = fields.named.iter().map(|f| {
76 let ftype = &f.ty;
77 quote! {
78 <#ftype as MemVariable<C>>::size_of()
79 }
80 });
81
82 let field_loads = fields.named.iter().map(|f| {
83 let fname = &f.ident;
84 let ftype = &f.ty;
85 quote! {
86 {
87 self.#fname.load(ptr, index, builder);
89 index.offset += <#ftype as MemVariable<C>>::size_of();
90 }
91 }
92 });
93
94 let field_stores = fields.named.iter().map(|f| {
95 let fname = &f.ident;
96 let ftype = &f.ty;
97 quote! {
98 {
99 self.#fname.store(ptr, index, builder);
101 index.offset += <#ftype as MemVariable<C>>::size_of();
102 }
103 }
104 });
105
106 quote! {
107 impl #impl_generics Variable<C> for #name #ty_generics #where_clause {
108 type Expression = Self;
109
110 fn uninit(builder: &mut Builder<C>) -> Self {
111 Self {
112 #(#fields_init)*
113 }
114 }
115
116 fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
117 #(#fields_assign)*
118 }
119
120 fn assert_eq(
121 lhs: impl Into<Self::Expression>,
122 rhs: impl Into<Self::Expression>,
123 builder: &mut Builder<C>,
124 ) {
125 let lhs = lhs.into();
126 let rhs = rhs.into();
127 #(#fields_assert_eq)*
128 }
129 }
130
131 impl #impl_generics MemVariable<C> for #name #ty_generics #where_clause {
132 fn size_of() -> usize {
133 let mut size = 0;
134 #(size += #field_sizes;)*
135 size
136 }
137
138 fn load(&self, ptr: Ptr<<C as Config>::N>,
139 index: MemIndex<<C as Config>::N>,
140 builder: &mut Builder<C>) {
141 let mut index = index;
142 #(#field_loads)*
143 }
144
145 fn store(&self, ptr: Ptr<<C as Config>::N>,
146 index: MemIndex<<C as Config>::N>,
147 builder: &mut Builder<C>) {
148 let mut index = index;
149 #(#field_stores)*
150 }
151 }
152 }
153 }
154 _ => unimplemented!(),
155 },
156 _ => unimplemented!(),
157 };
158
159 gen.into()
160}
161
162struct IterZipArgs {
163 builder: Expr,
164 args: Punctuated<Expr, Token![,]>,
165}
166
167impl Parse for IterZipArgs {
168 fn parse(input: ParseStream) -> syn::Result<Self> {
169 let builder = input.parse()?;
170 let _: Token![,] = input.parse()?;
171 let args = Punctuated::parse_terminated(input)?;
172
173 Ok(IterZipArgs { builder, args })
174 }
175}
176
177#[proc_macro]
178pub fn iter_zip(input: TokenStream) -> TokenStream {
179 let IterZipArgs { builder, args } = parse_macro_input!(input as IterZipArgs);
180 let array_elements = args.iter().map(|arg| {
181 quote! {
182 Box::new(#arg.clone()) as Box<dyn ArrayLike<_>>
183 }
184 });
185
186 let expanded = quote! {
187 #builder.zip(&[
188 #(#array_elements),*
189 ])
190 };
191
192 expanded.into()
193}