openvm_circuit_primitives_derive/
lib.rs
1extern crate alloc;
3extern crate proc_macro;
4
5use itertools::multiunzip;
6use proc_macro::TokenStream;
7use quote::quote;
8use syn::{parse_macro_input, Data, DeriveInput, Fields, GenericParam, LitStr, Meta};
9
10#[proc_macro_derive(AlignedBorrow)]
11pub fn aligned_borrow_derive(input: TokenStream) -> TokenStream {
12 let ast = parse_macro_input!(input as DeriveInput);
13 let name = &ast.ident;
14
15 let type_generic = ast
17 .generics
18 .params
19 .iter()
20 .map(|param| match param {
21 GenericParam::Type(type_param) => &type_param.ident,
22 _ => panic!("Expected first generic to be a type"),
23 })
24 .next()
25 .expect("Expected at least one generic");
26
27 let non_first_generics = ast
30 .generics
31 .params
32 .iter()
33 .skip(1)
34 .filter_map(|param| match param {
35 GenericParam::Type(type_param) => Some(&type_param.ident),
36 GenericParam::Const(const_param) => Some(&const_param.ident),
37 _ => None,
38 })
39 .collect::<Vec<_>>();
40
41 let (impl_generics, type_generics, where_clause) = ast.generics.split_for_impl();
43
44 let methods = quote! {
45 impl #impl_generics core::borrow::Borrow<#name #type_generics> for [#type_generic] #where_clause {
46 fn borrow(&self) -> &#name #type_generics {
47 debug_assert_eq!(self.len(), #name::#type_generics::width());
48 let (prefix, shorts, _suffix) = unsafe { self.align_to::<#name #type_generics>() };
49 debug_assert!(prefix.is_empty(), "Alignment should match");
50 debug_assert_eq!(shorts.len(), 1);
51 &shorts[0]
52 }
53 }
54
55 impl #impl_generics core::borrow::BorrowMut<#name #type_generics> for [#type_generic] #where_clause {
56 fn borrow_mut(&mut self) -> &mut #name #type_generics {
57 debug_assert_eq!(self.len(), #name::#type_generics::width());
58 let (prefix, shorts, _suffix) = unsafe { self.align_to_mut::<#name #type_generics>() };
59 debug_assert!(prefix.is_empty(), "Alignment should match");
60 debug_assert_eq!(shorts.len(), 1);
61 &mut shorts[0]
62 }
63 }
64
65 impl #impl_generics #name #type_generics {
66 pub const fn width() -> usize {
67 std::mem::size_of::<#name<u8 #(, #non_first_generics)*>>()
68 }
69 }
70 };
71
72 TokenStream::from(methods)
73}
74
75#[proc_macro_derive(Chip, attributes(chip))]
76pub fn chip_derive(input: TokenStream) -> TokenStream {
77 let ast: syn::DeriveInput = syn::parse(input).unwrap();
79
80 let name = &ast.ident;
81 let generics = &ast.generics;
82 let (_impl_generics, ty_generics, _where_clause) = generics.split_for_impl();
83
84 match &ast.data {
85 Data::Struct(inner) => {
86 let generics = &ast.generics;
87 let mut new_generics = generics.clone();
88 new_generics
89 .params
90 .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig });
91 let (impl_generics, _, _) = new_generics.split_for_impl();
92
93 let inner_ty = match &inner.fields {
95 Fields::Unnamed(fields) => {
96 if fields.unnamed.len() != 1 {
97 panic!("Only one unnamed field is supported");
98 }
99 fields.unnamed.first().unwrap().ty.clone()
100 }
101 _ => panic!("Only unnamed fields are supported"),
102 };
103 let mut new_generics = generics.clone();
104 let where_clause = new_generics.make_where_clause();
105 where_clause
106 .predicates
107 .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::Chip<SC> });
108 quote! {
109 impl #impl_generics openvm_stark_backend::Chip<SC> for #name #ty_generics #where_clause {
110 fn air(&self) -> openvm_stark_backend::AirRef<SC> {
111 self.0.air()
112 }
113 fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput<SC> {
114 self.0.generate_air_proof_input()
115 }
116 fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput<SC>) {
117 self.0.generate_air_proof_input_with_id(air_id)
118 }
119 }
120 }.into()
121 }
122 Data::Enum(e) => {
123 let variants = e
124 .variants
125 .iter()
126 .map(|variant| {
127 let variant_name = &variant.ident;
128
129 let mut fields = variant.fields.iter();
130 let field = fields.next().unwrap();
131 assert!(fields.next().is_none(), "Only one field is supported");
132 (variant_name, field)
133 })
134 .collect::<Vec<_>>();
135
136 let (air_arms, generate_air_proof_input_arms, generate_air_proof_input_with_id_arms): (Vec<_>, Vec<_>, Vec<_>) =
137 multiunzip(variants.iter().map(|(variant_name, field)| {
138 let field_ty = &field.ty;
139 let air_arm = quote! {
140 #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::air(x)
141 };
142 let generate_air_proof_input_arm = quote! {
143 #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::generate_air_proof_input(x)
144 };
145 let generate_air_proof_input_with_id_arm = quote! {
146 #name::#variant_name(x) => <#field_ty as openvm_stark_backend::Chip<SC>>::generate_air_proof_input_with_id(x, air_id)
147 };
148 (air_arm, generate_air_proof_input_arm, generate_air_proof_input_with_id_arm)
149 }));
150
151 let generics = &ast.generics;
153 let mut new_generics = generics.clone();
154 new_generics
155 .params
156 .push(syn::parse_quote! { SC: openvm_stark_backend::config::StarkGenericConfig });
157 let (impl_generics, _, _) = new_generics.split_for_impl();
158
159 let mut new_generics = generics.clone();
161 let where_clause = new_generics.make_where_clause();
162 where_clause.predicates.push(syn::parse_quote! { openvm_stark_backend::config::Domain<SC>: openvm_stark_backend::p3_commit::PolynomialSpace<Val = F>
163 });
164 let attributes = ast.attrs.iter().find(|&attr| attr.path().is_ident("chip"));
165 if let Some(attr) = attributes {
166 let mut fail_flag = false;
167
168 match &attr.meta {
169 Meta::List(meta_list) => {
170 meta_list
171 .parse_nested_meta(|meta| {
172 if meta.path.is_ident("where") {
173 let value = meta.value()?; let s: LitStr = value.parse()?;
175 let where_value = s.value();
176 where_clause.predicates.push(syn::parse_str(&where_value)?);
177 } else {
178 fail_flag = true;
179 }
180 Ok(())
181 })
182 .unwrap();
183 }
184 _ => fail_flag = true,
185 }
186 if fail_flag {
187 return syn::Error::new(
188 name.span(),
189 "Only `#[chip(where = ...)]` format is supported",
190 )
191 .to_compile_error()
192 .into();
193 }
194 }
195
196 quote! {
197 impl #impl_generics openvm_stark_backend::Chip<SC> for #name #ty_generics #where_clause {
198 fn air(&self) -> openvm_stark_backend::AirRef<SC> {
199 match self {
200 #(#air_arms,)*
201 }
202 }
203 fn generate_air_proof_input(self) -> openvm_stark_backend::prover::types::AirProofInput<SC> {
204 match self {
205 #(#generate_air_proof_input_arms,)*
206 }
207 }
208 fn generate_air_proof_input_with_id(self, air_id: usize) -> (usize, openvm_stark_backend::prover::types::AirProofInput<SC>) {
209 match self {
210 #(#generate_air_proof_input_with_id_arms,)*
211 }
212 }
213 }
214 }.into()
215 }
216 Data::Union(_) => unimplemented!("Unions are not supported"),
217 }
218}
219
220#[proc_macro_derive(ChipUsageGetter)]
221pub fn chip_usage_getter_derive(input: TokenStream) -> TokenStream {
222 let ast: syn::DeriveInput = syn::parse(input).unwrap();
223
224 let name = &ast.ident;
225 let generics = &ast.generics;
226 let (impl_generics, ty_generics, _) = generics.split_for_impl();
227
228 match &ast.data {
229 Data::Struct(inner) => {
230 let inner_ty = match &inner.fields {
232 Fields::Unnamed(fields) => {
233 if fields.unnamed.len() != 1 {
234 panic!("Only one unnamed field is supported");
235 }
236 fields.unnamed.first().unwrap().ty.clone()
237 }
238 _ => panic!("Only unnamed fields are supported"),
239 };
240 let mut new_generics = generics.clone();
242 let where_clause = new_generics.make_where_clause();
243 where_clause
244 .predicates
245 .push(syn::parse_quote! { #inner_ty: openvm_stark_backend::ChipUsageGetter });
246 quote! {
247 impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics #where_clause {
248 fn air_name(&self) -> String {
249 self.0.air_name()
250 }
251 fn constant_trace_height(&self) -> Option<usize> {
252 self.0.constant_trace_height()
253 }
254 fn current_trace_height(&self) -> usize {
255 self.0.current_trace_height()
256 }
257 fn trace_width(&self) -> usize {
258 self.0.trace_width()
259 }
260 }
261 }
262 .into()
263 }
264 Data::Enum(e) => {
265 let (air_name_arms, constant_trace_height_arms, current_trace_height_arms, trace_width_arms): (Vec<_>, Vec<_>, Vec<_>, Vec<_>) =
266 multiunzip(e.variants.iter().map(|variant| {
267 let variant_name = &variant.ident;
268 let air_name_arm = quote! {
269 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::air_name(x)
270 };
271 let constant_trace_height_arm = quote! {
272 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::constant_trace_height(x)
273 };
274 let current_trace_height_arm = quote! {
275 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::current_trace_height(x)
276 };
277 let trace_width_arm = quote! {
278 #name::#variant_name(x) => openvm_stark_backend::ChipUsageGetter::trace_width(x)
279 };
280 (air_name_arm, constant_trace_height_arm, current_trace_height_arm, trace_width_arm)
281 }));
282
283 quote! {
284 impl #impl_generics openvm_stark_backend::ChipUsageGetter for #name #ty_generics {
285 fn air_name(&self) -> String {
286 match self {
287 #(#air_name_arms,)*
288 }
289 }
290 fn constant_trace_height(&self) -> Option<usize> {
291 match self {
292 #(#constant_trace_height_arms,)*
293 }
294 }
295 fn current_trace_height(&self) -> usize {
296 match self {
297 #(#current_trace_height_arms,)*
298 }
299 }
300 fn trace_width(&self) -> usize {
301 match self {
302 #(#trace_width_arms,)*
303 }
304 }
305
306 }
307 }
308 .into()
309 }
310 Data::Union(_) => unimplemented!("Unions are not supported"),
311 }
312}
313
314#[proc_macro_derive(BytesStateful)]
315pub fn bytes_stateful_derive(input: TokenStream) -> TokenStream {
316 let ast: syn::DeriveInput = syn::parse(input).unwrap();
317
318 let name = &ast.ident;
319 let generics = &ast.generics;
320 let (impl_generics, ty_generics, _) = generics.split_for_impl();
321
322 match &ast.data {
323 Data::Struct(inner) => {
324 let inner_ty = match &inner.fields {
326 Fields::Unnamed(fields) => {
327 if fields.unnamed.len() != 1 {
328 panic!("Only one unnamed field is supported");
329 }
330 fields.unnamed.first().unwrap().ty.clone()
331 }
332 _ => panic!("Only unnamed fields are supported"),
333 };
334 let mut new_generics = generics.clone();
337 let where_clause = new_generics.make_where_clause();
338 where_clause
339 .predicates
340 .push(syn::parse_quote! { #inner_ty: ::openvm_stark_backend::Stateful<Vec<u8>> });
341
342 quote! {
343 impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics #where_clause {
344 fn load_state(&mut self, state: Vec<u8>) {
345 self.0.load_state(state)
346 }
347
348 fn store_state(&self) -> Vec<u8> {
349 self.0.store_state()
350 }
351 }
352 }
353 .into()
354 }
355 Data::Enum(e) => {
356 let variants = e
357 .variants
358 .iter()
359 .map(|variant| {
360 let variant_name = &variant.ident;
361
362 let mut fields = variant.fields.iter();
363 let field = fields.next().unwrap();
364 assert!(fields.next().is_none(), "Only one field is supported");
365 (variant_name, field)
366 })
367 .collect::<Vec<_>>();
368 let (load_state_arms, store_state_arms): (Vec<_>, Vec<_>) =
370 multiunzip(variants.iter().map(|(variant_name, field)| {
371 let field_ty = &field.ty;
372 let load_state_arm = quote! {
373 #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::load_state(x, state)
374 };
375 let store_state_arm = quote! {
376 #name::#variant_name(x) => <#field_ty as ::openvm_stark_backend::Stateful<Vec<u8>>>::store_state(x)
377 };
378
379 (load_state_arm, store_state_arm)
380 }));
381 quote! {
382 impl #impl_generics ::openvm_stark_backend::Stateful<Vec<u8>> for #name #ty_generics {
383 fn load_state(&mut self, state: Vec<u8>) {
384 match self {
385 #(#load_state_arms,)*
386 }
387 }
388
389 fn store_state(&self) -> Vec<u8> {
390 match self {
391 #(#store_state_arms,)*
392 }
393 }
394 }
395 }
396 .into()
397 }
398 _ => unimplemented!(),
399 }
400}