openvm_ecc_sw_macros/
lib.rs

1extern crate proc_macro;
2
3use openvm_macros_common::MacroArgs;
4use proc_macro::TokenStream;
5use quote::format_ident;
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input, Expr, ExprPath, Path, Token,
9};
10
11/// This macro generates the code to setup the elliptic curve for a given modular type. Also it
12/// places the curve parameters into a special static variable to be later extracted from the ELF
13/// and used by the VM. Usage:
14/// ```
15/// sw_declare! {
16///     [TODO]
17/// }
18/// ```
19///
20/// For this macro to work, you must import the `elliptic_curve` crate and the `openvm_ecc_guest`
21/// crate.
22#[proc_macro]
23pub fn sw_declare(input: TokenStream) -> TokenStream {
24    let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
25
26    let mut output = Vec::new();
27
28    let span = proc_macro::Span::call_site();
29
30    for item in items.into_iter() {
31        let struct_name = item.name.to_string();
32        let struct_name = syn::Ident::new(&struct_name, span.into());
33        let struct_path: syn::Path = syn::parse_quote!(#struct_name);
34        let mut intmod_type: Option<syn::Path> = None;
35        let mut const_a: Option<syn::Expr> = None;
36        let mut const_b: Option<syn::Expr> = None;
37        for param in item.params {
38            match param.name.to_string().as_str() {
39                // Note that mod_type must have NUM_LIMBS divisible by 4
40                "mod_type" => {
41                    if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
42                        intmod_type = Some(path)
43                    } else {
44                        return syn::Error::new_spanned(param.value, "Expected a type")
45                            .to_compile_error()
46                            .into();
47                    }
48                }
49                "a" => {
50                    // We currently leave it to the compiler to check if the expression is actually
51                    // a constant
52                    const_a = Some(param.value);
53                }
54                "b" => {
55                    // We currently leave it to the compiler to check if the expression is actually
56                    // a constant
57                    const_b = Some(param.value);
58                }
59                _ => {
60                    panic!("Unknown parameter {}", param.name);
61                }
62            }
63        }
64
65        let intmod_type = intmod_type.expect("mod_type parameter is required");
66        // const_a is optional, default to 0
67        let const_a = const_a
68            .unwrap_or(syn::parse_quote!(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO));
69        let const_b = const_b.expect("constant b coefficient is required");
70
71        macro_rules! create_extern_func {
72            ($name:ident) => {
73                let $name = syn::Ident::new(
74                    &format!(
75                        "{}_{}",
76                        stringify!($name),
77                        struct_path
78                            .segments
79                            .iter()
80                            .map(|x| x.ident.to_string())
81                            .collect::<Vec<_>>()
82                            .join("_")
83                    ),
84                    span.into(),
85                );
86            };
87        }
88        create_extern_func!(sw_add_ne_extern_func);
89        create_extern_func!(sw_double_extern_func);
90        create_extern_func!(hint_decompress_extern_func);
91        create_extern_func!(hint_non_qr_extern_func);
92
93        let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase());
94
95        let result = TokenStream::from(quote::quote_spanned! { span.into() =>
96            extern "C" {
97                fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize);
98                fn #sw_double_extern_func(rd: usize, rs1: usize);
99                fn #hint_decompress_extern_func(rs1: usize, rs2: usize);
100                fn #hint_non_qr_extern_func();
101            }
102
103            #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)]
104            #[repr(C)]
105            pub struct #struct_name {
106                x: #intmod_type,
107                y: #intmod_type,
108            }
109            #[allow(non_upper_case_globals)]
110
111            impl #struct_name {
112                const fn identity() -> Self {
113                    Self {
114                        x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
115                        y: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
116                    }
117                }
118                // Below are wrapper functions for the intrinsic instructions.
119                // Should not be called directly.
120                #[inline(always)]
121                fn add_ne(p1: &#struct_name, p2: &#struct_name) -> #struct_name {
122                    #[cfg(not(target_os = "zkvm"))]
123                    {
124                        use openvm_algebra_guest::DivUnsafe;
125                        let lambda = (&p2.y - &p1.y).div_unsafe(&p2.x - &p1.x);
126                        let x3 = &lambda * &lambda - &p1.x - &p2.x;
127                        let y3 = &lambda * &(&p1.x - &x3) - &p1.y;
128                        #struct_name { x: x3, y: y3 }
129                    }
130                    #[cfg(target_os = "zkvm")]
131                    {
132                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
133                        unsafe {
134                            #sw_add_ne_extern_func(
135                                uninit.as_mut_ptr() as usize,
136                                p1 as *const #struct_name as usize,
137                                p2 as *const #struct_name as usize
138                            )
139                        };
140                        unsafe { uninit.assume_init() }
141                    }
142                }
143
144                #[inline(always)]
145                fn add_ne_assign(&mut self, p2: &#struct_name) {
146                    #[cfg(not(target_os = "zkvm"))]
147                    {
148                        use openvm_algebra_guest::DivUnsafe;
149                        let lambda = (&p2.y - &self.y).div_unsafe(&p2.x - &self.x);
150                        let x3 = &lambda * &lambda - &self.x - &p2.x;
151                        let y3 = &lambda * &(&self.x - &x3) - &self.y;
152                        self.x = x3;
153                        self.y = y3;
154                    }
155                    #[cfg(target_os = "zkvm")]
156                    {
157                        unsafe {
158                            #sw_add_ne_extern_func(
159                                self as *mut #struct_name as usize,
160                                self as *const #struct_name as usize,
161                                p2 as *const #struct_name as usize
162                            )
163                        };
164                    }
165                }
166
167                /// Assumes that `p` is not identity.
168                #[inline(always)]
169                fn double_impl(p: &#struct_name) -> #struct_name {
170                    #[cfg(not(target_os = "zkvm"))]
171                    {
172                        use openvm_algebra_guest::DivUnsafe;
173                        let curve_a: #intmod_type = #const_a;
174                        let two = #intmod_type::from_u8(2);
175                        let lambda = (&p.x * &p.x * #intmod_type::from_u8(3) + &curve_a).div_unsafe(&p.y * &two);
176                        let x3 = &lambda * &lambda - &p.x * &two;
177                        let y3 = &lambda * &(&p.x - &x3) - &p.y;
178                        #struct_name { x: x3, y: y3 }
179                    }
180                    #[cfg(target_os = "zkvm")]
181                    {
182                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
183                        unsafe {
184                            #sw_double_extern_func(
185                                uninit.as_mut_ptr() as usize,
186                                p as *const #struct_name as usize,
187                            )
188                        };
189                        unsafe { uninit.assume_init() }
190                    }
191                }
192
193                #[inline(always)]
194                fn double_assign_impl(&mut self) {
195                    #[cfg(not(target_os = "zkvm"))]
196                    {
197                        *self = Self::double_impl(self);
198                    }
199                    #[cfg(target_os = "zkvm")]
200                    {
201                        unsafe {
202                            #sw_double_extern_func(
203                                self as *mut #struct_name as usize,
204                                self as *const #struct_name as usize
205                            )
206                        };
207                    }
208                }
209
210            }
211
212            impl ::openvm_ecc_guest::weierstrass::WeierstrassPoint for #struct_name {
213                const CURVE_A: #intmod_type = #const_a;
214                const CURVE_B: #intmod_type = #const_b;
215                const IDENTITY: Self = Self::identity();
216                type Coordinate = #intmod_type;
217
218                /// SAFETY: assumes that #intmod_type has a memory representation
219                /// such that with repr(C), two coordinates are packed contiguously.
220                fn as_le_bytes(&self) -> &[u8] {
221                    unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) }
222                }
223
224                fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
225                    Self { x, y }
226                }
227
228                fn x(&self) -> &Self::Coordinate {
229                    &self.x
230                }
231
232                fn y(&self) -> &Self::Coordinate {
233                    &self.y
234                }
235
236                fn x_mut(&mut self) -> &mut Self::Coordinate {
237                    &mut self.x
238                }
239
240                fn y_mut(&mut self) -> &mut Self::Coordinate {
241                    &mut self.y
242                }
243
244                fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
245                    (self.x, self.y)
246                }
247
248                fn add_ne_nonidentity(&self, p2: &Self) -> Self {
249                    Self::add_ne(self, p2)
250                }
251
252                fn add_ne_assign_nonidentity(&mut self, p2: &Self) {
253                    Self::add_ne_assign(self, p2);
254                }
255
256                fn sub_ne_nonidentity(&self, p2: &Self) -> Self {
257                    Self::add_ne(self, &p2.clone().neg())
258                }
259
260                fn sub_ne_assign_nonidentity(&mut self, p2: &Self) {
261                    Self::add_ne_assign(self, &p2.clone().neg());
262                }
263
264                fn double_nonidentity(&self) -> Self {
265                    Self::double_impl(self)
266                }
267
268                fn double_assign_nonidentity(&mut self) {
269                    Self::double_assign_impl(self);
270                }
271            }
272
273            impl core::ops::Neg for #struct_name {
274                type Output = Self;
275
276                fn neg(self) -> Self::Output {
277                    #struct_name {
278                        x: self.x,
279                        y: -self.y,
280                    }
281                }
282            }
283
284            impl core::ops::Neg for &#struct_name {
285                type Output = #struct_name;
286
287                fn neg(self) -> #struct_name {
288                    #struct_name {
289                        x: self.x.clone(),
290                        y: core::ops::Neg::neg(&self.y),
291                    }
292                }
293            }
294
295            mod #group_ops_mod_name {
296                use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed, DecompressionHint}, impl_sw_group_ops, algebra::{IntMod, DivUnsafe, DivAssignUnsafe, ExpBytes}};
297                use super::*;
298
299                impl_sw_group_ops!(#struct_name, #intmod_type);
300
301                impl FromCompressed<#intmod_type> for #struct_name {
302                    fn decompress(x: #intmod_type, rec_id: &u8) -> Option<Self> {
303                        match Self::honest_host_decompress(&x, rec_id) {
304                            // successfully decompressed
305                            Some(Some(ret)) => Some(ret),
306                            // successfully proved that the point cannot be decompressed
307                            Some(None) => None,
308                            None => {
309                                // host is dishonest, enter infinite loop
310                                loop {
311                                    openvm::io::println("ERROR: Decompression hint is invalid. Entering infinite loop.");
312                                }
313                            }
314                        }
315                    }
316
317                    fn hint_decompress(x: &#intmod_type, rec_id: &u8) -> Option<DecompressionHint<#intmod_type>> {
318                        #[cfg(not(target_os = "zkvm"))]
319                        {
320                            unimplemented!()
321                        }
322                        #[cfg(target_os = "zkvm")]
323                        {
324                            use openvm::platform as openvm_platform; // needed for hint_store_u32!
325
326                            let possible = core::mem::MaybeUninit::<u32>::uninit();
327                            let sqrt = core::mem::MaybeUninit::<#intmod_type>::uninit();
328                            unsafe {
329                                #hint_decompress_extern_func(x as *const _ as usize, rec_id as *const u8 as usize);
330                                let possible_ptr = possible.as_ptr() as *const u32;
331                                openvm_rv32im_guest::hint_store_u32!(possible_ptr);
332                                openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
333                                let possible = possible.assume_init();
334                                if possible == 0 || possible == 1 {
335                                    Some(DecompressionHint { possible: possible == 1, sqrt: sqrt.assume_init() })
336                                } else {
337                                    None
338                                }
339                            }
340                        }
341                    }
342                }
343
344                impl #struct_name {
345                    // Returns None if the hint is incorrect (i.e. the host is dishonest)
346                    // Returns Some(None) if the hint proves that the point cannot be decompressed
347                    fn honest_host_decompress(x: &#intmod_type, rec_id: &u8) -> Option<Option<Self>> {
348                        let hint = <#struct_name as FromCompressed<#intmod_type>>::hint_decompress(x, rec_id)?;
349
350                        if hint.possible {
351                            // ensure y < modulus
352                            hint.sqrt.assert_reduced();
353
354                            if hint.sqrt.as_le_bytes()[0] & 1 != *rec_id & 1 {
355                                None
356                            } else {
357                                let ret = <#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::from_xy_nonidentity(x.clone(), hint.sqrt)?;
358                                Some(Some(ret))
359                            }
360                        } else {
361                            // ensure sqrt < modulus
362                            hint.sqrt.assert_reduced();
363
364                            let alpha = (x * x * x) + (x * &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A) + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_B;
365                            if &hint.sqrt * &hint.sqrt == alpha * Self::get_non_qr() {
366                                Some(None)
367                            } else {
368                                None
369                            }
370                        }
371                    }
372
373                    // Generate a non quadratic residue in the coordinate field by using a hint
374                    fn init_non_qr() -> alloc::boxed::Box<<Self as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate> {
375                        #[cfg(not(target_os = "zkvm"))]
376                        {
377                            unimplemented!();
378                        }
379                        #[cfg(target_os = "zkvm")]
380                        {
381                            use openvm::platform as openvm_platform; // needed for hint_buffer_u32
382                            let mut non_qr_uninit = core::mem::MaybeUninit::<#intmod_type>::uninit();
383                            let mut non_qr;
384                            unsafe {
385                                #hint_non_qr_extern_func();
386                                let ptr = non_qr_uninit.as_ptr() as *const u8;
387                                openvm_rv32im_guest::hint_buffer_u32!(ptr, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
388                                non_qr = non_qr_uninit.assume_init();
389                            }
390                            // ensure non_qr < modulus
391                            non_qr.assert_reduced();
392
393                            // construct exp = (p-1)/2 as an integer by first constraining exp = (p-1)/2 (mod p) and then exp < p
394                            let exp = -<#intmod_type as openvm_algebra_guest::IntMod>::ONE.div_unsafe(#intmod_type::from_const_u8(2));
395                            exp.assert_reduced();
396
397                            if non_qr.exp_bytes(true, &exp.to_be_bytes()) != -<#intmod_type as openvm_algebra_guest::IntMod>::ONE
398                            {
399                                // non_qr is not a non quadratic residue, so host is dishonest
400                                loop {
401                                    openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop.");
402                                }
403                            }
404
405                            alloc::boxed::Box::new(non_qr)
406                        }
407                    }
408
409                    pub fn get_non_qr() -> &'static #intmod_type {
410                        static non_qr: ::openvm_ecc_guest::once_cell::race::OnceBox<#intmod_type> = ::openvm_ecc_guest::once_cell::race::OnceBox::new();
411                        &non_qr.get_or_init(Self::init_non_qr)
412                    }
413                }
414            }
415        });
416        output.push(result);
417    }
418
419    TokenStream::from_iter(output)
420}
421
422struct SwDefine {
423    items: Vec<Path>,
424}
425
426impl Parse for SwDefine {
427    fn parse(input: ParseStream) -> syn::Result<Self> {
428        let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
429        Ok(Self {
430            items: items
431                .into_iter()
432                .map(|e| {
433                    if let Expr::Path(p) = e {
434                        p.path
435                    } else {
436                        panic!("expected path");
437                    }
438                })
439                .collect(),
440        })
441    }
442}
443
444#[proc_macro]
445pub fn sw_init(input: TokenStream) -> TokenStream {
446    let SwDefine { items } = parse_macro_input!(input as SwDefine);
447
448    let mut externs = Vec::new();
449    let mut setups = Vec::new();
450    let mut setup_all_curves = Vec::new();
451
452    let span = proc_macro::Span::call_site();
453
454    for (ec_idx, item) in items.into_iter().enumerate() {
455        let str_path = item
456            .segments
457            .iter()
458            .map(|x| x.ident.to_string())
459            .collect::<Vec<_>>()
460            .join("_");
461        let add_ne_extern_func =
462            syn::Ident::new(&format!("sw_add_ne_extern_func_{}", str_path), span.into());
463        let double_extern_func =
464            syn::Ident::new(&format!("sw_double_extern_func_{}", str_path), span.into());
465        let hint_decompress_extern_func = syn::Ident::new(
466            &format!("hint_decompress_extern_func_{}", str_path),
467            span.into(),
468        );
469        let hint_non_qr_extern_func = syn::Ident::new(
470            &format!("hint_non_qr_extern_func_{}", str_path),
471            span.into(),
472        );
473        externs.push(quote::quote_spanned! { span.into() =>
474            #[no_mangle]
475            extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) {
476                openvm::platform::custom_insn_r!(
477                    opcode = OPCODE,
478                    funct3 = SW_FUNCT3 as usize,
479                    funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx
480                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
481                    rd = In rd,
482                    rs1 = In rs1,
483                    rs2 = In rs2
484                );
485            }
486
487            #[no_mangle]
488            extern "C" fn #double_extern_func(rd: usize, rs1: usize) {
489                openvm::platform::custom_insn_r!(
490                    opcode = OPCODE,
491                    funct3 = SW_FUNCT3 as usize,
492                    funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx
493                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
494                    rd = In rd,
495                    rs1 = In rs1,
496                    rs2 = Const "x0"
497                );
498            }
499
500            #[no_mangle]
501            extern "C" fn #hint_decompress_extern_func(rs1: usize, rs2: usize) {
502                openvm::platform::custom_insn_r!(
503                    opcode = OPCODE,
504                    funct3 = SW_FUNCT3 as usize,
505                    funct7 = SwBaseFunct7::HintDecompress as usize + #ec_idx
506                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
507                    rd = Const "x0",
508                    rs1 = In rs1,
509                    rs2 = In rs2
510                );
511            }
512
513            #[no_mangle]
514            extern "C" fn #hint_non_qr_extern_func() {
515                openvm::platform::custom_insn_r!(
516                    opcode = OPCODE,
517                    funct3 = SW_FUNCT3 as usize,
518                    funct7 = SwBaseFunct7::HintNonQr as usize + #ec_idx
519                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
520                    rd = Const "x0",
521                    rs1 = Const "x0",
522                    rs2 = Const "x0"
523                );
524            }
525        });
526
527        let setup_function = syn::Ident::new(&format!("setup_sw_{}", str_path), span.into());
528        setups.push(quote::quote_spanned! { span.into() =>
529            #[allow(non_snake_case)]
530            pub fn #setup_function() {
531                #[cfg(target_os = "zkvm")]
532                {
533                    // p1 is (x1, y1), and x1 must be the modulus.
534                    // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
535                    let modulus_bytes = <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
536                    let mut one = [0u8; <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
537                    one[0] = 1;
538                    let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
539                    // p1 should be (p, a)
540                    let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
541                    // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
542                    let p2 = [one.as_ref(), one.as_ref()].concat();
543                    let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit();
544                    openvm::platform::custom_insn_r!(
545                        opcode = ::openvm_ecc_guest::OPCODE,
546                        funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
547                        funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
548                            + #ec_idx
549                                * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
550                        rd = In uninit.as_mut_ptr(),
551                        rs1 = In p1.as_ptr(),
552                        rs2 = In p2.as_ptr()
553                    );
554                    openvm::platform::custom_insn_r!(
555                        opcode = ::openvm_ecc_guest::OPCODE,
556                        funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
557                        funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
558                            + #ec_idx
559                                * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
560                        rd = In uninit.as_mut_ptr(),
561                        rs1 = In p1.as_ptr(),
562                        rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_EC_DOUBLE
563                    );
564                }
565            }
566        });
567
568        setup_all_curves.push(quote::quote_spanned! { span.into() =>
569            #setup_function();
570        });
571    }
572
573    TokenStream::from(quote::quote_spanned! { span.into() =>
574        #[cfg(target_os = "zkvm")]
575        mod openvm_intrinsics_ffi_2 {
576            use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7};
577
578            #(#externs)*
579        }
580        #(#setups)*
581        pub fn setup_all_curves() {
582            #(#setup_all_curves)*
583        }
584    })
585}