openvm_circuit_primitives_derive/
lib.rs

1// AlignedBorrow is copied from valida-derive under MIT license
2extern crate alloc;
3extern crate proc_macro;
4
5use itertools::multiunzip;
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta};
9
10#[proc_macro_derive(AlignedBorrow)]
11pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
12    let ast = parse_macro_input!(input as DeriveInput);
13    let name = &ast.ident;
14
15    // Get first generic which must be type (ex. `T`) for input <T, N: NumLimbs, const M: usize>
16    let type_generic = ast
17        .generics
18        .params
19        .iter()
20        .map(|param| match param {
21            GenericParam::Type(type_param) => &type_param.ident,
22            _ => panic!("Expected first generic to be a type"),
23        })
24        .next()
25        .expect("Expected at least one generic");
26
27    // Get generics after the first (ex. `N: NumLimbs, const M: usize`)
28    // We need this because when we assert the size, we want to substitute u8 for T.
29    let non_first_generics = ast
30        .generics
31        .params
32        .iter()
33        .skip(1)
34        .filter_map(|param| match param {
35            GenericParam::Type(type_param) => Some(&type_param.ident),
36            GenericParam::Const(const_param) => Some(&const_param.ident),
37            _ => None,
38        })
39        .collect::<Vec<_>>();
40
41    // Get impl generics (`<T, N: NumLimbs, const M: usize>`), type generics (`<T, N>`), where clause (`where T: Clone`)
42    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
43
44    let methods = quote! {
45        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
46            fn borrow(&self) -> &#name #type_generics {
47                debug_assert_eq!(self.len(), #name::#type_generics::width());
48                let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
49                debug_assert!(prefix.is_empty(), "Alignment should match");
50                debug_assert_eq!(shorts.len(), 1);
51                &shorts[0]
52            }
53        }
54
55        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
56            fn borrow_mut(&mut self) -> &mut #name #type_generics {
57                debug_assert_eq!(self.len(), #name::#type_generics::width());
58                let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
59                debug_assert!(prefix.is_empty(), "Alignment should match");
60                debug_assert_eq!(shorts.len(), 1);
61                &mut shorts[0]
62            }
63        }
64
65        impl #impl_generics #name #type_generics {
66            pub const fn width() -> usize {
67                std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>()
68            }
69        }
70    };
71
72    TokenStream::from(methods)
73}
74
75#[proc_macro_derive(Chip, attributes(chip))]
76pub fn chip_derive(input: TokenStream) -> TokenStream {
77    // Parse the attributes from the struct or enum
78    let ast: syn::DeriveInput = syn::parse(input).unwrap();
79
80    let name = &ast.ident;
81    let generics = &ast.generics;
82    let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
83
84    match &ast.data {
85        Data::Struct(inner) => {
86            let generics = &ast.generics;
87            let mut new_generics = generics.clone();
88            new_generics
89                .params
90                .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig });
91            let (impl_generics, _, _) = new_generics.split_for_impl();
92
93            // Check if the struct has only one unnamed field
94            let inner_ty = match &inner.fields {
95                Fields::Unnamed(fields) => {
96                    if fields.unnamed.len() != 1 {
97                        panic!("Only one unnamed field is supported");
98                    }
99                    fields.unnamed.first().unwrap().ty.clone()
100                }
101                _ => panic!("Only unnamed fields are supported"),
102            };
103            let mut new_generics = generics.clone();
104            let where_clause = new_generics.make_where_clause();
105            where_clause
106                .predicates
107                .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip<SC> });
108            quote! {
109                impl #impl_generics openvm_stark_backend::Chip<SC> for #name #ty_generics #where_clause {
110                    fn air(&self) -> openvm_stark_backend::AirRef<SC> {
111                        self.0.air()
112                    }
113                    fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput<SC> {
114                        self.0.generate_air_proof_input()
115                    }
116                    fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput<SC>) {
117                        self.0.generate_air_proof_input_with_id(air_id)
118                    }
119                }
120            }.into()
121        }
122        Data::Enum(e) => {
123            let variants = e
124                .variants
125                .iter()
126                .map(|variant| {
127                    let variant_name = &variant.ident;
128
129                    let mut fields = variant.fields.iter();
130                    let field = fields.next().unwrap();
131                    assert!(fields.next().is_none(), "Only one field is supported");
132                    (variant_name, field)
133                })
134                .collect::<Vec<_>>();
135
136            let (air_arms, generate_air_proof_input_arms, generate_air_proof_input_with_id_arms): (Vec<_>, Vec<_>, Vec<_>) =
137                multiunzip(variants.iter().map(|(variant_name, field)| {
138                let field_ty = &field.ty;
139                let air_arm = quote! {
140                    #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::air(x)
141                };
142                let generate_air_proof_input_arm = quote! {
143                    #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::generate_air_proof_input(x)
144                };
145                let generate_air_proof_input_with_id_arm = quote! {
146                    #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::generate_air_proof_input_with_id(x, air_id)
147                };
148                (air_arm, generate_air_proof_input_arm, generate_air_proof_input_with_id_arm)
149            }));
150
151            // Attach an extra generic SC: StarkGenericConfig to the impl_generics
152            let generics = &ast.generics;
153            let mut new_generics = generics.clone();
154            new_generics
155                .params
156                .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig });
157            let (impl_generics, _, _) = new_generics.split_for_impl();
158
159            // Implement Chip whenever the inner type implements Chip
160            let mut new_generics = generics.clone();
161            let where_clause = new_generics.make_where_clause();
162            where_clause.predicates.push(syn::parse_quote! { openvm_stark_backend::config::Domain<SC>: openvm_stark_backend::p3_commit::PolynomialSpace<Val = F>
163            });
164            let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip"));
165            if let Some(attr) = attributes {
166                let mut fail_flag = false;
167
168                match &attr.meta {
169                    Meta::List(meta_list) => {
170                        meta_list
171                            .parse_nested_meta(|meta| {
172                                if meta.path.is_ident("where") {
173                                    let value = meta.value()?; // this parses the `=`
174                                    let s: LitStr = value.parse()?;
175                                    let where_value = s.value();
176                                    where_clause.predicates.push(syn::parse_str(&where_value)?);
177                                } else {
178                                    fail_flag = true;
179                                }
180                                Ok(())
181                            })
182                            .unwrap();
183                    }
184                    _ => fail_flag = true,
185                }
186                if fail_flag {
187                    return syn::Error::new(
188                        name.span(),
189                        "Only `#[chip(where = ...)]` format is supported",
190                    )
191                    .to_compile_error()
192                    .into();
193                }
194            }
195
196            quote! {
197                impl #impl_generics openvm_stark_backend::Chip<SC> for #name #ty_generics #where_clause {
198                    fn air(&self) -> openvm_stark_backend::AirRef<SC> {
199                        match self {
200                            #(#air_arms,)*
201                        }
202                    }
203                    fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput<SC> {
204                        match self {
205                            #(#generate_air_proof_input_arms,)*
206                        }
207                    }
208                    fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput<SC>) {
209                        match self {
210                            #(#generate_air_proof_input_with_id_arms,)*
211                        }
212                    }
213                }
214            }.into()
215        }
216        Data::Union(_) => unimplemented!("Unions are not supported"),
217    }
218}
219
220#[proc_macro_derive(ChipUsageGetter)]
221pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream {
222    let ast: syn::DeriveInput = syn::parse(input).unwrap();
223
224    let name = &ast.ident;
225    let generics = &ast.generics;
226    let (impl_generics, ty_generics, _) = generics.split_for_impl();
227
228    match &ast.data {
229        Data::Struct(inner) => {
230            // Check if the struct has only one unnamed field
231            let inner_ty = match &inner.fields {
232                Fields::Unnamed(fields) => {
233                    if fields.unnamed.len() != 1 {
234                        panic!("Only one unnamed field is supported");
235                    }
236                    fields.unnamed.first().unwrap().ty.clone()
237                }
238                _ => panic!("Only unnamed fields are supported"),
239            };
240            // Implement ChipUsageGetter whenever the inner type implements ChipUsageGetter
241            let mut new_generics = generics.clone();
242            let where_clause = new_generics.make_where_clause();
243            where_clause
244                .predicates
245                .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::ChipUsageGetter });
246            quote! {
247                impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics #where_clause {
248                    fn air_name(&self) -> String {
249                        self.0.air_name()
250                    }
251                    fn constant_trace_height(&self) -> Option<usize> {
252                        self.0.constant_trace_height()
253                    }
254                    fn current_trace_height(&self) -> usize {
255                        self.0.current_trace_height()
256                    }
257                    fn trace_width(&self) -> usize {
258                        self.0.trace_width()
259                    }
260                }
261            }
262            .into()
263        }
264        Data::Enum(e) => {
265            let (air_name_arms, constant_trace_height_arms, current_trace_height_arms, trace_width_arms): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
266                multiunzip(e.variants.iter().map(|variant| {
267                    let variant_name = &variant.ident;
268                    let air_name_arm = quote! {
269                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::air_name(x)
270                };
271                    let constant_trace_height_arm = quote! {
272                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::constant_trace_height(x)
273                };
274                    let current_trace_height_arm = quote! {
275                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::current_trace_height(x)
276                };
277                    let trace_width_arm = quote! {
278                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::trace_width(x)
279                };
280                    (air_name_arm, constant_trace_height_arm, current_trace_height_arm, trace_width_arm)
281                }));
282
283            quote! {
284                impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics {
285                    fn air_name(&self) -> String {
286                        match self {
287                            #(#air_name_arms,)*
288                        }
289                    }
290                    fn constant_trace_height(&self) -> Option<usize> {
291                        match self {
292                            #(#constant_trace_height_arms,)*
293                        }
294                    }
295                    fn current_trace_height(&self) -> usize {
296                        match self {
297                            #(#current_trace_height_arms,)*
298                        }
299                    }
300                    fn trace_width(&self) -> usize {
301                        match self {
302                            #(#trace_width_arms,)*
303                        }
304                    }
305
306                }
307            }
308            .into()
309        }
310        Data::Union(_) => unimplemented!("Unions are not supported"),
311    }
312}
313
314#[proc_macro_derive(BytesStateful)]
315pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
316    let ast: syn::DeriveInput = syn::parse(input).unwrap();
317
318    let name = &ast.ident;
319    let generics = &ast.generics;
320    let (impl_generics, ty_generics, _) = generics.split_for_impl();
321
322    match &ast.data {
323        Data::Struct(inner) => {
324            // Check if the struct has only one unnamed field
325            let inner_ty = match &inner.fields {
326                Fields::Unnamed(fields) => {
327                    if fields.unnamed.len() != 1 {
328                        panic!("Only one unnamed field is supported");
329                    }
330                    fields.unnamed.first().unwrap().ty.clone()
331                }
332                _ => panic!("Only unnamed fields are supported"),
333            };
334            // Use full path ::openvm_circuit... so it can be used either within or outside the vm crate.
335            // Assume F is already generic of the field.
336            let mut new_generics = generics.clone();
337            let where_clause = new_generics.make_where_clause();
338            where_clause
339                .predicates
340                .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful<Vec<u8>> });
341
342            quote! {
343                impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics #where_clause {
344                    fn load_state(&mut self, state: Vec<u8>) {
345                        self.0.load_state(state)
346                    }
347
348                    fn store_state(&self) -> Vec<u8> {
349                        self.0.store_state()
350                    }
351                }
352            }
353            .into()
354        }
355        Data::Enum(e) => {
356            let variants = e
357                .variants
358                .iter()
359                .map(|variant| {
360                    let variant_name = &variant.ident;
361
362                    let mut fields = variant.fields.iter();
363                    let field = fields.next().unwrap();
364                    assert!(fields.next().is_none(), "Only one field is supported");
365                    (variant_name, field)
366                })
367                .collect::<Vec<_>>();
368            // Use full path ::openvm_stark_backend... so it can be used either within or outside the vm crate.
369            let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) =
370                multiunzip(variants.iter().map(|(variant_name, field)| {
371                    let field_ty = &field.ty;
372                    let load_state_arm = quote! {
373                        #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::load_state(x, state)
374                    };
375                    let store_state_arm = quote! {
376                        #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::store_state(x)
377                    };
378
379                    (load_state_arm, store_state_arm)
380                }));
381            quote! {
382                impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics {
383                    fn load_state(&mut self, state: Vec<u8>) {
384                        match self {
385                            #(#load_state_arms,)*
386                        }
387                    }
388
389                    fn store_state(&self) -> Vec<u8> {
390                        match self {
391                            #(#store_state_arms,)*
392                        }
393                    }
394                }
395            }
396            .into()
397        }
398        _ => unimplemented!(),
399    }
400}