openvm_ecc_sw_macros/
lib.rs

1extern crate proc_macro;
2
3use openvm_macros_common::MacroArgs;
4use proc_macro::TokenStream;
5use quote::format_ident;
6use syn::{
7    parse::{Parse, ParseStream},
8    parse_macro_input, ExprPath, LitStr, Token,
9};
10
11/// This macro generates the code to setup the elliptic curve for a given modular type. Also it
12/// places the curve parameters into a special static variable to be later extracted from the ELF
13/// and used by the VM. Usage:
14/// ```
15/// sw_declare! {
16///     [TODO]
17/// }
18/// ```
19///
20/// For this macro to work, you must import the `elliptic_curve` crate and the `openvm_ecc_guest`
21/// crate.
22#[proc_macro]
23pub fn sw_declare(input: TokenStream) -> TokenStream {
24    let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
25
26    let mut output = Vec::new();
27
28    let span = proc_macro::Span::call_site();
29
30    for item in items.into_iter() {
31        let struct_name_str = item.name.to_string();
32        let struct_name = syn::Ident::new(&struct_name_str, span.into());
33        let mut intmod_type: Option<syn::Path> = None;
34        let mut const_a: Option<syn::Expr> = None;
35        let mut const_b: Option<syn::Expr> = None;
36        for param in item.params {
37            match param.name.to_string().as_str() {
38                // Note that mod_type must have NUM_LIMBS divisible by 4
39                "mod_type" => {
40                    if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
41                        intmod_type = Some(path)
42                    } else {
43                        return syn::Error::new_spanned(param.value, "Expected a type")
44                            .to_compile_error()
45                            .into();
46                    }
47                }
48                "a" => {
49                    // We currently leave it to the compiler to check if the expression is actually
50                    // a constant
51                    const_a = Some(param.value);
52                }
53                "b" => {
54                    // We currently leave it to the compiler to check if the expression is actually
55                    // a constant
56                    const_b = Some(param.value);
57                }
58                _ => {
59                    panic!("Unknown parameter {}", param.name);
60                }
61            }
62        }
63
64        let intmod_type = intmod_type.expect("mod_type parameter is required");
65        // const_a is optional, default to 0
66        let const_a = const_a
67            .unwrap_or(syn::parse_quote!(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO));
68        let const_b = const_b.expect("constant b coefficient is required");
69
70        macro_rules! create_extern_func {
71            ($name:ident) => {
72                let $name = syn::Ident::new(
73                    &format!("{}_{}", stringify!($name), struct_name_str),
74                    span.into(),
75                );
76            };
77        }
78        create_extern_func!(sw_add_ne_extern_func);
79        create_extern_func!(sw_double_extern_func);
80        create_extern_func!(sw_setup_extern_func);
81
82        let group_ops_mod_name = format_ident!("{}_ops", struct_name_str.to_lowercase());
83
84        let result = TokenStream::from(quote::quote_spanned! { span.into() =>
85            extern "C" {
86                fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize);
87                fn #sw_double_extern_func(rd: usize, rs1: usize);
88                fn #sw_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8);
89            }
90
91            #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)]
92            #[repr(C)]
93            pub struct #struct_name {
94                x: #intmod_type,
95                y: #intmod_type,
96            }
97            #[allow(non_upper_case_globals)]
98
99            impl #struct_name {
100                const fn identity() -> Self {
101                    Self {
102                        x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
103                        y: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
104                    }
105                }
106                // Below are wrapper functions for the intrinsic instructions.
107                // Should not be called directly.
108                #[inline(always)]
109                unsafe fn add_ne<const CHECK_SETUP: bool>(p1: &#struct_name, p2: &#struct_name) -> #struct_name {
110                    #[cfg(not(target_os = "zkvm"))]
111                    {
112                        use openvm_algebra_guest::DivUnsafe;
113                        let lambda = (&p2.y - &p1.y).div_unsafe(&p2.x - &p1.x);
114                        let x3 = &lambda * &lambda - &p1.x - &p2.x;
115                        let y3 = &lambda * &(&p1.x - &x3) - &p1.y;
116                        #struct_name { x: x3, y: y3 }
117                    }
118                    #[cfg(target_os = "zkvm")]
119                    {
120                        if CHECK_SETUP {
121                            Self::set_up_once();
122                        }
123                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
124                        #sw_add_ne_extern_func(
125                            uninit.as_mut_ptr() as usize,
126                            p1 as *const #struct_name as usize,
127                            p2 as *const #struct_name as usize
128                        );
129                        uninit.assume_init()
130                    }
131                }
132
133                #[inline(always)]
134                unsafe fn add_ne_assign<const CHECK_SETUP: bool>(&mut self, p2: &#struct_name) {
135                    #[cfg(not(target_os = "zkvm"))]
136                    {
137                        use openvm_algebra_guest::DivUnsafe;
138                        let lambda = (&p2.y - &self.y).div_unsafe(&p2.x - &self.x);
139                        let x3 = &lambda * &lambda - &self.x - &p2.x;
140                        let y3 = &lambda * &(&self.x - &x3) - &self.y;
141                        self.x = x3;
142                        self.y = y3;
143                    }
144                    #[cfg(target_os = "zkvm")]
145                    {
146                        if CHECK_SETUP {
147                            Self::set_up_once();
148                        }
149                        #sw_add_ne_extern_func(
150                            self as *mut #struct_name as usize,
151                            self as *const #struct_name as usize,
152                            p2 as *const #struct_name as usize
153                        );
154                    }
155                }
156
157                /// Assumes that `p` is not identity.
158                #[inline(always)]
159                unsafe fn double_impl<const CHECK_SETUP: bool>(p: &#struct_name) -> #struct_name {
160                    #[cfg(not(target_os = "zkvm"))]
161                    {
162                        use openvm_algebra_guest::DivUnsafe;
163                        let curve_a: #intmod_type = #const_a;
164                        let two = #intmod_type::from_u8(2);
165                        let lambda = (&p.x * &p.x * #intmod_type::from_u8(3) + &curve_a).div_unsafe(&p.y * &two);
166                        let x3 = &lambda * &lambda - &p.x * &two;
167                        let y3 = &lambda * &(&p.x - &x3) - &p.y;
168                        #struct_name { x: x3, y: y3 }
169                    }
170                    #[cfg(target_os = "zkvm")]
171                    {
172                        if CHECK_SETUP {
173                            Self::set_up_once();
174                        }
175                        let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
176                        #sw_double_extern_func(
177                            uninit.as_mut_ptr() as usize,
178                            p as *const #struct_name as usize,
179                        );
180                        uninit.assume_init()
181                    }
182                }
183
184                // Helper function to call the setup instruction on first use
185                #[inline(always)]
186                #[cfg(target_os = "zkvm")]
187                fn set_up_once() {
188                    static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new();
189
190                    is_setup.get_or_init(|| {
191                        // p1 is (x1, y1), and x1 must be the modulus.
192                        // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
193                        let modulus_bytes = <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
194                        let mut one = [0u8; <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
195                        one[0] = 1;
196                        let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#struct_name as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
197                        // p1 should be (p, a)
198                        let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
199                        // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
200                        let p2 = [one.as_ref(), one.as_ref()].concat();
201                        let mut uninit: core::mem::MaybeUninit<[Self; 2]> = core::mem::MaybeUninit::uninit();
202
203                        unsafe { #sw_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); }
204                        <#intmod_type as openvm_algebra_guest::IntMod>::set_up_once();
205                        true
206                    });
207                }
208
209                #[inline(always)]
210                #[cfg(not(target_os = "zkvm"))]
211                fn set_up_once() {
212                    // No-op for non-ZKVM targets
213                }
214
215                #[inline(always)]
216                fn is_identity_impl<const CHECK_SETUP: bool>(&self) -> bool {
217                    use openvm_algebra_guest::IntMod;
218                    // Safety: Self::set_up_once() ensures IntMod::set_up_once() has been called.
219                    unsafe {
220                        self.x.eq_impl::<CHECK_SETUP>(&#intmod_type::ZERO) && self.y.eq_impl::<CHECK_SETUP>(&#intmod_type::ZERO)
221                    }
222                }
223            }
224
225            impl ::openvm_ecc_guest::weierstrass::WeierstrassPoint for #struct_name {
226                const CURVE_A: #intmod_type = #const_a;
227                const CURVE_B: #intmod_type = #const_b;
228                const IDENTITY: Self = Self::identity();
229                type Coordinate = #intmod_type;
230
231                /// SAFETY: assumes that #intmod_type has a memory representation
232                /// such that with repr(C), two coordinates are packed contiguously.
233                #[inline(always)]
234                fn as_le_bytes(&self) -> &[u8] {
235                    unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) }
236                }
237
238                #[inline(always)]
239                fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
240                    Self { x, y }
241                }
242
243                #[inline(always)]
244                fn x(&self) -> &Self::Coordinate {
245                    &self.x
246                }
247
248                #[inline(always)]
249                fn y(&self) -> &Self::Coordinate {
250                    &self.y
251                }
252
253                #[inline(always)]
254                fn x_mut(&mut self) -> &mut Self::Coordinate {
255                    &mut self.x
256                }
257
258                #[inline(always)]
259                fn y_mut(&mut self) -> &mut Self::Coordinate {
260                    &mut self.y
261                }
262
263                #[inline(always)]
264                fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
265                    (self.x, self.y)
266                }
267
268                #[inline(always)]
269                fn set_up_once() {
270                    Self::set_up_once();
271                }
272
273                #[inline]
274                fn add_assign_impl<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
275                    use openvm_algebra_guest::IntMod;
276
277                    if CHECK_SETUP {
278                        // Call setup here so we skip it below
279                        #intmod_type::set_up_once();
280                    }
281
282                    if self.is_identity_impl::<CHECK_SETUP>() {
283                        *self = p2.clone();
284                    } else if p2.is_identity_impl::<CHECK_SETUP>() {
285                        // do nothing
286                    } else if unsafe { self.x.eq_impl::<false>(&p2.x) } { // Safety: we called IntMod setup above
287                        let sum_ys = unsafe { self.y.add_ref::<false>(&p2.y) };
288                        // Safety: we called IntMod setup above
289                        if unsafe { IntMod::eq_impl::<false>(&sum_ys, &<#intmod_type as IntMod>::ZERO) } {
290                            *self = Self::identity();
291                        } else {
292                            unsafe {
293                                self.double_assign_nonidentity::<CHECK_SETUP>();
294                            }
295                        }
296                    } else {
297                        unsafe {
298                            self.add_ne_assign_nonidentity::<CHECK_SETUP>(p2);
299                        }
300                    }
301                }
302
303                #[inline(always)]
304                fn double_assign_impl<const CHECK_SETUP: bool>(&mut self) {
305                    if !self.is_identity_impl::<CHECK_SETUP>() {
306                        unsafe {
307                            self.double_assign_nonidentity::<CHECK_SETUP>();
308                        }
309                    }
310                }
311
312                #[inline(always)]
313                unsafe fn add_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
314                    Self::add_ne::<CHECK_SETUP>(self, p2)
315                }
316
317                #[inline(always)]
318                unsafe fn add_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
319                    Self::add_ne_assign::<CHECK_SETUP>(self, p2);
320                }
321
322                #[inline(always)]
323                unsafe fn sub_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
324                    Self::add_ne::<CHECK_SETUP>(self, &p2.clone().neg())
325                }
326
327                #[inline(always)]
328                unsafe fn sub_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
329                    Self::add_ne_assign::<CHECK_SETUP>(self, &p2.clone().neg());
330                }
331
332                #[inline(always)]
333                unsafe fn double_nonidentity<const CHECK_SETUP: bool>(&self) -> Self {
334                    Self::double_impl::<CHECK_SETUP>(self)
335                }
336
337                #[inline(always)]
338                unsafe fn double_assign_nonidentity<const CHECK_SETUP: bool>(&mut self) {
339                    #[cfg(not(target_os = "zkvm"))]
340                    {
341                        *self = Self::double_impl::<CHECK_SETUP>(self);
342                    }
343                    #[cfg(target_os = "zkvm")]
344                    {
345                        if CHECK_SETUP {
346                            Self::set_up_once();
347                        }
348                        #sw_double_extern_func(
349                            self as *mut #struct_name as usize,
350                            self as *const #struct_name as usize
351                        );
352                    }
353                }
354            }
355
356            impl core::ops::Neg for #struct_name {
357                type Output = Self;
358
359                fn neg(self) -> Self::Output {
360                    #struct_name {
361                        x: self.x,
362                        y: -self.y,
363                    }
364                }
365            }
366
367            impl core::ops::Neg for &#struct_name {
368                type Output = #struct_name;
369
370                fn neg(self) -> #struct_name {
371                    #struct_name {
372                        x: self.x.clone(),
373                        y: core::ops::Neg::neg(&self.y),
374                    }
375                }
376            }
377
378            mod #group_ops_mod_name {
379                use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed}, impl_sw_group_ops, algebra::IntMod};
380                use super::*;
381
382                impl_sw_group_ops!(#struct_name, #intmod_type);
383
384                impl FromCompressed<#intmod_type> for #struct_name {
385                    fn decompress(x: #intmod_type, rec_id: &u8) -> Option<Self> {
386                        use openvm_algebra_guest::Sqrt;
387                        let y_squared = &x * &x * &x + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A * &x + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_B;
388                        let y = y_squared.sqrt();
389                        match y {
390                            None => None,
391                            Some(y) => {
392                                let correct_y = if y.as_le_bytes()[0] & 1 == *rec_id & 1 {
393                                    y
394                                } else {
395                                    -y
396                                };
397                                // If y = 0 then negating y doesn't change its parity
398                                if correct_y.as_le_bytes()[0] & 1 != *rec_id & 1 {
399                                    return None;
400                                }
401                                // In order for sqrt() to return Some, we are guaranteed that y * y == y_squared, which already proves (x, correct_y) is on the curve
402                                Some(<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::from_xy_unchecked(x, correct_y))
403                            }
404                        }
405                    }
406                }
407            }
408        });
409        output.push(result);
410    }
411
412    TokenStream::from_iter(output)
413}
414
415struct SwDefine {
416    items: Vec<String>,
417}
418
419impl Parse for SwDefine {
420    fn parse(input: ParseStream) -> syn::Result<Self> {
421        let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
422        Ok(Self {
423            items: items.into_iter().map(|e| e.value()).collect(),
424        })
425    }
426}
427
428#[proc_macro]
429pub fn sw_init(input: TokenStream) -> TokenStream {
430    let SwDefine { items } = parse_macro_input!(input as SwDefine);
431
432    let mut externs = Vec::new();
433
434    let span = proc_macro::Span::call_site();
435
436    for (ec_idx, struct_id) in items.into_iter().enumerate() {
437        // Unique identifier shared by sw_define! and sw_init! used for naming the extern funcs.
438        // Currently it's just the struct type name.
439        let add_ne_extern_func =
440            syn::Ident::new(&format!("sw_add_ne_extern_func_{}", struct_id), span.into());
441        let double_extern_func =
442            syn::Ident::new(&format!("sw_double_extern_func_{}", struct_id), span.into());
443        let setup_extern_func =
444            syn::Ident::new(&format!("sw_setup_extern_func_{}", struct_id), span.into());
445
446        externs.push(quote::quote_spanned! { span.into() =>
447            #[no_mangle]
448            extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) {
449                openvm::platform::custom_insn_r!(
450                    opcode = OPCODE,
451                    funct3 = SW_FUNCT3 as usize,
452                    funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx
453                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
454                    rd = In rd,
455                    rs1 = In rs1,
456                    rs2 = In rs2
457                );
458            }
459
460            #[no_mangle]
461            extern "C" fn #double_extern_func(rd: usize, rs1: usize) {
462                openvm::platform::custom_insn_r!(
463                    opcode = OPCODE,
464                    funct3 = SW_FUNCT3 as usize,
465                    funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx
466                        * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
467                    rd = In rd,
468                    rs1 = In rs1,
469                    rs2 = Const "x0"
470                );
471            }
472
473            #[no_mangle]
474            extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) {
475                #[cfg(target_os = "zkvm")]
476                {
477                    openvm::platform::custom_insn_r!(
478                        opcode = ::openvm_ecc_guest::OPCODE,
479                        funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
480                        funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
481                            + #ec_idx
482                                * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
483                        rd = In uninit,
484                        rs1 = In p1,
485                        rs2 = In p2
486                    );
487                    openvm::platform::custom_insn_r!(
488                        opcode = ::openvm_ecc_guest::OPCODE,
489                        funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
490                        funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
491                            + #ec_idx
492                                * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
493                        rd = In uninit,
494                        rs1 = In p1,
495                        rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_EC_DOUBLE
496                    );
497
498
499                }
500            }
501        });
502    }
503
504    TokenStream::from(quote::quote_spanned! { span.into() =>
505        #[allow(non_snake_case)]
506        #[cfg(target_os = "zkvm")]
507        mod openvm_intrinsics_ffi_2 {
508            use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7};
509
510            #(#externs)*
511        }
512    })
513}