1use proc_macro2::{Span, TokenStream};
2use syn::{
3 parse::{Parse, ParseStream},
4 Ident, Token,
5};
6
7enum AsmArg {
8 In(TokenStream),
9 Out(TokenStream),
10 InOut(TokenStream),
11 ConstExpr(TokenStream),
12 ConstLit(syn::LitStr),
13}
14
15struct CustomInsnR {
16 pub rd: AsmArg,
17 pub rs1: AsmArg,
18 pub rs2: AsmArg,
19 pub opcode: TokenStream,
20 pub funct3: TokenStream,
21 pub funct7: TokenStream,
22}
23
24struct CustomInsnI {
25 pub rd: AsmArg,
26 pub rs1: AsmArg,
27 pub imm: AsmArg,
28 pub opcode: TokenStream,
29 pub funct3: TokenStream,
30}
31
32#[allow(clippy::type_complexity)]
34fn parse_common_fields(
35 input: ParseStream,
36) -> syn::Result<(
37 Option<AsmArg>,
38 Option<AsmArg>,
39 Option<TokenStream>,
40 Option<TokenStream>,
41)> {
42 let mut rd = None;
43 let mut rs1 = None;
44 let mut opcode = None;
45 let mut funct3 = None;
46
47 while !input.is_empty() {
48 let key: Ident = input.parse()?;
49 input.parse::<Token![=]>()?;
50
51 let value = if key == "opcode" || key == "funct3" {
52 let mut tokens = TokenStream::new();
53 while !input.is_empty() && !input.peek(Token![,]) {
54 tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
55 }
56 match key.to_string().as_str() {
57 "opcode" => opcode = Some(tokens),
58 "funct3" => funct3 = Some(tokens),
59 _ => unreachable!(),
60 }
61 None
62 } else if key == "rd" || key == "rs1" {
63 Some(parse_asm_arg(input)?)
64 } else {
65 while !input.is_empty() && !input.peek(Token![,]) {
66 input.parse::<proc_macro2::TokenTree>()?;
67 }
68 None
69 };
70
71 match key.to_string().as_str() {
72 "rd" => rd = value,
73 "rs1" => rs1 = value,
74 "opcode" | "funct3" => (),
75 _ => {
77 if !input.is_empty() {
78 input.parse::<Token![,]>()?;
79 }
80 continue;
81 }
82 }
83
84 if !input.is_empty() {
85 input.parse::<Token![,]>()?;
86 }
87 }
88
89 Ok((rd, rs1, opcode, funct3))
90}
91
92fn parse_asm_arg(input: ParseStream) -> syn::Result<AsmArg> {
94 let lookahead = input.lookahead1();
95 if lookahead.peek(kw::In) {
96 input.parse::<kw::In>()?;
97 let mut tokens = TokenStream::new();
98 while !input.is_empty() && !input.peek(Token![,]) {
99 tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
100 }
101 Ok(AsmArg::In(tokens))
102 } else if lookahead.peek(kw::Out) {
103 input.parse::<kw::Out>()?;
105 let mut tokens = TokenStream::new();
106 while !input.is_empty() && !input.peek(Token![,]) {
107 tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
108 }
109 Ok(AsmArg::Out(tokens))
110 } else if lookahead.peek(kw::InOut) {
111 input.parse::<kw::InOut>()?;
113 let mut tokens = TokenStream::new();
114 while !input.is_empty() && !input.peek(Token![,]) {
115 tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
116 }
117 Ok(AsmArg::InOut(tokens))
118 } else if lookahead.peek(kw::Const) {
119 input.parse::<kw::Const>()?;
120 if input.peek(syn::LitStr) {
121 Ok(AsmArg::ConstLit(input.parse()?))
122 } else {
123 let mut tokens = TokenStream::new();
124 while !input.is_empty() && !input.peek(Token![,]) {
125 tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
126 }
127 Ok(AsmArg::ConstExpr(tokens))
128 }
129 } else {
130 Err(lookahead.error())
131 }
132}
133
134impl Parse for CustomInsnR {
135 fn parse(input: ParseStream) -> syn::Result<Self> {
136 let input_fork = input.fork();
137 let (rd, rs1, opcode, funct3) = parse_common_fields(input)?;
138
139 let mut rs2 = None;
141 let mut funct7 = None;
142 while !input_fork.is_empty() {
143 let key: Ident = input_fork.parse()?;
144 input_fork.parse::<Token![=]>()?;
145
146 if key == "rs2" {
147 rs2 = Some(parse_asm_arg(&input_fork)?);
148 } else if key == "funct7" {
149 let mut tokens = TokenStream::new();
150 while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
151 tokens.extend(TokenStream::from(
152 input_fork.parse::<proc_macro2::TokenTree>()?,
153 ));
154 }
155 funct7 = Some(tokens);
156 } else {
157 while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
159 input_fork.parse::<proc_macro2::TokenTree>()?;
160 }
161 }
162
163 if !input_fork.is_empty() {
164 input_fork.parse::<Token![,]>()?;
165 }
166 }
167
168 let opcode = opcode.ok_or_else(|| syn::Error::new(input.span(), "missing opcode field"))?;
169 let funct3 = funct3.ok_or_else(|| syn::Error::new(input.span(), "missing funct3 field"))?;
170 let funct7 = funct7.ok_or_else(|| syn::Error::new(input.span(), "missing funct7 field"))?;
171 let rd = rd.ok_or_else(|| syn::Error::new(input.span(), "missing rd field"))?;
172 let rs1 = rs1.ok_or_else(|| syn::Error::new(input.span(), "missing rs1 field"))?;
173 let rs2 = rs2.ok_or_else(|| syn::Error::new(input.span(), "missing rs2 field"))?;
174
175 Ok(CustomInsnR {
176 rd,
177 rs1,
178 rs2,
179 opcode,
180 funct3,
181 funct7,
182 })
183 }
184}
185
186impl Parse for CustomInsnI {
187 fn parse(input: ParseStream) -> syn::Result<Self> {
188 let input_fork = input.fork();
189 let (rd, rs1, opcode, funct3) = parse_common_fields(input)?;
190
191 let mut imm = None;
193 while !input_fork.is_empty() {
194 let key: Ident = input_fork.parse()?;
195 input_fork.parse::<Token![=]>()?;
196
197 if key == "imm" {
198 let value = parse_asm_arg(&input_fork)?;
199 match value {
200 AsmArg::ConstLit(lit) => imm = Some(AsmArg::ConstLit(lit)),
201 AsmArg::ConstExpr(expr) => imm = Some(AsmArg::ConstExpr(expr)),
202 _ => return Err(syn::Error::new(key.span(), "imm must be a Const")),
203 }
204 } else {
205 while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
207 input_fork.parse::<proc_macro2::TokenTree>()?;
208 }
209 }
210
211 if !input_fork.is_empty() {
212 input_fork.parse::<Token![,]>()?;
213 }
214 }
215
216 let opcode = opcode.ok_or_else(|| syn::Error::new(input.span(), "missing opcode field"))?;
217 let funct3 = funct3.ok_or_else(|| syn::Error::new(input.span(), "missing funct3 field"))?;
218 let rd = rd.ok_or_else(|| syn::Error::new(input.span(), "missing rd field"))?;
219 let rs1 = rs1.ok_or_else(|| syn::Error::new(input.span(), "missing rs1 field"))?;
220 let imm = imm.ok_or_else(|| syn::Error::new(input.span(), "missing imm field"))?;
221
222 Ok(CustomInsnI {
223 rd,
224 rs1,
225 imm,
226 opcode,
227 funct3,
228 })
229 }
230}
231
232fn handle_reg_arg(
234 template: &mut String,
235 args: &mut Vec<proc_macro2::TokenStream>,
236 arg: &AsmArg,
237 reg_name: &str,
238) {
239 let reg_ident = syn::Ident::new(reg_name, Span::call_site());
240 match arg {
241 AsmArg::ConstLit(lit) => {
242 template.push_str(", ");
243 template.push_str(&lit.value());
244 }
245 AsmArg::In(tokens) => {
246 template.push_str(", {");
247 template.push_str(reg_name);
248 template.push('}');
249 args.push(quote::quote! { #reg_ident = in(reg) #tokens });
250 }
251 AsmArg::Out(tokens) => {
252 template.push_str(", {");
253 template.push_str(reg_name);
254 template.push('}');
255 args.push(quote::quote! { #reg_ident = out(reg) #tokens });
256 }
257 AsmArg::InOut(tokens) => {
258 template.push_str(", {");
259 template.push_str(reg_name);
260 template.push('}');
261 args.push(quote::quote! { #reg_ident = inout(reg) #tokens });
262 }
263 AsmArg::ConstExpr(tokens) => {
264 template.push_str(", {");
265 template.push_str(reg_name);
266 template.push('}');
267 args.push(quote::quote! { #reg_ident = const #tokens });
268 }
269 }
270}
271
272mod kw {
273 syn::custom_keyword!(In);
274 syn::custom_keyword!(Out);
275 syn::custom_keyword!(InOut);
276 syn::custom_keyword!(Const);
277}
278
279#[proc_macro]
304pub fn custom_insn_r(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
305 let CustomInsnR {
306 rd,
307 rs1,
308 rs2,
309 opcode,
310 funct3,
311 funct7,
312 } = syn::parse_macro_input!(input as CustomInsnR);
313
314 let mut template = String::from(".insn r {opcode}, {funct3}, {funct7}");
315 let mut args = vec![];
316
317 handle_reg_arg(&mut template, &mut args, &rd, "rd");
319 handle_reg_arg(&mut template, &mut args, &rs1, "rs1");
320 handle_reg_arg(&mut template, &mut args, &rs2, "rs2");
321
322 let expanded = quote::quote! {
323 #[cfg(target_os = "zkvm")]
324 unsafe {
325 core::arch::asm!(
326 #template,
327 opcode = const #opcode,
328 funct3 = const #funct3,
329 funct7 = const #funct7,
330 #(#args),*
331 )
332 }
333 };
334
335 expanded.into()
336}
337
338#[proc_macro]
364pub fn custom_insn_i(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
365 let CustomInsnI {
366 rd,
367 rs1,
368 imm,
369 opcode,
370 funct3,
371 } = syn::parse_macro_input!(input as CustomInsnI);
372
373 let mut template = String::from(".insn i {opcode}, {funct3}");
374 let mut args = vec![];
375
376 handle_reg_arg(&mut template, &mut args, &rd, "rd");
378 handle_reg_arg(&mut template, &mut args, &rs1, "rs1");
379 handle_reg_arg(&mut template, &mut args, &imm, "imm");
380
381 let expanded = quote::quote! {
382 #[cfg(target_os = "zkvm")]
383 unsafe {
384 core::arch::asm!(
385 #template,
386 opcode = const #opcode,
387 funct3 = const #funct3,
388 #(#args),*
389 )
390 }
391 };
392
393 expanded.into()
394}