openvm_custom_insn/
lib.rs

1use proc_macro2::{Span, TokenStream};
2use syn::{
3    parse::{Parse, ParseStream},
4    Ident, Token,
5};
6
7enum AsmArg {
8    In(TokenStream),
9    Out(TokenStream),
10    InOut(TokenStream),
11    ConstExpr(TokenStream),
12    ConstLit(syn::LitStr),
13}
14
15struct CustomInsnR {
16    pub rd: AsmArg,
17    pub rs1: AsmArg,
18    pub rs2: AsmArg,
19    pub opcode: TokenStream,
20    pub funct3: TokenStream,
21    pub funct7: TokenStream,
22}
23
24struct CustomInsnI {
25    pub rd: AsmArg,
26    pub rs1: AsmArg,
27    pub imm: AsmArg,
28    pub opcode: TokenStream,
29    pub funct3: TokenStream,
30}
31
32/// Returns `(rd, rs1, opcode, funct3)`.
33#[allow(clippy::type_complexity)]
34fn parse_common_fields(
35    input: ParseStream,
36) -> syn::Result<(
37    Option<AsmArg>,
38    Option<AsmArg>,
39    Option<TokenStream>,
40    Option<TokenStream>,
41)> {
42    let mut rd = None;
43    let mut rs1 = None;
44    let mut opcode = None;
45    let mut funct3 = None;
46
47    while !input.is_empty() {
48        let key: Ident = input.parse()?;
49        input.parse::<Token![=]>()?;
50
51        let value = if key == "opcode" || key == "funct3" {
52            let mut tokens = TokenStream::new();
53            while !input.is_empty() && !input.peek(Token![,]) {
54                tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
55            }
56            match key.to_string().as_str() {
57                "opcode" => opcode = Some(tokens),
58                "funct3" => funct3 = Some(tokens),
59                _ => unreachable!(),
60            }
61            None
62        } else if key == "rd" || key == "rs1" {
63            Some(parse_asm_arg(input)?)
64        } else {
65            while !input.is_empty() && !input.peek(Token![,]) {
66                input.parse::<proc_macro2::TokenTree>()?;
67            }
68            None
69        };
70
71        match key.to_string().as_str() {
72            "rd" => rd = value,
73            "rs1" => rs1 = value,
74            "opcode" | "funct3" => (),
75            // Skip other fields instead of returning an error
76            _ => {
77                if !input.is_empty() {
78                    input.parse::<Token![,]>()?;
79                }
80                continue;
81            }
82        }
83
84        if !input.is_empty() {
85            input.parse::<Token![,]>()?;
86        }
87    }
88
89    Ok((rd, rs1, opcode, funct3))
90}
91
92// Helper function to parse AsmArg
93fn parse_asm_arg(input: ParseStream) -> syn::Result<AsmArg> {
94    let lookahead = input.lookahead1();
95    if lookahead.peek(kw::In) {
96        input.parse::<kw::In>()?;
97        let mut tokens = TokenStream::new();
98        while !input.is_empty() && !input.peek(Token![,]) {
99            tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
100        }
101        Ok(AsmArg::In(tokens))
102    } else if lookahead.peek(kw::Out) {
103        // ... similar for Out
104        input.parse::<kw::Out>()?;
105        let mut tokens = TokenStream::new();
106        while !input.is_empty() && !input.peek(Token![,]) {
107            tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
108        }
109        Ok(AsmArg::Out(tokens))
110    } else if lookahead.peek(kw::InOut) {
111        // ... similar for InOut
112        input.parse::<kw::InOut>()?;
113        let mut tokens = TokenStream::new();
114        while !input.is_empty() && !input.peek(Token![,]) {
115            tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
116        }
117        Ok(AsmArg::InOut(tokens))
118    } else if lookahead.peek(kw::Const) {
119        input.parse::<kw::Const>()?;
120        if input.peek(syn::LitStr) {
121            Ok(AsmArg::ConstLit(input.parse()?))
122        } else {
123            let mut tokens = TokenStream::new();
124            while !input.is_empty() && !input.peek(Token![,]) {
125                tokens.extend(TokenStream::from(input.parse::<proc_macro2::TokenTree>()?));
126            }
127            Ok(AsmArg::ConstExpr(tokens))
128        }
129    } else {
130        Err(lookahead.error())
131    }
132}
133
134impl Parse for CustomInsnR {
135    fn parse(input: ParseStream) -> syn::Result<Self> {
136        let input_fork = input.fork();
137        let (rd, rs1, opcode, funct3) = parse_common_fields(input)?;
138
139        // Parse rs2 and funct7 from the forked input
140        let mut rs2 = None;
141        let mut funct7 = None;
142        while !input_fork.is_empty() {
143            let key: Ident = input_fork.parse()?;
144            input_fork.parse::<Token![=]>()?;
145
146            if key == "rs2" {
147                rs2 = Some(parse_asm_arg(&input_fork)?);
148            } else if key == "funct7" {
149                let mut tokens = TokenStream::new();
150                while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
151                    tokens.extend(TokenStream::from(
152                        input_fork.parse::<proc_macro2::TokenTree>()?,
153                    ));
154                }
155                funct7 = Some(tokens);
156            } else {
157                // Skip other fields
158                while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
159                    input_fork.parse::<proc_macro2::TokenTree>()?;
160                }
161            }
162
163            if !input_fork.is_empty() {
164                input_fork.parse::<Token![,]>()?;
165            }
166        }
167
168        let opcode = opcode.ok_or_else(|| syn::Error::new(input.span(), "missing opcode field"))?;
169        let funct3 = funct3.ok_or_else(|| syn::Error::new(input.span(), "missing funct3 field"))?;
170        let funct7 = funct7.ok_or_else(|| syn::Error::new(input.span(), "missing funct7 field"))?;
171        let rd = rd.ok_or_else(|| syn::Error::new(input.span(), "missing rd field"))?;
172        let rs1 = rs1.ok_or_else(|| syn::Error::new(input.span(), "missing rs1 field"))?;
173        let rs2 = rs2.ok_or_else(|| syn::Error::new(input.span(), "missing rs2 field"))?;
174
175        Ok(CustomInsnR {
176            rd,
177            rs1,
178            rs2,
179            opcode,
180            funct3,
181            funct7,
182        })
183    }
184}
185
186impl Parse for CustomInsnI {
187    fn parse(input: ParseStream) -> syn::Result<Self> {
188        let input_fork = input.fork();
189        let (rd, rs1, opcode, funct3) = parse_common_fields(input)?;
190
191        // Parse imm from the forked input
192        let mut imm = None;
193        while !input_fork.is_empty() {
194            let key: Ident = input_fork.parse()?;
195            input_fork.parse::<Token![=]>()?;
196
197            if key == "imm" {
198                let value = parse_asm_arg(&input_fork)?;
199                match value {
200                    AsmArg::ConstLit(lit) => imm = Some(AsmArg::ConstLit(lit)),
201                    AsmArg::ConstExpr(expr) => imm = Some(AsmArg::ConstExpr(expr)),
202                    _ => return Err(syn::Error::new(key.span(), "imm must be a Const")),
203                }
204            } else {
205                // Skip other fields
206                while !input_fork.is_empty() && !input_fork.peek(Token![,]) {
207                    input_fork.parse::<proc_macro2::TokenTree>()?;
208                }
209            }
210
211            if !input_fork.is_empty() {
212                input_fork.parse::<Token![,]>()?;
213            }
214        }
215
216        let opcode = opcode.ok_or_else(|| syn::Error::new(input.span(), "missing opcode field"))?;
217        let funct3 = funct3.ok_or_else(|| syn::Error::new(input.span(), "missing funct3 field"))?;
218        let rd = rd.ok_or_else(|| syn::Error::new(input.span(), "missing rd field"))?;
219        let rs1 = rs1.ok_or_else(|| syn::Error::new(input.span(), "missing rs1 field"))?;
220        let imm = imm.ok_or_else(|| syn::Error::new(input.span(), "missing imm field"))?;
221
222        Ok(CustomInsnI {
223            rd,
224            rs1,
225            imm,
226            opcode,
227            funct3,
228        })
229    }
230}
231
232// Helper function for handling register arguments in both proc macros
233fn handle_reg_arg(
234    template: &mut String,
235    args: &mut Vec<proc_macro2::TokenStream>,
236    arg: &AsmArg,
237    reg_name: &str,
238) {
239    let reg_ident = syn::Ident::new(reg_name, Span::call_site());
240    match arg {
241        AsmArg::ConstLit(lit) => {
242            template.push_str(", ");
243            template.push_str(&lit.value());
244        }
245        AsmArg::In(tokens) => {
246            template.push_str(", {");
247            template.push_str(reg_name);
248            template.push('}');
249            args.push(quote::quote! { #reg_ident = in(reg) #tokens });
250        }
251        AsmArg::Out(tokens) => {
252            template.push_str(", {");
253            template.push_str(reg_name);
254            template.push('}');
255            args.push(quote::quote! { #reg_ident = out(reg) #tokens });
256        }
257        AsmArg::InOut(tokens) => {
258            template.push_str(", {");
259            template.push_str(reg_name);
260            template.push('}');
261            args.push(quote::quote! { #reg_ident = inout(reg) #tokens });
262        }
263        AsmArg::ConstExpr(tokens) => {
264            template.push_str(", {");
265            template.push_str(reg_name);
266            template.push('}');
267            args.push(quote::quote! { #reg_ident = const #tokens });
268        }
269    }
270}
271
272mod kw {
273    syn::custom_keyword!(In);
274    syn::custom_keyword!(Out);
275    syn::custom_keyword!(InOut);
276    syn::custom_keyword!(Const);
277}
278
279/// Custom RISC-V instruction macro for the zkVM.
280///
281/// This macro is used to define custom R-type RISC-V instructions for the zkVM.
282/// Usage:
283/// ```rust
284/// custom_insn_r!(
285///     opcode = OPCODE,
286///     funct3 = FUNCT3,
287///     funct7 = FUNCT7,
288///     rd = InOut x0,
289///     rs1 = In rs1,
290///     rs2 = In rs2
291/// );
292/// ```
293/// Here, `opcode`, `funct3`, and `funct7` are the opcode, funct3, and funct7 fields of the RISC-V
294/// instruction. `rd`, `rs1`, and `rs2` are the destination register, source register 1, and source
295/// register 2 respectively. The `In`, `Out`, `InOut`, and `Const` keywords are required to specify
296/// the type of the register arguments. They translate to `in(reg)`, `out(reg)`, `inout(reg)`, and
297/// `const` respectively, and mean
298/// - "read the value from this variable" before execution (`In`),
299/// - "write the value to this variable" after execution (`Out`),
300/// - "read the value from this variable, then write it back to the same variable" after execution
301///   (`InOut`), and
302/// - "use this constant value" (`Const`).
303#[proc_macro]
304pub fn custom_insn_r(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
305    let CustomInsnR {
306        rd,
307        rs1,
308        rs2,
309        opcode,
310        funct3,
311        funct7,
312    } = syn::parse_macro_input!(input as CustomInsnR);
313
314    let mut template = String::from(".insn r {opcode}, {funct3}, {funct7}");
315    let mut args = vec![];
316
317    // Helper function to handle register arguments
318    handle_reg_arg(&mut template, &mut args, &rd, "rd");
319    handle_reg_arg(&mut template, &mut args, &rs1, "rs1");
320    handle_reg_arg(&mut template, &mut args, &rs2, "rs2");
321
322    let expanded = quote::quote! {
323        #[cfg(target_os = "zkvm")]
324        unsafe {
325            core::arch::asm!(
326                #template,
327                opcode = const #opcode,
328                funct3 = const #funct3,
329                funct7 = const #funct7,
330                #(#args),*
331            )
332        }
333    };
334
335    expanded.into()
336}
337
338/// Custom RISC-V instruction macro for the zkVM.
339///
340/// This macro is used to define custom I-type RISC-V instructions for the zkVM.
341/// Usage:
342/// ```rust
343/// custom_insn_r!(
344///     opcode = OPCODE,
345///     funct3 = FUNCT3,
346///     rd = InOut x0,
347///     rs1 = In rs1,
348///     imm = Const 123
349/// );
350/// ```
351/// Here, `opcode`, `funct3` are the opcode and funct3 fields of the RISC-V instruction.
352/// `rd`, `rs1`, and `imm` are the destination register, source register 1, and immediate value
353/// respectively. The `In`, `Out`, `InOut`, and `Const` keywords are required to specify the type of
354/// the register arguments. They translate to `in(reg)`, `out(reg)`, `inout(reg)`, and `const`
355/// respectively, and mean
356/// - "read the value from this variable" before execution (`In`),
357/// - "write the value to this variable" after execution (`Out`),
358/// - "read the value from this variable, then write it back to the same variable" after execution
359///   (`InOut`), and
360/// - "use this constant value" (`Const`).
361///
362/// The `imm` argument is required to be a constant value.
363#[proc_macro]
364pub fn custom_insn_i(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
365    let CustomInsnI {
366        rd,
367        rs1,
368        imm,
369        opcode,
370        funct3,
371    } = syn::parse_macro_input!(input as CustomInsnI);
372
373    let mut template = String::from(".insn i {opcode}, {funct3}");
374    let mut args = vec![];
375
376    // Helper function to handle register arguments
377    handle_reg_arg(&mut template, &mut args, &rd, "rd");
378    handle_reg_arg(&mut template, &mut args, &rs1, "rs1");
379    handle_reg_arg(&mut template, &mut args, &imm, "imm");
380
381    let expanded = quote::quote! {
382        #[cfg(target_os = "zkvm")]
383        unsafe {
384            core::arch::asm!(
385                #template,
386                opcode = const #opcode,
387                funct3 = const #funct3,
388                #(#args),*
389            )
390        }
391    };
392
393    expanded.into()
394}