1use std::sync::Arc;
2
3use hex_literal::hex;
4use lazy_static::lazy_static;
5use num_bigint::BigUint;
6use num_traits::{FromPrimitive, Zero};
7use once_cell::sync::Lazy;
8use openvm_circuit::{
9 arch::{
10 AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
11 ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
12 VmExecutionExtension, VmProverExtension,
13 },
14 system::{memory::SharedMemoryHelper, SystemPort},
15};
16use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
17use openvm_circuit_primitives::{
18 bitwise_op_lookup::{
19 BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
20 SharedBitwiseOperationLookupChip,
21 },
22 var_range::VariableRangeCheckerBus,
23};
24use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
25use openvm_instructions::{LocalOpcode, VmOpcode};
26use openvm_mod_circuit_builder::ExprBuilderConfig;
27use openvm_stark_backend::{
28 config::{StarkGenericConfig, Val},
29 engine::StarkEngine,
30 p3_field::PrimeField32,
31 prover::cpu::{CpuBackend, CpuDevice},
32};
33use serde::{Deserialize, Serialize};
34use serde_with::{serde_as, DisplayFromStr};
35use strum::EnumCount;
36
37use crate::{
38 get_ec_addne_air, get_ec_addne_chip, get_ec_addne_step, get_ec_double_air, get_ec_double_chip,
39 get_ec_double_step, EcAddNeExecutor, EcDoubleExecutor, EccCpuProverExt, WeierstrassAir,
40};
41
42#[serde_as]
43#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
44pub struct CurveConfig {
45 pub struct_name: String,
47 #[serde_as(as = "DisplayFromStr")]
49 pub modulus: BigUint,
50 #[serde_as(as = "DisplayFromStr")]
52 pub scalar: BigUint,
53 #[serde_as(as = "DisplayFromStr")]
55 pub a: BigUint,
56 #[serde_as(as = "DisplayFromStr")]
58 pub b: BigUint,
59}
60
61pub static SECP256K1_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
62 struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(),
63 modulus: SECP256K1_MODULUS.clone(),
64 scalar: SECP256K1_ORDER.clone(),
65 a: BigUint::zero(),
66 b: BigUint::from_u8(7u8).unwrap(),
67});
68
69pub static P256_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
70 struct_name: P256_ECC_STRUCT_NAME.to_string(),
71 modulus: P256_MODULUS.clone(),
72 scalar: P256_ORDER.clone(),
73 a: BigUint::from_bytes_le(&P256_A),
74 b: BigUint::from_bytes_le(&P256_B),
75});
76
77#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
78pub struct WeierstrassExtension {
79 pub supported_curves: Vec<CurveConfig>,
80}
81
82impl WeierstrassExtension {
83 pub fn generate_sw_init(&self) -> String {
84 let supported_curves = self
85 .supported_curves
86 .iter()
87 .map(|curve_config| format!("\"{}\"", curve_config.struct_name))
88 .collect::<Vec<String>>()
89 .join(", ");
90
91 format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}")
92 }
93}
94
95#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
96#[cfg_attr(
97 feature = "aot",
98 derive(
99 openvm_circuit_derive::AotExecutor,
100 openvm_circuit_derive::AotMeteredExecutor
101 )
102)]
103pub enum WeierstrassExtensionExecutor {
104 EcAddNeRv32_32(EcAddNeExecutor<2, 32>),
106 EcDoubleRv32_32(EcDoubleExecutor<2, 32>),
107 EcAddNeRv32_48(EcAddNeExecutor<6, 16>),
109 EcDoubleRv32_48(EcDoubleExecutor<6, 16>),
110}
111
112impl<F: PrimeField32> VmExecutionExtension<F> for WeierstrassExtension {
113 type Executor = WeierstrassExtensionExecutor;
114
115 fn extend_execution(
116 &self,
117 inventory: &mut ExecutorInventoryBuilder<F, WeierstrassExtensionExecutor>,
118 ) -> Result<(), ExecutorInventoryError> {
119 let pointer_max_bits = inventory.pointer_max_bits();
120 let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
122 for (i, curve) in self.supported_curves.iter().enumerate() {
123 let start_offset =
124 Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
125 let bytes = curve.modulus.bits().div_ceil(8);
126
127 if bytes <= 32 {
128 let config = ExprBuilderConfig {
129 modulus: curve.modulus.clone(),
130 num_limbs: 32,
131 limb_bits: 8,
132 };
133 let addne = get_ec_addne_step(
134 config.clone(),
135 dummy_range_checker_bus,
136 pointer_max_bits,
137 start_offset,
138 );
139
140 inventory.add_executor(
141 WeierstrassExtensionExecutor::EcAddNeRv32_32(addne),
142 ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
143 ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
144 .map(|x| VmOpcode::from_usize(x + start_offset)),
145 )?;
146
147 let double = get_ec_double_step(
148 config,
149 dummy_range_checker_bus,
150 pointer_max_bits,
151 start_offset,
152 curve.a.clone(),
153 );
154
155 inventory.add_executor(
156 WeierstrassExtensionExecutor::EcDoubleRv32_32(double),
157 ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
158 ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
159 .map(|x| VmOpcode::from_usize(x + start_offset)),
160 )?;
161 } else if bytes <= 48 {
162 let config = ExprBuilderConfig {
163 modulus: curve.modulus.clone(),
164 num_limbs: 48,
165 limb_bits: 8,
166 };
167 let addne = get_ec_addne_step(
168 config.clone(),
169 dummy_range_checker_bus,
170 pointer_max_bits,
171 start_offset,
172 );
173
174 inventory.add_executor(
175 WeierstrassExtensionExecutor::EcAddNeRv32_48(addne),
176 ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
177 ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
178 .map(|x| VmOpcode::from_usize(x + start_offset)),
179 )?;
180
181 let double = get_ec_double_step(
182 config,
183 dummy_range_checker_bus,
184 pointer_max_bits,
185 start_offset,
186 curve.a.clone(),
187 );
188
189 inventory.add_executor(
190 WeierstrassExtensionExecutor::EcDoubleRv32_48(double),
191 ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
192 ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
193 .map(|x| VmOpcode::from_usize(x + start_offset)),
194 )?;
195 } else {
196 panic!("Modulus too large");
197 }
198 }
199
200 Ok(())
201 }
202}
203
204impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for WeierstrassExtension {
205 fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
206 let SystemPort {
207 execution_bus,
208 program_bus,
209 memory_bridge,
210 } = inventory.system().port();
211
212 let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
213 let range_checker_bus = inventory.range_checker().bus;
214 let pointer_max_bits = inventory.pointer_max_bits();
215
216 let bitwise_lu = {
217 let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
219 if let Some(air) = existing_air {
220 air.bus
221 } else {
222 let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
223 let air = BitwiseOperationLookupAir::<8>::new(bus);
224 inventory.add_air(air);
225 air.bus
226 }
227 };
228 for (i, curve) in self.supported_curves.iter().enumerate() {
229 let start_offset =
230 Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
231 let bytes = curve.modulus.bits().div_ceil(8);
232
233 if bytes <= 32 {
234 let config = ExprBuilderConfig {
235 modulus: curve.modulus.clone(),
236 num_limbs: 32,
237 limb_bits: 8,
238 };
239
240 let addne = get_ec_addne_air::<2, 32>(
241 exec_bridge,
242 memory_bridge,
243 config.clone(),
244 range_checker_bus,
245 bitwise_lu,
246 pointer_max_bits,
247 start_offset,
248 );
249 inventory.add_air(addne);
250
251 let double = get_ec_double_air::<2, 32>(
252 exec_bridge,
253 memory_bridge,
254 config,
255 range_checker_bus,
256 bitwise_lu,
257 pointer_max_bits,
258 start_offset,
259 curve.a.clone(),
260 );
261 inventory.add_air(double);
262 } else if bytes <= 48 {
263 let config = ExprBuilderConfig {
264 modulus: curve.modulus.clone(),
265 num_limbs: 48,
266 limb_bits: 8,
267 };
268
269 let addne = get_ec_addne_air::<6, 16>(
270 exec_bridge,
271 memory_bridge,
272 config.clone(),
273 range_checker_bus,
274 bitwise_lu,
275 pointer_max_bits,
276 start_offset,
277 );
278 inventory.add_air(addne);
279
280 let double = get_ec_double_air::<6, 16>(
281 exec_bridge,
282 memory_bridge,
283 config,
284 range_checker_bus,
285 bitwise_lu,
286 pointer_max_bits,
287 start_offset,
288 curve.a.clone(),
289 );
290 inventory.add_air(double);
291 } else {
292 panic!("Modulus too large");
293 }
294 }
295
296 Ok(())
297 }
298}
299
300impl<E, SC, RA> VmProverExtension<E, RA, WeierstrassExtension> for EccCpuProverExt
303where
304 SC: StarkGenericConfig,
305 E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
306 RA: RowMajorMatrixArena<Val<SC>>,
307 Val<SC>: PrimeField32,
308{
309 fn extend_prover(
310 &self,
311 extension: &WeierstrassExtension,
312 inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
313 ) -> Result<(), ChipInventoryError> {
314 let range_checker = inventory.range_checker()?.clone();
315 let timestamp_max_bits = inventory.timestamp_max_bits();
316 let pointer_max_bits = inventory.airs().pointer_max_bits();
317 let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
318 let bitwise_lu = {
319 let existing_chip = inventory
320 .find_chip::<SharedBitwiseOperationLookupChip<8>>()
321 .next();
322 if let Some(chip) = existing_chip {
323 chip.clone()
324 } else {
325 let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
326 let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
327 inventory.add_periphery_chip(chip.clone());
328 chip
329 }
330 };
331 for curve in extension.supported_curves.iter() {
332 let bytes = curve.modulus.bits().div_ceil(8);
333
334 if bytes <= 32 {
335 let config = ExprBuilderConfig {
336 modulus: curve.modulus.clone(),
337 num_limbs: 32,
338 limb_bits: 8,
339 };
340
341 inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
342 let addne = get_ec_addne_chip::<Val<SC>, 2, 32>(
343 config.clone(),
344 mem_helper.clone(),
345 range_checker.clone(),
346 bitwise_lu.clone(),
347 pointer_max_bits,
348 );
349 inventory.add_executor_chip(addne);
350
351 inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
352 let double = get_ec_double_chip::<Val<SC>, 2, 32>(
353 config,
354 mem_helper.clone(),
355 range_checker.clone(),
356 bitwise_lu.clone(),
357 pointer_max_bits,
358 curve.a.clone(),
359 );
360 inventory.add_executor_chip(double);
361 } else if bytes <= 48 {
362 let config = ExprBuilderConfig {
363 modulus: curve.modulus.clone(),
364 num_limbs: 48,
365 limb_bits: 8,
366 };
367
368 inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
369 let addne = get_ec_addne_chip::<Val<SC>, 6, 16>(
370 config.clone(),
371 mem_helper.clone(),
372 range_checker.clone(),
373 bitwise_lu.clone(),
374 pointer_max_bits,
375 );
376 inventory.add_executor_chip(addne);
377
378 inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
379 let double = get_ec_double_chip::<Val<SC>, 6, 16>(
380 config,
381 mem_helper.clone(),
382 range_checker.clone(),
383 bitwise_lu.clone(),
384 pointer_max_bits,
385 curve.a.clone(),
386 );
387 inventory.add_executor_chip(double);
388 } else {
389 panic!("Modulus too large");
390 }
391 }
392
393 Ok(())
394 }
395}
396
397lazy_static! {
399 pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
401 "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F"
402 ));
403 pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
404 "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141"
405 ));
406}
407
408lazy_static! {
409 pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
411 "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
412 ));
413 pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
414 "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
415 ));
416}
417const P256_A: [u8; 32] = hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff");
419const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a");
421
422pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point";
423pub const P256_ECC_STRUCT_NAME: &str = "P256Point";