openvm_native_compiler_derive/
lib.rs

1// Initial version copied from sp1-derive under MIT license
2extern crate alloc;
3extern crate proc_macro;
4
5use proc_macro::TokenStream;
6use quote::quote;
7use syn::{
8    parse::{Parse, ParseStream},
9    parse_macro_input,
10    punctuated::Punctuated,
11    Data, DeriveInput, Expr, Fields, GenericParam, Generics, Token, TypeParamBound,
12};
13
14/// Returns true if the generic parameter C: Config exists.
15pub(crate) fn has_config_generic(generics: &Generics) -> bool {
16    generics.params.iter().any(|param| match param {
17        GenericParam::Type(ty) => {
18            ty.ident == "C"
19                && ty.bounds.iter().any(|b| match b {
20                    TypeParamBound::Trait(tr) => tr.path.segments.last().unwrap().ident == "Config",
21                    _ => false,
22                })
23        }
24        _ => false,
25    })
26}
27
28#[proc_macro_derive(DslVariable)]
29pub fn derive_variable(input: TokenStream) -> TokenStream {
30    let input = parse_macro_input!(input as DeriveInput);
31    let name = input.ident; // Struct name
32    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
33    assert!(
34        has_config_generic(&input.generics),
35        "DslVariable requires a generic parameter C: Config"
36    );
37
38    let gen = match input.data {
39        Data::Struct(data) => match data.fields {
40            Fields::Named(fields) => {
41                let fields_init = fields.named.iter().map(|f| {
42                    let fname = &f.ident;
43                    let ftype = &f.ty;
44                    let ftype_str = quote! { #ftype }.to_string();
45                    if ftype_str.contains("Array") {
46                        quote! {
47                            #fname: if builder.flags.static_only {
48                                builder.uninit_fixed_array(0)
49                            } else {
50                                Array::Dyn(builder.uninit(), builder.uninit())
51                            },
52                        }
53                    } else {
54                        quote! {
55                            #fname: <#ftype as Variable<C>>::uninit(builder),
56                        }
57                    }
58                });
59
60                let fields_assign = fields.named.iter().map(|f| {
61                    let fname = &f.ident;
62                    quote! {
63                        self.#fname.assign(src.#fname.into(), builder);
64                    }
65                });
66
67                let fields_assert_eq = fields.named.iter().map(|f| {
68                    let fname = &f.ident;
69                    let ftype = &f.ty;
70                    quote! {
71                        <#ftype as Variable<C>>::assert_eq(lhs.#fname, rhs.#fname, builder);
72                    }
73                });
74
75                let field_sizes = fields.named.iter().map(|f| {
76                    let ftype = &f.ty;
77                    quote! {
78                        <#ftype as MemVariable<C>>::size_of()
79                    }
80                });
81
82                let field_loads = fields.named.iter().map(|f| {
83                    let fname = &f.ident;
84                    let ftype = &f.ty;
85                    quote! {
86                        {
87                            // let address = builder.eval(ptr + Usize::Const(offset));
88                            self.#fname.load(ptr, index, builder);
89                            index.offset += <#ftype as MemVariable<C>>::size_of();
90                        }
91                    }
92                });
93
94                let field_stores = fields.named.iter().map(|f| {
95                    let fname = &f.ident;
96                    let ftype = &f.ty;
97                    quote! {
98                        {
99                            // let address = builder.eval(ptr + Usize::Const(offset));
100                            self.#fname.store(ptr, index, builder);
101                            index.offset += <#ftype as MemVariable<C>>::size_of();
102                        }
103                    }
104                });
105
106                quote! {
107                    impl #impl_generics Variable<C> for #name #ty_generics #where_clause {
108                        type Expression = Self;
109
110                        fn uninit(builder: &mut Builder<C>) -> Self {
111                            Self {
112                                #(#fields_init)*
113                            }
114                        }
115
116                        fn assign(&self, src: Self::Expression, builder: &mut Builder<C>) {
117                            #(#fields_assign)*
118                        }
119
120                        fn assert_eq(
121                            lhs: impl Into<Self::Expression>,
122                            rhs: impl Into<Self::Expression>,
123                            builder: &mut Builder<C>,
124                        ) {
125                            let lhs = lhs.into();
126                            let rhs = rhs.into();
127                            #(#fields_assert_eq)*
128                        }
129                    }
130
131                    impl #impl_generics MemVariable<C> for #name #ty_generics #where_clause {
132                        fn size_of() -> usize {
133                            let mut size = 0;
134                            #(size += #field_sizes;)*
135                            size
136                        }
137
138                        fn load(&self, ptr: Ptr<<C as Config>::N>,
139                            index: MemIndex<<C as Config>::N>,
140                            builder: &mut Builder<C>) {
141                            let mut index = index;
142                            #(#field_loads)*
143                        }
144
145                        fn store(&self, ptr: Ptr<<C as Config>::N>,
146                                 index: MemIndex<<C as Config>::N>,
147                                builder: &mut Builder<C>) {
148                            let mut index = index;
149                            #(#field_stores)*
150                        }
151                    }
152                }
153            }
154            _ => unimplemented!(),
155        },
156        _ => unimplemented!(),
157    };
158
159    gen.into()
160}
161
162struct IterZipArgs {
163    builder: Expr,
164    args: Punctuated<Expr, Token![,]>,
165}
166
167impl Parse for IterZipArgs {
168    fn parse(input: ParseStream) -> syn::Result<Self> {
169        let builder = input.parse()?;
170        let _: Token![,] = input.parse()?;
171        let args = Punctuated::parse_terminated(input)?;
172
173        Ok(IterZipArgs { builder, args })
174    }
175}
176
177#[proc_macro]
178pub fn iter_zip(input: TokenStream) -> TokenStream {
179    let IterZipArgs { builder, args } = parse_macro_input!(input as IterZipArgs);
180    let array_elements = args.iter().map(|arg| {
181        quote! {
182            Box::new(#arg.clone()) as Box<dyn ArrayLike<_>>
183        }
184    });
185
186    let expanded = quote! {
187        #builder.zip(&[
188            #(#array_elements),*
189        ])
190    };
191
192    expanded.into()
193}