1extern crate proc_macro;
2
3use openvm_macros_common::MacroArgs;
4use proc_macro::TokenStream;
5use quote::format_ident;
6use syn::{
7 parse::{Parse, ParseStream},
8 parse_macro_input, Expr, ExprPath, Path, Token,
9};
10
11#[proc_macro]
23pub fn sw_declare(input: TokenStream) -> TokenStream {
24 let MacroArgs { items } = parse_macro_input!(input as MacroArgs);
25
26 let mut output = Vec::new();
27
28 let span = proc_macro::Span::call_site();
29
30 for item in items.into_iter() {
31 let struct_name = item.name.to_string();
32 let struct_name = syn::Ident::new(&struct_name, span.into());
33 let struct_path: syn::Path = syn::parse_quote!(#struct_name);
34 let mut intmod_type: Option<syn::Path> = None;
35 let mut const_a: Option<syn::Expr> = None;
36 let mut const_b: Option<syn::Expr> = None;
37 for param in item.params {
38 match param.name.to_string().as_str() {
39 "mod_type" => {
41 if let syn::Expr::Path(ExprPath { path, .. }) = param.value {
42 intmod_type = Some(path)
43 } else {
44 return syn::Error::new_spanned(param.value, "Expected a type")
45 .to_compile_error()
46 .into();
47 }
48 }
49 "a" => {
50 const_a = Some(param.value);
53 }
54 "b" => {
55 const_b = Some(param.value);
58 }
59 _ => {
60 panic!("Unknown parameter {}", param.name);
61 }
62 }
63 }
64
65 let intmod_type = intmod_type.expect("mod_type parameter is required");
66 let const_a = const_a
68 .unwrap_or(syn::parse_quote!(<#intmod_type as openvm_algebra_guest::IntMod>::ZERO));
69 let const_b = const_b.expect("constant b coefficient is required");
70
71 macro_rules! create_extern_func {
72 ($name:ident) => {
73 let $name = syn::Ident::new(
74 &format!(
75 "{}_{}",
76 stringify!($name),
77 struct_path
78 .segments
79 .iter()
80 .map(|x| x.ident.to_string())
81 .collect::<Vec<_>>()
82 .join("_")
83 ),
84 span.into(),
85 );
86 };
87 }
88 create_extern_func!(sw_add_ne_extern_func);
89 create_extern_func!(sw_double_extern_func);
90 create_extern_func!(hint_decompress_extern_func);
91 create_extern_func!(hint_non_qr_extern_func);
92
93 let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase());
94
95 let result = TokenStream::from(quote::quote_spanned! { span.into() =>
96 extern "C" {
97 fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize);
98 fn #sw_double_extern_func(rd: usize, rs1: usize);
99 fn #hint_decompress_extern_func(rs1: usize, rs2: usize);
100 fn #hint_non_qr_extern_func();
101 }
102
103 #[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)]
104 #[repr(C)]
105 pub struct #struct_name {
106 x: #intmod_type,
107 y: #intmod_type,
108 }
109 #[allow(non_upper_case_globals)]
110
111 impl #struct_name {
112 const fn identity() -> Self {
113 Self {
114 x: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
115 y: <#intmod_type as openvm_algebra_guest::IntMod>::ZERO,
116 }
117 }
118 #[inline(always)]
121 fn add_ne(p1: &#struct_name, p2: &#struct_name) -> #struct_name {
122 #[cfg(not(target_os = "zkvm"))]
123 {
124 use openvm_algebra_guest::DivUnsafe;
125 let lambda = (&p2.y - &p1.y).div_unsafe(&p2.x - &p1.x);
126 let x3 = &lambda * &lambda - &p1.x - &p2.x;
127 let y3 = &lambda * &(&p1.x - &x3) - &p1.y;
128 #struct_name { x: x3, y: y3 }
129 }
130 #[cfg(target_os = "zkvm")]
131 {
132 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
133 unsafe {
134 #sw_add_ne_extern_func(
135 uninit.as_mut_ptr() as usize,
136 p1 as *const #struct_name as usize,
137 p2 as *const #struct_name as usize
138 )
139 };
140 unsafe { uninit.assume_init() }
141 }
142 }
143
144 #[inline(always)]
145 fn add_ne_assign(&mut self, p2: &#struct_name) {
146 #[cfg(not(target_os = "zkvm"))]
147 {
148 use openvm_algebra_guest::DivUnsafe;
149 let lambda = (&p2.y - &self.y).div_unsafe(&p2.x - &self.x);
150 let x3 = &lambda * &lambda - &self.x - &p2.x;
151 let y3 = &lambda * &(&self.x - &x3) - &self.y;
152 self.x = x3;
153 self.y = y3;
154 }
155 #[cfg(target_os = "zkvm")]
156 {
157 unsafe {
158 #sw_add_ne_extern_func(
159 self as *mut #struct_name as usize,
160 self as *const #struct_name as usize,
161 p2 as *const #struct_name as usize
162 )
163 };
164 }
165 }
166
167 #[inline(always)]
169 fn double_impl(p: &#struct_name) -> #struct_name {
170 #[cfg(not(target_os = "zkvm"))]
171 {
172 use openvm_algebra_guest::DivUnsafe;
173 let curve_a: #intmod_type = #const_a;
174 let two = #intmod_type::from_u8(2);
175 let lambda = (&p.x * &p.x * #intmod_type::from_u8(3) + &curve_a).div_unsafe(&p.y * &two);
176 let x3 = &lambda * &lambda - &p.x * &two;
177 let y3 = &lambda * &(&p.x - &x3) - &p.y;
178 #struct_name { x: x3, y: y3 }
179 }
180 #[cfg(target_os = "zkvm")]
181 {
182 let mut uninit: core::mem::MaybeUninit<#struct_name> = core::mem::MaybeUninit::uninit();
183 unsafe {
184 #sw_double_extern_func(
185 uninit.as_mut_ptr() as usize,
186 p as *const #struct_name as usize,
187 )
188 };
189 unsafe { uninit.assume_init() }
190 }
191 }
192
193 #[inline(always)]
194 fn double_assign_impl(&mut self) {
195 #[cfg(not(target_os = "zkvm"))]
196 {
197 *self = Self::double_impl(self);
198 }
199 #[cfg(target_os = "zkvm")]
200 {
201 unsafe {
202 #sw_double_extern_func(
203 self as *mut #struct_name as usize,
204 self as *const #struct_name as usize
205 )
206 };
207 }
208 }
209
210 }
211
212 impl ::openvm_ecc_guest::weierstrass::WeierstrassPoint for #struct_name {
213 const CURVE_A: #intmod_type = #const_a;
214 const CURVE_B: #intmod_type = #const_b;
215 const IDENTITY: Self = Self::identity();
216 type Coordinate = #intmod_type;
217
218 fn as_le_bytes(&self) -> &[u8] {
221 unsafe { &*core::ptr::slice_from_raw_parts(self as *const Self as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS * 2) }
222 }
223
224 fn from_xy_unchecked(x: Self::Coordinate, y: Self::Coordinate) -> Self {
225 Self { x, y }
226 }
227
228 fn x(&self) -> &Self::Coordinate {
229 &self.x
230 }
231
232 fn y(&self) -> &Self::Coordinate {
233 &self.y
234 }
235
236 fn x_mut(&mut self) -> &mut Self::Coordinate {
237 &mut self.x
238 }
239
240 fn y_mut(&mut self) -> &mut Self::Coordinate {
241 &mut self.y
242 }
243
244 fn into_coords(self) -> (Self::Coordinate, Self::Coordinate) {
245 (self.x, self.y)
246 }
247
248 fn add_ne_nonidentity(&self, p2: &Self) -> Self {
249 Self::add_ne(self, p2)
250 }
251
252 fn add_ne_assign_nonidentity(&mut self, p2: &Self) {
253 Self::add_ne_assign(self, p2);
254 }
255
256 fn sub_ne_nonidentity(&self, p2: &Self) -> Self {
257 Self::add_ne(self, &p2.clone().neg())
258 }
259
260 fn sub_ne_assign_nonidentity(&mut self, p2: &Self) {
261 Self::add_ne_assign(self, &p2.clone().neg());
262 }
263
264 fn double_nonidentity(&self) -> Self {
265 Self::double_impl(self)
266 }
267
268 fn double_assign_nonidentity(&mut self) {
269 Self::double_assign_impl(self);
270 }
271 }
272
273 impl core::ops::Neg for #struct_name {
274 type Output = Self;
275
276 fn neg(self) -> Self::Output {
277 #struct_name {
278 x: self.x,
279 y: -self.y,
280 }
281 }
282 }
283
284 impl core::ops::Neg for &#struct_name {
285 type Output = #struct_name;
286
287 fn neg(self) -> #struct_name {
288 #struct_name {
289 x: self.x.clone(),
290 y: core::ops::Neg::neg(&self.y),
291 }
292 }
293 }
294
295 mod #group_ops_mod_name {
296 use ::openvm_ecc_guest::{weierstrass::{WeierstrassPoint, FromCompressed, DecompressionHint}, impl_sw_group_ops, algebra::{IntMod, DivUnsafe, DivAssignUnsafe, ExpBytes}};
297 use super::*;
298
299 impl_sw_group_ops!(#struct_name, #intmod_type);
300
301 impl FromCompressed<#intmod_type> for #struct_name {
302 fn decompress(x: #intmod_type, rec_id: &u8) -> Option<Self> {
303 match Self::honest_host_decompress(&x, rec_id) {
304 Some(Some(ret)) => Some(ret),
306 Some(None) => None,
308 None => {
309 loop {
311 openvm::io::println("ERROR: Decompression hint is invalid. Entering infinite loop.");
312 }
313 }
314 }
315 }
316
317 fn hint_decompress(x: &#intmod_type, rec_id: &u8) -> Option<DecompressionHint<#intmod_type>> {
318 #[cfg(not(target_os = "zkvm"))]
319 {
320 unimplemented!()
321 }
322 #[cfg(target_os = "zkvm")]
323 {
324 use openvm::platform as openvm_platform; let possible = core::mem::MaybeUninit::<u32>::uninit();
327 let sqrt = core::mem::MaybeUninit::<#intmod_type>::uninit();
328 unsafe {
329 #hint_decompress_extern_func(x as *const _ as usize, rec_id as *const u8 as usize);
330 let possible_ptr = possible.as_ptr() as *const u32;
331 openvm_rv32im_guest::hint_store_u32!(possible_ptr);
332 openvm_rv32im_guest::hint_buffer_u32!(sqrt.as_ptr() as *const u8, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
333 let possible = possible.assume_init();
334 if possible == 0 || possible == 1 {
335 Some(DecompressionHint { possible: possible == 1, sqrt: sqrt.assume_init() })
336 } else {
337 None
338 }
339 }
340 }
341 }
342 }
343
344 impl #struct_name {
345 fn honest_host_decompress(x: &#intmod_type, rec_id: &u8) -> Option<Option<Self>> {
348 let hint = <#struct_name as FromCompressed<#intmod_type>>::hint_decompress(x, rec_id)?;
349
350 if hint.possible {
351 hint.sqrt.assert_reduced();
353
354 if hint.sqrt.as_le_bytes()[0] & 1 != *rec_id & 1 {
355 None
356 } else {
357 let ret = <#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::from_xy_nonidentity(x.clone(), hint.sqrt)?;
358 Some(Some(ret))
359 }
360 } else {
361 hint.sqrt.assert_reduced();
363
364 let alpha = (x * x * x) + (x * &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A) + &<#struct_name as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_B;
365 if &hint.sqrt * &hint.sqrt == alpha * Self::get_non_qr() {
366 Some(None)
367 } else {
368 None
369 }
370 }
371 }
372
373 fn init_non_qr() -> alloc::boxed::Box<<Self as ::openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate> {
375 #[cfg(not(target_os = "zkvm"))]
376 {
377 unimplemented!();
378 }
379 #[cfg(target_os = "zkvm")]
380 {
381 use openvm::platform as openvm_platform; let mut non_qr_uninit = core::mem::MaybeUninit::<#intmod_type>::uninit();
383 let mut non_qr;
384 unsafe {
385 #hint_non_qr_extern_func();
386 let ptr = non_qr_uninit.as_ptr() as *const u8;
387 openvm_rv32im_guest::hint_buffer_u32!(ptr, <#intmod_type as openvm_algebra_guest::IntMod>::NUM_LIMBS / 4);
388 non_qr = non_qr_uninit.assume_init();
389 }
390 non_qr.assert_reduced();
392
393 let exp = -<#intmod_type as openvm_algebra_guest::IntMod>::ONE.div_unsafe(#intmod_type::from_const_u8(2));
395 exp.assert_reduced();
396
397 if non_qr.exp_bytes(true, &exp.to_be_bytes()) != -<#intmod_type as openvm_algebra_guest::IntMod>::ONE
398 {
399 loop {
401 openvm::io::println("ERROR: Non quadratic residue hint is invalid. Entering infinite loop.");
402 }
403 }
404
405 alloc::boxed::Box::new(non_qr)
406 }
407 }
408
409 pub fn get_non_qr() -> &'static #intmod_type {
410 static non_qr: ::openvm_ecc_guest::once_cell::race::OnceBox<#intmod_type> = ::openvm_ecc_guest::once_cell::race::OnceBox::new();
411 &non_qr.get_or_init(Self::init_non_qr)
412 }
413 }
414 }
415 });
416 output.push(result);
417 }
418
419 TokenStream::from_iter(output)
420}
421
422struct SwDefine {
423 items: Vec<Path>,
424}
425
426impl Parse for SwDefine {
427 fn parse(input: ParseStream) -> syn::Result<Self> {
428 let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
429 Ok(Self {
430 items: items
431 .into_iter()
432 .map(|e| {
433 if let Expr::Path(p) = e {
434 p.path
435 } else {
436 panic!("expected path");
437 }
438 })
439 .collect(),
440 })
441 }
442}
443
444#[proc_macro]
445pub fn sw_init(input: TokenStream) -> TokenStream {
446 let SwDefine { items } = parse_macro_input!(input as SwDefine);
447
448 let mut externs = Vec::new();
449 let mut setups = Vec::new();
450 let mut setup_all_curves = Vec::new();
451
452 let span = proc_macro::Span::call_site();
453
454 for (ec_idx, item) in items.into_iter().enumerate() {
455 let str_path = item
456 .segments
457 .iter()
458 .map(|x| x.ident.to_string())
459 .collect::<Vec<_>>()
460 .join("_");
461 let add_ne_extern_func =
462 syn::Ident::new(&format!("sw_add_ne_extern_func_{}", str_path), span.into());
463 let double_extern_func =
464 syn::Ident::new(&format!("sw_double_extern_func_{}", str_path), span.into());
465 let hint_decompress_extern_func = syn::Ident::new(
466 &format!("hint_decompress_extern_func_{}", str_path),
467 span.into(),
468 );
469 let hint_non_qr_extern_func = syn::Ident::new(
470 &format!("hint_non_qr_extern_func_{}", str_path),
471 span.into(),
472 );
473 externs.push(quote::quote_spanned! { span.into() =>
474 #[no_mangle]
475 extern "C" fn #add_ne_extern_func(rd: usize, rs1: usize, rs2: usize) {
476 openvm::platform::custom_insn_r!(
477 opcode = OPCODE,
478 funct3 = SW_FUNCT3 as usize,
479 funct7 = SwBaseFunct7::SwAddNe as usize + #ec_idx
480 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
481 rd = In rd,
482 rs1 = In rs1,
483 rs2 = In rs2
484 );
485 }
486
487 #[no_mangle]
488 extern "C" fn #double_extern_func(rd: usize, rs1: usize) {
489 openvm::platform::custom_insn_r!(
490 opcode = OPCODE,
491 funct3 = SW_FUNCT3 as usize,
492 funct7 = SwBaseFunct7::SwDouble as usize + #ec_idx
493 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
494 rd = In rd,
495 rs1 = In rs1,
496 rs2 = Const "x0"
497 );
498 }
499
500 #[no_mangle]
501 extern "C" fn #hint_decompress_extern_func(rs1: usize, rs2: usize) {
502 openvm::platform::custom_insn_r!(
503 opcode = OPCODE,
504 funct3 = SW_FUNCT3 as usize,
505 funct7 = SwBaseFunct7::HintDecompress as usize + #ec_idx
506 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
507 rd = Const "x0",
508 rs1 = In rs1,
509 rs2 = In rs2
510 );
511 }
512
513 #[no_mangle]
514 extern "C" fn #hint_non_qr_extern_func() {
515 openvm::platform::custom_insn_r!(
516 opcode = OPCODE,
517 funct3 = SW_FUNCT3 as usize,
518 funct7 = SwBaseFunct7::HintNonQr as usize + #ec_idx
519 * (SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
520 rd = Const "x0",
521 rs1 = Const "x0",
522 rs2 = Const "x0"
523 );
524 }
525 });
526
527 let setup_function = syn::Ident::new(&format!("setup_sw_{}", str_path), span.into());
528 setups.push(quote::quote_spanned! { span.into() =>
529 #[allow(non_snake_case)]
530 pub fn #setup_function() {
531 #[cfg(target_os = "zkvm")]
532 {
533 let modulus_bytes = <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
536 let mut one = [0u8; <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
537 one[0] = 1;
538 let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
539 let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
541 let p2 = [one.as_ref(), one.as_ref()].concat();
543 let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit();
544 openvm::platform::custom_insn_r!(
545 opcode = ::openvm_ecc_guest::OPCODE,
546 funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
547 funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
548 + #ec_idx
549 * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
550 rd = In uninit.as_mut_ptr(),
551 rs1 = In p1.as_ptr(),
552 rs2 = In p2.as_ptr()
553 );
554 openvm::platform::custom_insn_r!(
555 opcode = ::openvm_ecc_guest::OPCODE,
556 funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
557 funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
558 + #ec_idx
559 * (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
560 rd = In uninit.as_mut_ptr(),
561 rs1 = In p1.as_ptr(),
562 rs2 = Const "x0" );
564 }
565 }
566 });
567
568 setup_all_curves.push(quote::quote_spanned! { span.into() =>
569 #setup_function();
570 });
571 }
572
573 TokenStream::from(quote::quote_spanned! { span.into() =>
574 #[cfg(target_os = "zkvm")]
575 mod openvm_intrinsics_ffi_2 {
576 use ::openvm_ecc_guest::{OPCODE, SW_FUNCT3, SwBaseFunct7};
577
578 #(#externs)*
579 }
580 #(#setups)*
581 pub fn setup_all_curves() {
582 #(#setup_all_curves)*
583 }
584 })
585}