1use derive_more::derive::From;
2use num_bigint::BigUint;
3use num_traits::{FromPrimitive, Zero};
4use once_cell::sync::Lazy;
5use openvm_algebra_guest::IntMod;
6use openvm_circuit::{
7 arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError},
8 system::phantom::PhantomChip,
9};
10use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
11use openvm_circuit_primitives::bitwise_op_lookup::{
12 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
13};
14use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
15use openvm_ecc_guest::{
16 k256::{SECP256K1_MODULUS, SECP256K1_ORDER},
17 p256::{CURVE_A as P256_A, CURVE_B as P256_B, P256_MODULUS, P256_ORDER},
18};
19use openvm_ecc_transpiler::{EccPhantom, Rv32WeierstrassOpcode};
20use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode};
21use openvm_mod_circuit_builder::ExprBuilderConfig;
22use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
23use openvm_stark_backend::p3_field::PrimeField32;
24use serde::{Deserialize, Serialize};
25use serde_with::{serde_as, DisplayFromStr};
26use strum::EnumCount;
27
28use super::{EcAddNeChip, EcDoubleChip};
29
30#[serde_as]
31#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
32pub struct CurveConfig {
33 #[serde_as(as = "DisplayFromStr")]
35 pub modulus: BigUint,
36 #[serde_as(as = "DisplayFromStr")]
38 pub scalar: BigUint,
39 #[serde_as(as = "DisplayFromStr")]
41 pub a: BigUint,
42 #[serde_as(as = "DisplayFromStr")]
44 pub b: BigUint,
45}
46
47pub static SECP256K1_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
48 modulus: SECP256K1_MODULUS.clone(),
49 scalar: SECP256K1_ORDER.clone(),
50 a: BigUint::zero(),
51 b: BigUint::from_u8(7u8).unwrap(),
52});
53
54pub static P256_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
55 modulus: P256_MODULUS.clone(),
56 scalar: P256_ORDER.clone(),
57 a: BigUint::from_bytes_le(P256_A.as_le_bytes()),
58 b: BigUint::from_bytes_le(P256_B.as_le_bytes()),
59});
60
61#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
62pub struct WeierstrassExtension {
63 pub supported_curves: Vec<CurveConfig>,
64}
65
66#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)]
67pub enum WeierstrassExtensionExecutor<F: PrimeField32> {
68 EcAddNeRv32_32(EcAddNeChip<F, 2, 32>),
70 EcDoubleRv32_32(EcDoubleChip<F, 2, 32>),
71 EcAddNeRv32_48(EcAddNeChip<F, 6, 16>),
73 EcDoubleRv32_48(EcDoubleChip<F, 6, 16>),
74}
75
76#[derive(ChipUsageGetter, Chip, AnyEnum, From)]
77pub enum WeierstrassExtensionPeriphery<F: PrimeField32> {
78 BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>),
79 Phantom(PhantomChip<F>),
80}
81
82impl<F: PrimeField32> VmExtension<F> for WeierstrassExtension {
83 type Executor = WeierstrassExtensionExecutor<F>;
84 type Periphery = WeierstrassExtensionPeriphery<F>;
85
86 fn build(
87 &self,
88 builder: &mut VmInventoryBuilder<F>,
89 ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError> {
90 let mut inventory = VmInventory::new();
91 let SystemPort {
92 execution_bus,
93 program_bus,
94 memory_bridge,
95 } = builder.system_port();
96 let bitwise_lu_chip = if let Some(&chip) = builder
97 .find_chip::<SharedBitwiseOperationLookupChip<8>>()
98 .first()
99 {
100 chip.clone()
101 } else {
102 let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx());
103 let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus);
104 inventory.add_periphery_chip(chip.clone());
105 chip
106 };
107 let offline_memory = builder.system_base().offline_memory();
108 let range_checker = builder.system_base().range_checker_chip.clone();
109 let pointer_bits = builder.system_config().memory_config.pointer_max_bits;
110 let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize)
111 ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize);
112 let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize)
113 ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize);
114
115 for (i, curve) in self.supported_curves.iter().enumerate() {
116 let start_offset =
117 Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
118 let bytes = curve.modulus.bits().div_ceil(8);
119 let config32 = ExprBuilderConfig {
120 modulus: curve.modulus.clone(),
121 num_limbs: 32,
122 limb_bits: 8,
123 };
124 let config48 = ExprBuilderConfig {
125 modulus: curve.modulus.clone(),
126 num_limbs: 48,
127 limb_bits: 8,
128 };
129 if bytes <= 32 {
130 let add_ne_chip = EcAddNeChip::new(
131 Rv32VecHeapAdapterChip::<F, 2, 2, 2, 32, 32>::new(
132 execution_bus,
133 program_bus,
134 memory_bridge,
135 pointer_bits,
136 bitwise_lu_chip.clone(),
137 ),
138 config32.clone(),
139 start_offset,
140 range_checker.clone(),
141 offline_memory.clone(),
142 );
143 inventory.add_executor(
144 WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip),
145 ec_add_ne_opcodes
146 .clone()
147 .map(|x| VmOpcode::from_usize(x + start_offset)),
148 )?;
149 let double_chip = EcDoubleChip::new(
150 Rv32VecHeapAdapterChip::<F, 1, 2, 2, 32, 32>::new(
151 execution_bus,
152 program_bus,
153 memory_bridge,
154 pointer_bits,
155 bitwise_lu_chip.clone(),
156 ),
157 range_checker.clone(),
158 config32.clone(),
159 start_offset,
160 curve.a.clone(),
161 offline_memory.clone(),
162 );
163 inventory.add_executor(
164 WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip),
165 ec_double_opcodes
166 .clone()
167 .map(|x| VmOpcode::from_usize(x + start_offset)),
168 )?;
169 } else if bytes <= 48 {
170 let add_ne_chip = EcAddNeChip::new(
171 Rv32VecHeapAdapterChip::<F, 2, 6, 6, 16, 16>::new(
172 execution_bus,
173 program_bus,
174 memory_bridge,
175 pointer_bits,
176 bitwise_lu_chip.clone(),
177 ),
178 config48.clone(),
179 start_offset,
180 range_checker.clone(),
181 offline_memory.clone(),
182 );
183 inventory.add_executor(
184 WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip),
185 ec_add_ne_opcodes
186 .clone()
187 .map(|x| VmOpcode::from_usize(x + start_offset)),
188 )?;
189 let double_chip = EcDoubleChip::new(
190 Rv32VecHeapAdapterChip::<F, 1, 6, 6, 16, 16>::new(
191 execution_bus,
192 program_bus,
193 memory_bridge,
194 pointer_bits,
195 bitwise_lu_chip.clone(),
196 ),
197 range_checker.clone(),
198 config48.clone(),
199 start_offset,
200 curve.a.clone(),
201 offline_memory.clone(),
202 );
203 inventory.add_executor(
204 WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip),
205 ec_double_opcodes
206 .clone()
207 .map(|x| VmOpcode::from_usize(x + start_offset)),
208 )?;
209 } else {
210 panic!("Modulus too large");
211 }
212 }
213 let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_curves.clone());
214 builder.add_phantom_sub_executor(
215 non_qr_hint_sub_ex.clone(),
216 PhantomDiscriminant(EccPhantom::HintNonQr as u16),
217 )?;
218 builder.add_phantom_sub_executor(
219 phantom::DecompressHintSubEx::new(non_qr_hint_sub_ex),
220 PhantomDiscriminant(EccPhantom::HintDecompress as u16),
221 )?;
222
223 Ok(inventory)
224 }
225}
226
227pub(crate) mod phantom {
228 use std::{
229 iter::{once, repeat},
230 ops::Deref,
231 };
232
233 use eyre::bail;
234 use num_bigint::{BigUint, RandBigInt};
235 use num_integer::Integer;
236 use num_traits::{FromPrimitive, One};
237 use openvm_circuit::{
238 arch::{PhantomSubExecutor, Streams},
239 system::memory::MemoryController,
240 };
241 use openvm_ecc_guest::weierstrass::DecompressionHint;
242 use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant};
243 use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register;
244 use openvm_stark_backend::p3_field::PrimeField32;
245 use rand::{rngs::StdRng, SeedableRng};
246
247 use super::CurveConfig;
248
249 #[derive(derive_new::new)]
250 pub struct DecompressHintSubEx(NonQrHintSubEx);
251
252 impl Deref for DecompressHintSubEx {
253 type Target = NonQrHintSubEx;
254
255 fn deref(&self) -> &NonQrHintSubEx {
256 &self.0
257 }
258 }
259
260 impl<F: PrimeField32> PhantomSubExecutor<F> for DecompressHintSubEx {
261 fn phantom_execute(
262 &mut self,
263 memory: &MemoryController<F>,
264 streams: &mut Streams<F>,
265 _: PhantomDiscriminant,
266 a: F,
267 b: F,
268 c_upper: u16,
269 ) -> eyre::Result<()> {
270 let c_idx = c_upper as usize;
271 if c_idx >= self.supported_curves.len() {
272 bail!(
273 "Curve index {c_idx} out of range: {} supported curves",
274 self.supported_curves.len()
275 );
276 }
277 let curve = &self.supported_curves[c_idx];
278 let rs1 = unsafe_read_rv32_register(memory, a);
279 let num_limbs: usize = if curve.modulus.bits().div_ceil(8) <= 32 {
280 32
281 } else if curve.modulus.bits().div_ceil(8) <= 48 {
282 48
283 } else {
284 bail!("Modulus too large")
285 };
286 let mut x_limbs: Vec<u8> = Vec::with_capacity(num_limbs);
287 for i in 0..num_limbs {
288 let limb = memory.unsafe_read_cell(
289 F::from_canonical_u32(RV32_MEMORY_AS),
290 F::from_canonical_u32(rs1 + i as u32),
291 );
292 x_limbs.push(limb.as_canonical_u32() as u8);
293 }
294 let x = BigUint::from_bytes_le(&x_limbs);
295 let rs2 = unsafe_read_rv32_register(memory, b);
296 let rec_id = memory.unsafe_read_cell(
297 F::from_canonical_u32(RV32_MEMORY_AS),
298 F::from_canonical_u32(rs2),
299 );
300 let hint = self.decompress_point(x, rec_id.as_canonical_u32() & 1 == 1, c_idx);
301 let hint_bytes = once(F::from_bool(hint.possible))
302 .chain(repeat(F::ZERO))
303 .take(4)
304 .chain(
305 hint.sqrt
306 .to_bytes_le()
307 .into_iter()
308 .map(F::from_canonical_u8)
309 .chain(repeat(F::ZERO))
310 .take(num_limbs),
311 )
312 .collect();
313 streams.hint_stream = hint_bytes;
314 Ok(())
315 }
316 }
317
318 impl DecompressHintSubEx {
319 fn decompress_point(
326 &self,
327 x: BigUint,
328 is_y_odd: bool,
329 curve_idx: usize,
330 ) -> DecompressionHint<BigUint> {
331 let curve = &self.supported_curves[curve_idx];
332 let alpha = ((&x * &x * &x) + (&x * &curve.a) + &curve.b) % &curve.modulus;
333 match mod_sqrt(&alpha, &curve.modulus, &self.non_qrs[curve_idx]) {
334 Some(beta) => {
335 if is_y_odd == beta.is_odd() {
336 DecompressionHint {
337 possible: true,
338 sqrt: beta,
339 }
340 } else {
341 DecompressionHint {
342 possible: true,
343 sqrt: &curve.modulus - &beta,
344 }
345 }
346 }
347 None => {
348 debug_assert_eq!(
349 self.non_qrs[curve_idx]
350 .modpow(&((&curve.modulus - BigUint::one()) >> 1), &curve.modulus),
351 &curve.modulus - BigUint::one()
352 );
353 let sqrt = mod_sqrt(
354 &(&alpha * &self.non_qrs[curve_idx]),
355 &curve.modulus,
356 &self.non_qrs[curve_idx],
357 )
358 .unwrap();
359 DecompressionHint {
360 possible: false,
361 sqrt,
362 }
363 }
364 }
365 }
366 }
367
368 pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option<BigUint> {
371 if modulus % 4u32 == BigUint::from_u8(3).unwrap() {
372 let exponent = (modulus + BigUint::one()) >> 2;
374 let ret = x.modpow(&exponent, modulus);
375 if &ret * &ret % modulus == x % modulus {
376 Some(ret)
377 } else {
378 None
379 }
380 } else {
381 let mut q = modulus - BigUint::one();
384 let mut s = 0;
385 while &q % 2u32 == BigUint::ZERO {
386 s += 1;
387 q /= 2u32;
388 }
389 let z = non_qr;
390 let mut m = s;
391 let mut c = z.modpow(&q, modulus);
392 let mut t = x.modpow(&q, modulus);
393 let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus);
394 loop {
395 if t == BigUint::ZERO {
396 return Some(BigUint::ZERO);
397 }
398 if t == BigUint::one() {
399 return Some(r);
400 }
401 let mut i = 0;
402 let mut tmp = t.clone();
403 while tmp != BigUint::one() && i < m {
404 tmp = &tmp * &tmp % modulus;
405 i += 1;
406 }
407 if i == m {
408 return None;
410 }
411 for _ in 0..m - i - 1 {
412 c = &c * &c % modulus;
413 }
414 let b = c;
415 m = i;
416 c = &b * &b % modulus;
417 t = ((t * &b % modulus) * &b) % modulus;
418 r = (r * b) % modulus;
419 }
420 }
421 }
422
423 #[derive(Clone)]
424 pub struct NonQrHintSubEx {
425 pub supported_curves: Vec<CurveConfig>,
426 pub non_qrs: Vec<BigUint>,
427 }
428
429 impl NonQrHintSubEx {
430 pub fn new(supported_curves: Vec<CurveConfig>) -> Self {
431 let non_qrs = supported_curves
432 .iter()
433 .map(|curve| find_non_qr(&curve.modulus))
434 .collect();
435 Self {
436 supported_curves,
437 non_qrs,
438 }
439 }
440 }
441
442 impl<F: PrimeField32> PhantomSubExecutor<F> for NonQrHintSubEx {
443 fn phantom_execute(
444 &mut self,
445 _: &MemoryController<F>,
446 streams: &mut Streams<F>,
447 _: PhantomDiscriminant,
448 _: F,
449 _: F,
450 c_upper: u16,
451 ) -> eyre::Result<()> {
452 let c_idx = c_upper as usize;
453 if c_idx >= self.supported_curves.len() {
454 bail!(
455 "Curve index {c_idx} out of range: {} supported curves",
456 self.supported_curves.len()
457 );
458 }
459 let curve = &self.supported_curves[c_idx];
460
461 let num_limbs: usize = if curve.modulus.bits().div_ceil(8) <= 32 {
462 32
463 } else if curve.modulus.bits().div_ceil(8) <= 48 {
464 48
465 } else {
466 bail!("Modulus too large")
467 };
468
469 let hint_bytes = self.non_qrs[c_idx]
470 .to_bytes_le()
471 .into_iter()
472 .map(F::from_canonical_u8)
473 .chain(repeat(F::ZERO))
474 .take(num_limbs)
475 .collect();
476 streams.hint_stream = hint_bytes;
477 Ok(())
478 }
479 }
480
481 fn find_non_qr(modulus: &BigUint) -> BigUint {
483 if modulus % 4u32 == BigUint::from(3u8) {
484 modulus - BigUint::one()
486 } else if modulus % 8u32 == BigUint::from(5u8) {
487 BigUint::from_u8(2u8).unwrap()
490 } else {
491 let mut rng = StdRng::from_entropy();
492 let mut non_qr = rng.gen_biguint_range(
493 &BigUint::from_u8(2).unwrap(),
494 &(modulus - BigUint::from_u8(1).unwrap()),
495 );
496 let exponent = (modulus - BigUint::one()) >> 1;
500 while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() {
501 non_qr = rng.gen_biguint_range(
502 &BigUint::from_u8(2).unwrap(),
503 &(modulus - BigUint::from_u8(1).unwrap()),
504 );
505 }
506 non_qr
507 }
508 }
509}