openvm_circuit_primitives_derive/
lib.rs1extern crate alloc;
3extern crate proc_macro;
4
5use itertools::multiunzip;
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta};
9
10#[proc_macro_derive(AlignedBorrow)]
11pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
12 let ast = parse_macro_input!(input as DeriveInput);
13 let name = &ast.ident;
14
15 let type_generic = ast
17 .generics
18 .params
19 .iter()
20 .map(|param| match param {
21 GenericParam::Type(type_param) => &type_param.ident,
22 _ => panic!("Expected first generic to be a type"),
23 })
24 .next()
25 .expect("Expected at least one generic");
26
27 let non_first_generics = ast
30 .generics
31 .params
32 .iter()
33 .skip(1)
34 .filter_map(|param| match param {
35 GenericParam::Type(type_param) => Some(&type_param.ident),
36 GenericParam::Const(const_param) => Some(&const_param.ident),
37 _ => None,
38 })
39 .collect::<Vec<_>>();
40
41 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
44
45 let methods = quote! {
46 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
47 fn borrow(&self) -> &#name #type_generics {
48 debug_assert_eq!(self.len(), #name::#type_generics::width());
49 let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
50 debug_assert!(prefix.is_empty(), "Alignment should match");
51 debug_assert_eq!(shorts.len(), 1);
52 &shorts[0]
53 }
54 }
55
56 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
57 fn borrow_mut(&mut self) -> &mut #name #type_generics {
58 debug_assert_eq!(self.len(), #name::#type_generics::width());
59 let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
60 debug_assert!(prefix.is_empty(), "Alignment should match");
61 debug_assert_eq!(shorts.len(), 1);
62 &mut shorts[0]
63 }
64 }
65
66 impl #impl_generics #name #type_generics {
67 pub const fn width() -> usize {
68 std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>()
69 }
70 }
71 };
72
73 TokenStream::from(methods)
74}
75
76#[proc_macro_derive(AlignedBytesBorrow)]
82pub fn aligned_bytes_borrow_derive(input: TokenStream) -> TokenStream {
83 let ast = parse_macro_input!(input as DeriveInput);
84 let name = &ast.ident;
85
86 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
89
90 let methods = quote! {
91 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [u8]
92 where
93 #where_clause
94 {
95 fn borrow(&self) -> &#name #type_generics {
96 use core::mem::{align_of, size_of_val};
97 debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>());
98 debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0);
99 unsafe { &*(self.as_ptr() as *const #name #type_generics) }
100 }
101 }
102
103 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [u8]
104 where
105 #where_clause
106 {
107 fn borrow_mut(&mut self) -> &mut #name #type_generics {
108 use core::mem::{align_of, size_of_val};
109 debug_assert!(size_of_val(self) >= core::mem::size_of::<#name #type_generics>());
110 debug_assert_eq!(self.as_ptr() as usize % align_of::<#name #type_generics>(), 0);
111 unsafe { &mut *(self.as_mut_ptr() as *mut #name #type_generics) }
112 }
113 }
114 };
115
116 TokenStream::from(methods)
117}
118
119#[proc_macro_derive(Chip, attributes(chip))]
120pub fn chip_derive(input: TokenStream) -> TokenStream {
121 let ast: syn::DeriveInput = syn::parse(input).unwrap();
123
124 let name = &ast.ident;
125 let generics = &ast.generics;
126 let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
127
128 match &ast.data {
129 Data::Struct(inner) => {
130 let generics = &ast.generics;
131 let mut new_generics = generics.clone();
132 new_generics.params.push(syn::parse_quote! { R });
133 new_generics
134 .params
135 .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend });
136 let (impl_generics, _, _) = new_generics.split_for_impl();
137
138 let inner_ty = match &inner.fields {
140 Fields::Unnamed(fields) => {
141 if fields.unnamed.len() != 1 {
142 panic!("Only one unnamed field is supported");
143 }
144 fields.unnamed.first().unwrap().ty.clone()
145 }
146 _ => panic!("Only unnamed fields are supported"),
147 };
148 let mut new_generics = generics.clone();
149 let where_clause = new_generics.make_where_clause();
150 where_clause
151 .predicates
152 .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip<R, PB> });
153 quote! {
154 impl #impl_generics openvm_stark_backend::Chip<R, PB> for #name #ty_generics #where_clause {
155 fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext<PB> {
156 self.0.generate_proving_ctx(records)
157 }
158 }
159 }.into()
160 }
161 Data::Enum(e) => {
162 let variants = e
163 .variants
164 .iter()
165 .map(|variant| {
166 let variant_name = &variant.ident;
167
168 let mut fields = variant.fields.iter();
169 let field = fields.next().unwrap();
170 assert!(fields.next().is_none(), "Only one field is supported");
171 (variant_name, field)
172 })
173 .collect::<Vec<_>>();
174
175 let (generate_proving_ctx_arms, where_predicates): (Vec<_>, Vec<_>) =
176 variants.iter().map(|(variant_name, field)| {
177 let field_ty = &field.ty;
178 let generate_proving_ctx_arm = quote! {
179 #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<R, PB>>::generate_proving_ctx(x, records)
180 };
181 let where_predicate =
182 syn::parse_quote! { #field_ty: openvm_stark_backend::Chip<R, PB> };
183 (generate_proving_ctx_arm, where_predicate)
184 }).collect();
185
186 let generics = &ast.generics;
188 let mut new_generics = generics.clone();
189 new_generics.params.push(syn::parse_quote! { R });
190 new_generics
191 .params
192 .push(syn::parse_quote! { PB: openvm_stark_backend::prover::hal::ProverBackend });
193 let (impl_generics, _, _) = new_generics.split_for_impl();
194
195 let mut new_generics = generics.clone();
197 let where_clause = new_generics.make_where_clause();
198 for predicate in where_predicates {
199 where_clause.predicates.push(predicate);
200 }
201 let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip"));
202 if let Some(attr) = attributes {
203 let mut fail_flag = false;
204
205 match &attr.meta {
206 Meta::List(meta_list) => {
207 meta_list
208 .parse_nested_meta(|meta| {
209 if meta.path.is_ident("where") {
210 let value = meta.value()?; let s: LitStr = value.parse()?;
212 let where_value = s.value();
213 where_clause.predicates.push(syn::parse_str(&where_value)?);
214 } else {
215 fail_flag = true;
216 }
217 Ok(())
218 })
219 .unwrap();
220 }
221 _ => fail_flag = true,
222 }
223 if fail_flag {
224 return syn::Error::new(
225 name.span(),
226 "Only `#[chip(where = ...)]` format is supported",
227 )
228 .to_compile_error()
229 .into();
230 }
231 }
232
233 quote! {
234 impl #impl_generics openvm_stark_backend::Chip<R, PB> for #name #ty_generics #where_clause {
235 fn generate_proving_ctx(&self, records: R) -> openvm_stark_backend::prover::types::AirProvingContext<PB> {
236 match self {
237 #(#generate_proving_ctx_arms,)*
238 }
239 }
240 }
241 }.into()
242 }
243 Data::Union(_) => unimplemented!("Unions are not supported"),
244 }
245}
246
247#[proc_macro_derive(ChipUsageGetter)]
248pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream {
249 let ast: syn::DeriveInput = syn::parse(input).unwrap();
250
251 let name = &ast.ident;
252 let generics = &ast.generics;
253 let (impl_generics, ty_generics, _) = generics.split_for_impl();
254
255 match &ast.data {
256 Data::Struct(inner) => {
257 let inner_ty = match &inner.fields {
259 Fields::Unnamed(fields) => {
260 if fields.unnamed.len() != 1 {
261 panic!("Only one unnamed field is supported");
262 }
263 fields.unnamed.first().unwrap().ty.clone()
264 }
265 _ => panic!("Only unnamed fields are supported"),
266 };
267 let mut new_generics = generics.clone();
269 let where_clause = new_generics.make_where_clause();
270 where_clause
271 .predicates
272 .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::ChipUsageGetter });
273 quote! {
274 impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics #where_clause {
275 fn air_name(&self) -> String {
276 self.0.air_name()
277 }
278 fn constant_trace_height(&self) -> Option<usize> {
279 self.0.constant_trace_height()
280 }
281 fn current_trace_height(&self) -> usize {
282 self.0.current_trace_height()
283 }
284 fn trace_width(&self) -> usize {
285 self.0.trace_width()
286 }
287 }
288 }
289 .into()
290 }
291 Data::Enum(e) => {
292 let (air_name_arms, constant_trace_height_arms, current_trace_height_arms, trace_width_arms): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
293 multiunzip(e.variants.iter().map(|variant| {
294 let variant_name = &variant.ident;
295 let air_name_arm = quote! {
296 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::air_name(x)
297 };
298 let constant_trace_height_arm = quote! {
299 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::constant_trace_height(x)
300 };
301 let current_trace_height_arm = quote! {
302 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::current_trace_height(x)
303 };
304 let trace_width_arm = quote! {
305 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::trace_width(x)
306 };
307 (air_name_arm, constant_trace_height_arm, current_trace_height_arm, trace_width_arm)
308 }));
309
310 quote! {
311 impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics {
312 fn air_name(&self) -> String {
313 match self {
314 #(#air_name_arms,)*
315 }
316 }
317 fn constant_trace_height(&self) -> Option<usize> {
318 match self {
319 #(#constant_trace_height_arms,)*
320 }
321 }
322 fn current_trace_height(&self) -> usize {
323 match self {
324 #(#current_trace_height_arms,)*
325 }
326 }
327 fn trace_width(&self) -> usize {
328 match self {
329 #(#trace_width_arms,)*
330 }
331 }
332
333 }
334 }
335 .into()
336 }
337 Data::Union(_) => unimplemented!("Unions are not supported"),
338 }
339}
340
341#[proc_macro_derive(BytesStateful)]
342pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
343 let ast: syn::DeriveInput = syn::parse(input).unwrap();
344
345 let name = &ast.ident;
346 let generics = &ast.generics;
347 let (impl_generics, ty_generics, _) = generics.split_for_impl();
348
349 match &ast.data {
350 Data::Struct(inner) => {
351 let inner_ty = match &inner.fields {
353 Fields::Unnamed(fields) => {
354 if fields.unnamed.len() != 1 {
355 panic!("Only one unnamed field is supported");
356 }
357 fields.unnamed.first().unwrap().ty.clone()
358 }
359 _ => panic!("Only unnamed fields are supported"),
360 };
361 let mut new_generics = generics.clone();
364 let where_clause = new_generics.make_where_clause();
365 where_clause
366 .predicates
367 .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful<Vec<u8>> });
368
369 quote! {
370 impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics #where_clause {
371 fn load_state(&mut self, state: Vec<u8>) {
372 self.0.load_state(state)
373 }
374
375 fn store_state(&self) -> Vec<u8> {
376 self.0.store_state()
377 }
378 }
379 }
380 .into()
381 }
382 Data::Enum(e) => {
383 let variants = e
384 .variants
385 .iter()
386 .map(|variant| {
387 let variant_name = &variant.ident;
388
389 let mut fields = variant.fields.iter();
390 let field = fields.next().unwrap();
391 assert!(fields.next().is_none(), "Only one field is supported");
392 (variant_name, field)
393 })
394 .collect::<Vec<_>>();
395 let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) =
398 multiunzip(variants.iter().map(|(variant_name, field)| {
399 let field_ty = &field.ty;
400 let load_state_arm = quote! {
401 #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::load_state(x, state)
402 };
403 let store_state_arm = quote! {
404 #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::store_state(x)
405 };
406
407 (load_state_arm, store_state_arm)
408 }));
409 quote! {
410 impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics {
411 fn load_state(&mut self, state: Vec<u8>) {
412 match self {
413 #(#load_state_arms,)*
414 }
415 }
416
417 fn store_state(&self) -> Vec<u8> {
418 match self {
419 #(#store_state_arms,)*
420 }
421 }
422 }
423 }
424 .into()
425 }
426 _ => unimplemented!(),
427 }
428}