openvm_circuit_primitives_derive/
lib.rs

1// AlignedBorrow is copied from valida-derive under MIT license
2extern crate alloc;
3extern crate proc_macro;
4
5use itertools::multiunzip;
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta};
9
10#[proc_macro_derive(AlignedBorrow)]
11pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
12    let ast = parse_macro_input!(input as DeriveInput);
13    let name = &ast.ident;
14
15    // Get first generic which must be type (ex. `T`) for input <T, N: NumLimbs, const M: usize>
16    let type_generic = ast
17        .generics
18        .params
19        .iter()
20        .map(|param| match param {
21            GenericParam::Type(type_param) => &type_param.ident,
22            _ => panic!("Expected first generic to be a type"),
23        })
24        .next()
25        .expect("Expected at least one generic");
26
27    // Get generics after the first (ex. `N: NumLimbs, const M: usize`)
28    // We need this because when we assert the size, we want to substitute u8 for T.
29    let non_first_generics = ast
30        .generics
31        .params
32        .iter()
33        .skip(1)
34        .filter_map(|param| match param {
35            GenericParam::Type(type_param) => Some(&type_param.ident),
36            GenericParam::Const(const_param) => Some(&const_param.ident),
37            _ => None,
38        })
39        .collect::<Vec<_>>();
40
41    // Get impl generics (`<T, N: NumLimbs, const M: usize>`), type generics (`<T, N>`), where
42    // clause (`where T: Clone`)
43    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
44
45    let methods = quote! {
46        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
47            fn borrow(&self) -> &#name #type_generics {
48                debug_assert_eq!(self.len(), #name::#type_generics::width());
49                let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
50                debug_assert!(prefix.is_empty(), "Alignment should match");
51                debug_assert_eq!(shorts.len(), 1);
52                &shorts[0]
53            }
54        }
55
56        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
57            fn borrow_mut(&mut self) -> &mut #name #type_generics {
58                debug_assert_eq!(self.len(), #name::#type_generics::width());
59                let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
60                debug_assert!(prefix.is_empty(), "Alignment should match");
61                debug_assert_eq!(shorts.len(), 1);
62                &mut shorts[0]
63            }
64        }
65
66        impl #impl_generics #name #type_generics {
67            pub const fn width() -> usize {
68                std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>()
69            }
70        }
71    };
72
73    TokenStream::from(methods)
74}
75
76/// `S` is the type the derive macro is being called on
77/// Implements Borrow<S> and BorrowMut<S> for [u8]
78/// [u8] has to have (checked via `debug_assert!`s)
79/// - at least size_of(S) length
80/// - at least align_of(S) alignment
81#[proc_macro_derive(AlignedBytesBorrow)]
82pub fn aligned_bytes_borrow_derive(input: TokenStream) -> TokenStream {
83    let ast = parse_macro_input!(input as DeriveInput);
84    let name = &ast.ident;
85
86    // Get impl generics, type generics, where clause
87    // Note, need to add the new type generic to the `impl_generics`
88    let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
89
90    let methods = quote! {
91        impl #impl_generics core::borrow::Borrow<#name #type_generics> for [u8]
92        where
93            #where_clause
94        {
95            fn borrow(&self) -> &#name #type_generics {
96                use core::mem::{align_of, size_of_val};
97                debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>());
98                debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0);
99                unsafe { &*(self.as_ptr() as *const #name #type_generics) }
100            }
101        }
102
103        impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [u8]
104        where
105            #where_clause
106        {
107            fn borrow_mut(&mut self) -> &mut #name #type_generics {
108                use core::mem::{align_of, size_of_val};
109                debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>());
110                debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0);
111                unsafe { &mut *(self.as_mut_ptr() as *mut #name #type_generics) }
112            }
113        }
114    };
115
116    TokenStream::from(methods)
117}
118
119#[proc_macro_derive(Chip, attributes(chip))]
120pub fn chip_derive(input: TokenStream) -> TokenStream {
121    // Parse the attributes from the struct or enum
122    let ast: syn::DeriveInput = syn::parse(input).unwrap();
123
124    let name = &ast.ident;
125    let generics = &ast.generics;
126    let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
127
128    match &ast.data {
129        Data::Struct(inner) => {
130            let generics = &ast.generics;
131            let mut new_generics = generics.clone();
132            new_generics.params.push(syn::parse_quote! { R });
133            new_generics
134                .params
135                .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend });
136            let (impl_generics, _, _) = new_generics.split_for_impl();
137
138            // Check if the struct has only one unnamed field
139            let inner_ty = match &inner.fields {
140                Fields::Unnamed(fields) => {
141                    if fields.unnamed.len() != 1 {
142                        panic!("Only one unnamed field is supported");
143                    }
144                    fields.unnamed.first().unwrap().ty.clone()
145                }
146                _ => panic!("Only unnamed fields are supported"),
147            };
148            let mut new_generics = generics.clone();
149            let where_clause = new_generics.make_where_clause();
150            where_clause
151                .predicates
152                .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip<R, PB> });
153            quote! {
154                impl #impl_generics openvm_stark_backend::Chip<R, PB> for #name #ty_generics #where_clause {
155                    fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext<PB> {
156                        self.0.generate_proving_ctx(records)
157                    }
158                }
159            }.into()
160        }
161        Data::Enum(e) => {
162            let variants = e
163                .variants
164                .iter()
165                .map(|variant| {
166                    let variant_name = &variant.ident;
167
168                    let mut fields = variant.fields.iter();
169                    let field = fields.next().unwrap();
170                    assert!(fields.next().is_none(), "Only one field is supported");
171                    (variant_name, field)
172                })
173                .collect::<Vec<_>>();
174
175            let (generate_proving_ctx_arms, where_predicates): (Vec<_>, Vec<_>) =
176                variants.iter().map(|(variant_name, field)| {
177                let field_ty = &field.ty;
178                let generate_proving_ctx_arm = quote! {
179                    #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<R, PB>>::generate_proving_ctx(x, records)
180                };
181                let where_predicate =
182                    syn::parse_quote! { #field_ty: openvm_stark_backend::Chip<R, PB> };
183                (generate_proving_ctx_arm, where_predicate)
184            }).collect();
185
186            // Attach extra generics R and PB to the impl_generics
187            let generics = &ast.generics;
188            let mut new_generics = generics.clone();
189            new_generics.params.push(syn::parse_quote! { R });
190            new_generics
191                .params
192                .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend });
193            let (impl_generics, _, _) = new_generics.split_for_impl();
194
195            // Implement Chip whenever the inner type implements Chip
196            let mut new_generics = generics.clone();
197            let where_clause = new_generics.make_where_clause();
198            for predicate in where_predicates {
199                where_clause.predicates.push(predicate);
200            }
201            let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip"));
202            if let Some(attr) = attributes {
203                let mut fail_flag = false;
204
205                match &attr.meta {
206                    Meta::List(meta_list) => {
207                        meta_list
208                            .parse_nested_meta(|meta| {
209                                if meta.path.is_ident("where") {
210                                    let value = meta.value()?; // this parses the `=`
211                                    let s: LitStr = value.parse()?;
212                                    let where_value = s.value();
213                                    where_clause.predicates.push(syn::parse_str(&where_value)?);
214                                } else {
215                                    fail_flag = true;
216                                }
217                                Ok(())
218                            })
219                            .unwrap();
220                    }
221                    _ => fail_flag = true,
222                }
223                if fail_flag {
224                    return syn::Error::new(
225                        name.span(),
226                        "Only `#[chip(where = ...)]` format is supported",
227                    )
228                    .to_compile_error()
229                    .into();
230                }
231            }
232
233            quote! {
234                impl #impl_generics openvm_stark_backend::Chip<R, PB> for #name #ty_generics #where_clause {
235                    fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext<PB> {
236                        match self {
237                            #(#generate_proving_ctx_arms,)*
238                        }
239                    }
240                }
241            }.into()
242        }
243        Data::Union(_) => unimplemented!("Unions are not supported"),
244    }
245}
246
247#[proc_macro_derive(ChipUsageGetter)]
248pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream {
249    let ast: syn::DeriveInput = syn::parse(input).unwrap();
250
251    let name = &ast.ident;
252    let generics = &ast.generics;
253    let (impl_generics, ty_generics, _) = generics.split_for_impl();
254
255    match &ast.data {
256        Data::Struct(inner) => {
257            // Check if the struct has only one unnamed field
258            let inner_ty = match &inner.fields {
259                Fields::Unnamed(fields) => {
260                    if fields.unnamed.len() != 1 {
261                        panic!("Only one unnamed field is supported");
262                    }
263                    fields.unnamed.first().unwrap().ty.clone()
264                }
265                _ => panic!("Only unnamed fields are supported"),
266            };
267            // Implement ChipUsageGetter whenever the inner type implements ChipUsageGetter
268            let mut new_generics = generics.clone();
269            let where_clause = new_generics.make_where_clause();
270            where_clause
271                .predicates
272                .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::ChipUsageGetter });
273            quote! {
274                impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics #where_clause {
275                    fn air_name(&self) -> String {
276                        self.0.air_name()
277                    }
278                    fn constant_trace_height(&self) -> Option<usize> {
279                        self.0.constant_trace_height()
280                    }
281                    fn current_trace_height(&self) -> usize {
282                        self.0.current_trace_height()
283                    }
284                    fn trace_width(&self) -> usize {
285                        self.0.trace_width()
286                    }
287                }
288            }
289            .into()
290        }
291        Data::Enum(e) => {
292            let (air_name_arms, constant_trace_height_arms, current_trace_height_arms, trace_width_arms): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
293                multiunzip(e.variants.iter().map(|variant| {
294                    let variant_name = &variant.ident;
295                    let air_name_arm = quote! {
296                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::air_name(x)
297                };
298                    let constant_trace_height_arm = quote! {
299                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::constant_trace_height(x)
300                };
301                    let current_trace_height_arm = quote! {
302                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::current_trace_height(x)
303                };
304                    let trace_width_arm = quote! {
305                    #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::trace_width(x)
306                };
307                    (air_name_arm, constant_trace_height_arm, current_trace_height_arm, trace_width_arm)
308                }));
309
310            quote! {
311                impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics {
312                    fn air_name(&self) -> String {
313                        match self {
314                            #(#air_name_arms,)*
315                        }
316                    }
317                    fn constant_trace_height(&self) -> Option<usize> {
318                        match self {
319                            #(#constant_trace_height_arms,)*
320                        }
321                    }
322                    fn current_trace_height(&self) -> usize {
323                        match self {
324                            #(#current_trace_height_arms,)*
325                        }
326                    }
327                    fn trace_width(&self) -> usize {
328                        match self {
329                            #(#trace_width_arms,)*
330                        }
331                    }
332
333                }
334            }
335            .into()
336        }
337        Data::Union(_) => unimplemented!("Unions are not supported"),
338    }
339}
340
341#[proc_macro_derive(BytesStateful)]
342pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
343    let ast: syn::DeriveInput = syn::parse(input).unwrap();
344
345    let name = &ast.ident;
346    let generics = &ast.generics;
347    let (impl_generics, ty_generics, _) = generics.split_for_impl();
348
349    match &ast.data {
350        Data::Struct(inner) => {
351            // Check if the struct has only one unnamed field
352            let inner_ty = match &inner.fields {
353                Fields::Unnamed(fields) => {
354                    if fields.unnamed.len() != 1 {
355                        panic!("Only one unnamed field is supported");
356                    }
357                    fields.unnamed.first().unwrap().ty.clone()
358                }
359                _ => panic!("Only unnamed fields are supported"),
360            };
361            // Use full path ::openvm_circuit... so it can be used either within or outside the vm
362            // crate. Assume F is already generic of the field.
363            let mut new_generics = generics.clone();
364            let where_clause = new_generics.make_where_clause();
365            where_clause
366                .predicates
367                .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful<Vec<u8>> });
368
369            quote! {
370                impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics #where_clause {
371                    fn load_state(&mut self, state: Vec<u8>) {
372                        self.0.load_state(state)
373                    }
374
375                    fn store_state(&self) -> Vec<u8> {
376                        self.0.store_state()
377                    }
378                }
379            }
380            .into()
381        }
382        Data::Enum(e) => {
383            let variants = e
384                .variants
385                .iter()
386                .map(|variant| {
387                    let variant_name = &variant.ident;
388
389                    let mut fields = variant.fields.iter();
390                    let field = fields.next().unwrap();
391                    assert!(fields.next().is_none(), "Only one field is supported");
392                    (variant_name, field)
393                })
394                .collect::<Vec<_>>();
395            // Use full path ::openvm_stark_backend... so it can be used either within or outside
396            // the vm crate.
397            let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) =
398                multiunzip(variants.iter().map(|(variant_name, field)| {
399                    let field_ty = &field.ty;
400                    let load_state_arm = quote! {
401                        #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::load_state(x, state)
402                    };
403                    let store_state_arm = quote! {
404                        #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::store_state(x)
405                    };
406
407                    (load_state_arm, store_state_arm)
408                }));
409            quote! {
410                impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics {
411                    fn load_state(&mut self, state: Vec<u8>) {
412                        match self {
413                            #(#load_state_arms,)*
414                        }
415                    }
416
417                    fn store_state(&self) -> Vec<u8> {
418                        match self {
419                            #(#store_state_arms,)*
420                        }
421                    }
422                }
423            }
424            .into()
425        }
426        _ => unimplemented!(),
427    }
428}