1extern crate proc_macro;
2
3use openvm_macros_common::MacroArgs;
4use proc_macro::TokenStream;
5use quote::format_ident;
6use syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input, ExprPath, LitStr, Token,
9};
10
11#[proc_macro]
23pub fn sw_declare(input: TokenStream) -> TokenStream {
24 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
25
26 let mut output = Vec::new();
27
28 let span = proc_macro::Span::call_site();
29
30 for item in items.into_iter() {
31 let struct_name_str = item.name.to_string();
32 let struct_name = syn::Ident::new(&struct_name_str, span.into());
33 let mut intmod_type: Option<syn::Path> = None;
34 let mut const_a: Option<syn::Expr> = None;
35 let mut const_b: Option<syn::Expr> = None;
36 for param in item.params {
37 match param.name.to_string().as_str() {
38 "mod_type" => {
40 if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
41 intmod_type = Some(path)
42 } else {
43 return syn::Error::new_spanned(param.value, "Expected a type")
44 .to_compile_error()
45 .into();
46 }
47 }
48 "a" => {
49 const_a = Some(param.value);
52 }
53 "b" => {
54 const_b = Some(param.value);
57 }
58 _ => {
59 panic!("Unknown parameter {}", param.name);
60 }
61 }
62 }
63
64 let intmod_type = intmod_type.expect("mod_type parameter is required");
65 let const_a = const_a
67 .unwrap_or(syn::parse_quote!(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO));
68 let const_b = const_b.expect("constant b coefficient is required");
69
70 macro_rules! create_extern_func {
71 ($name:ident) => {
72 let $name = syn::Ident::new(
73 &format!("{}_{}", stringify!($name), struct_name_str),
74 span.into(),
75 );
76 };
77 }
78 create_extern_func!(sw_add_ne_extern_func);
79 create_extern_func!(sw_double_extern_func);
80 create_extern_func!(sw_setup_extern_func);
81
82 let group_ops_mod_name = format_ident!("{}_ops", struct_name_str.to_lowercase());
83
84 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
85 extern "C" {
86 fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize);
87 fn #sw_double_extern_func(rd: usize, rs1: usize);
88 fn #sw_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8);
89 }
90
91 #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)]
92 #[repr(C)]
93 pub struct #struct_name {
94 x: #intmod_type,
95 y: #intmod_type,
96 }
97 #[allow(non_upper_case_globals)]
98
99 impl #struct_name {
100 const fn identity() -> Self {
101 Self {
102 x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
103 y: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
104 }
105 }
106 #[inline(always)]
109 unsafe fn add_ne<const CHECK_SETUP: bool>(p1: &#struct_name, p2: &#struct_name) -> #struct_name {
110 #[cfg(not(target_os = "zkvm"))]
111 {
112 use openvm_algebra_guest::DivUnsafe;
113 let lambda = (&p2.y - &p1.y).div_unsafe(&p2.x - &p1.x);
114 let x3 = &lambda * &lambda - &p1.x - &p2.x;
115 let y3 = &lambda * &(&p1.x - &x3) - &p1.y;
116 #struct_name { x: x3, y: y3 }
117 }
118 #[cfg(target_os = "zkvm")]
119 {
120 if CHECK_SETUP {
121 Self::set_up_once();
122 }
123 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
124 #sw_add_ne_extern_func(
125 uninit.as_mut_ptr() as usize,
126 p1 as *const #struct_name as usize,
127 p2 as *const #struct_name as usize
128 );
129 uninit.assume_init()
130 }
131 }
132
133 #[inline(always)]
134 unsafe fn add_ne_assign<const CHECK_SETUP: bool>(&mut self, p2: &#struct_name) {
135 #[cfg(not(target_os = "zkvm"))]
136 {
137 use openvm_algebra_guest::DivUnsafe;
138 let lambda = (&p2.y - &self.y).div_unsafe(&p2.x - &self.x);
139 let x3 = &lambda * &lambda - &self.x - &p2.x;
140 let y3 = &lambda * &(&self.x - &x3) - &self.y;
141 self.x = x3;
142 self.y = y3;
143 }
144 #[cfg(target_os = "zkvm")]
145 {
146 if CHECK_SETUP {
147 Self::set_up_once();
148 }
149 #sw_add_ne_extern_func(
150 self as *mut #struct_name as usize,
151 self as *const #struct_name as usize,
152 p2 as *const #struct_name as usize
153 );
154 }
155 }
156
157 #[inline(always)]
159 unsafe fn double_impl<const CHECK_SETUP: bool>(p: &#struct_name) -> #struct_name {
160 #[cfg(not(target_os = "zkvm"))]
161 {
162 use openvm_algebra_guest::DivUnsafe;
163 let curve_a: #intmod_type = #const_a;
164 let two = #intmod_type::from_u8(2);
165 let lambda = (&p.x * &p.x * #intmod_type::from_u8(3) + &curve_a).div_unsafe(&p.y * &two);
166 let x3 = &lambda * &lambda - &p.x * &two;
167 let y3 = &lambda * &(&p.x - &x3) - &p.y;
168 #struct_name { x: x3, y: y3 }
169 }
170 #[cfg(target_os = "zkvm")]
171 {
172 if CHECK_SETUP {
173 Self::set_up_once();
174 }
175 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
176 #sw_double_extern_func(
177 uninit.as_mut_ptr() as usize,
178 p as *const #struct_name as usize,
179 );
180 uninit.assume_init()
181 }
182 }
183
184 #[inline(always)]
186 #[cfg(target_os = "zkvm")]
187 fn set_up_once() {
188 static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new();
189
190 is_setup.get_or_init(|| {
191 let modulus_bytes = <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
194 let mut one = [0u8; <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
195 one[0] = 1;
196 let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#struct_name as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
197 let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
199 let p2 = [one.as_ref(), one.as_ref()].concat();
201 let mut uninit: core::mem::MaybeUninit<[Self; 2]> = core::mem::MaybeUninit::uninit();
202
203 unsafe { #sw_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); }
204 <#intmod_type as openvm_algebra_guest::IntMod>::set_up_once();
205 true
206 });
207 }
208
209 #[inline(always)]
210 #[cfg(not(target_os = "zkvm"))]
211 fn set_up_once() {
212 }
214
215 #[inline(always)]
216 fn is_identity_impl<const CHECK_SETUP: bool>(&self) -> bool {
217 use openvm_algebra_guest::IntMod;
218 unsafe {
220 self.x.eq_impl::<CHECK_SETUP>(&#intmod_type::ZERO) && self.y.eq_impl::<CHECK_SETUP>(&#intmod_type::ZERO)
221 }
222 }
223 }
224
225 impl ::openvm_ecc_guest::weierstrass::WeierstrassPoint for #struct_name {
226 const CURVE_A: #intmod_type = #const_a;
227 const CURVE_B: #intmod_type = #const_b;
228 const IDENTITY: Self = Self::identity();
229 type Coordinate = #intmod_type;
230
231 #[inline(always)]
234 fn as_le_bytes(&self) -> &[u8] {
235 unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) }
236 }
237
238 #[inline(always)]
239 fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
240 Self { x, y }
241 }
242
243 #[inline(always)]
244 fn x(&self) -> &Self::Coordinate {
245 &self.x
246 }
247
248 #[inline(always)]
249 fn y(&self) -> &Self::Coordinate {
250 &self.y
251 }
252
253 #[inline(always)]
254 fn x_mut(&mut self) -> &mut Self::Coordinate {
255 &mut self.x
256 }
257
258 #[inline(always)]
259 fn y_mut(&mut self) -> &mut Self::Coordinate {
260 &mut self.y
261 }
262
263 #[inline(always)]
264 fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
265 (self.x, self.y)
266 }
267
268 #[inline(always)]
269 fn set_up_once() {
270 Self::set_up_once();
271 }
272
273 #[inline]
274 fn add_assign_impl<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
275 use openvm_algebra_guest::IntMod;
276
277 if CHECK_SETUP {
278 #intmod_type::set_up_once();
280 }
281
282 if self.is_identity_impl::<CHECK_SETUP>() {
283 *self = p2.clone();
284 } else if p2.is_identity_impl::<CHECK_SETUP>() {
285 } else if unsafe { self.x.eq_impl::<false>(&p2.x) } { let sum_ys = unsafe { self.y.add_ref::<false>(&p2.y) };
288 if unsafe { IntMod::eq_impl::<false>(&sum_ys, &<#intmod_type as IntMod>::ZERO) } {
290 *self = Self::identity();
291 } else {
292 unsafe {
293 self.double_assign_nonidentity::<CHECK_SETUP>();
294 }
295 }
296 } else {
297 unsafe {
298 self.add_ne_assign_nonidentity::<CHECK_SETUP>(p2);
299 }
300 }
301 }
302
303 #[inline(always)]
304 fn double_assign_impl<const CHECK_SETUP: bool>(&mut self) {
305 if !self.is_identity_impl::<CHECK_SETUP>() {
306 unsafe {
307 self.double_assign_nonidentity::<CHECK_SETUP>();
308 }
309 }
310 }
311
312 #[inline(always)]
313 unsafe fn add_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
314 Self::add_ne::<CHECK_SETUP>(self, p2)
315 }
316
317 #[inline(always)]
318 unsafe fn add_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
319 Self::add_ne_assign::<CHECK_SETUP>(self, p2);
320 }
321
322 #[inline(always)]
323 unsafe fn sub_ne_nonidentity<const CHECK_SETUP: bool>(&self, p2: &Self) -> Self {
324 Self::add_ne::<CHECK_SETUP>(self, &p2.clone().neg())
325 }
326
327 #[inline(always)]
328 unsafe fn sub_ne_assign_nonidentity<const CHECK_SETUP: bool>(&mut self, p2: &Self) {
329 Self::add_ne_assign::<CHECK_SETUP>(self, &p2.clone().neg());
330 }
331
332 #[inline(always)]
333 unsafe fn double_nonidentity<const CHECK_SETUP: bool>(&self) -> Self {
334 Self::double_impl::<CHECK_SETUP>(self)
335 }
336
337 #[inline(always)]
338 unsafe fn double_assign_nonidentity<const CHECK_SETUP: bool>(&mut self) {
339 #[cfg(not(target_os = "zkvm"))]
340 {
341 *self = Self::double_impl::<CHECK_SETUP>(self);
342 }
343 #[cfg(target_os = "zkvm")]
344 {
345 if CHECK_SETUP {
346 Self::set_up_once();
347 }
348 #sw_double_extern_func(
349 self as *mut #struct_name as usize,
350 self as *const #struct_name as usize
351 );
352 }
353 }
354 }
355
356 impl core::ops::Neg for #struct_name {
357 type Output = Self;
358
359 fn neg(self) -> Self::Output {
360 #struct_name {
361 x: self.x,
362 y: -self.y,
363 }
364 }
365 }
366
367 impl core::ops::Neg for &#struct_name {
368 type Output = #struct_name;
369
370 fn neg(self) -> #struct_name {
371 #struct_name {
372 x: self.x.clone(),
373 y: core::ops::Neg::neg(&self.y),
374 }
375 }
376 }
377
378 mod #group_ops_mod_name {
379 use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed}, impl_sw_group_ops, algebra::IntMod};
380 use super::*;
381
382 impl_sw_group_ops!(#struct_name, #intmod_type);
383
384 impl FromCompressed<#intmod_type> for #struct_name {
385 fn decompress(x: #intmod_type, rec_id: &u8) -> Option<Self> {
386 use openvm_algebra_guest::Sqrt;
387 let y_squared = &x * &x * &x + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A * &x + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_B;
388 let y = y_squared.sqrt();
389 match y {
390 None => None,
391 Some(y) => {
392 let correct_y = if y.as_le_bytes()[0] & 1 == *rec_id & 1 {
393 y
394 } else {
395 -y
396 };
397 if correct_y.as_le_bytes()[0] & 1 != *rec_id & 1 {
399 return None;
400 }
401 Some(<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::from_xy_unchecked(x, correct_y))
403 }
404 }
405 }
406 }
407 }
408 });
409 output.push(result);
410 }
411
412 TokenStream::from_iter(output)
413}
414
415struct SwDefine {
416 items: Vec<String>,
417}
418
419impl Parse for SwDefine {
420 fn parse(input: ParseStream) -> syn::Result<Self> {
421 let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
422 Ok(Self {
423 items: items.into_iter().map(|e| e.value()).collect(),
424 })
425 }
426}
427
428#[proc_macro]
429pub fn sw_init(input: TokenStream) -> TokenStream {
430 let SwDefine { items } = parse_macro_input!(input as SwDefine);
431
432 let mut externs = Vec::new();
433
434 let span = proc_macro::Span::call_site();
435
436 for (ec_idx, struct_id) in items.into_iter().enumerate() {
437 let add_ne_extern_func =
440 syn::Ident::new(&format!("sw_add_ne_extern_func_{}", struct_id), span.into());
441 let double_extern_func =
442 syn::Ident::new(&format!("sw_double_extern_func_{}", struct_id), span.into());
443 let setup_extern_func =
444 syn::Ident::new(&format!("sw_setup_extern_func_{}", struct_id), span.into());
445
446 externs.push(quote::quote_spanned! { span.into() =>
447 #[no_mangle]
448 extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) {
449 openvm::platform::custom_insn_r!(
450 opcode = OPCODE,
451 funct3 = SW_FUNCT3 as usize,
452 funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx
453 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
454 rd = In rd,
455 rs1 = In rs1,
456 rs2 = In rs2
457 );
458 }
459
460 #[no_mangle]
461 extern "C" fn #double_extern_func(rd: usize, rs1: usize) {
462 openvm::platform::custom_insn_r!(
463 opcode = OPCODE,
464 funct3 = SW_FUNCT3 as usize,
465 funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx
466 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
467 rd = In rd,
468 rs1 = In rs1,
469 rs2 = Const "x0"
470 );
471 }
472
473 #[no_mangle]
474 extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) {
475 #[cfg(target_os = "zkvm")]
476 {
477 openvm::platform::custom_insn_r!(
478 opcode = ::openvm_ecc_guest::OPCODE,
479 funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
480 funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
481 + #ec_idx
482 * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
483 rd = In uninit,
484 rs1 = In p1,
485 rs2 = In p2
486 );
487 openvm::platform::custom_insn_r!(
488 opcode = ::openvm_ecc_guest::OPCODE,
489 funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
490 funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
491 + #ec_idx
492 * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
493 rd = In uninit,
494 rs1 = In p1,
495 rs2 = Const "x0" );
497
498
499 }
500 }
501 });
502 }
503
504 TokenStream::from(quote::quote_spanned! { span.into() =>
505 #[allow(non_snake_case)]
506 #[cfg(target_os = "zkvm")]
507 mod openvm_intrinsics_ffi_2 {
508 use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7};
509
510 #(#externs)*
511 }
512 })
513}