openvm_ecc_circuit/extension/
weierstrass.rs

1use std::sync::Arc;
2
3use hex_literal::hex;
4use lazy_static::lazy_static;
5use num_bigint::BigUint;
6use num_traits::{FromPrimitive, Zero};
7use once_cell::sync::Lazy;
8use openvm_circuit::{
9    arch::{
10        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
11        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
12        VmExecutionExtension, VmProverExtension,
13    },
14    system::{memory::SharedMemoryHelper, SystemPort},
15};
16use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
17use openvm_circuit_primitives::{
18    bitwise_op_lookup::{
19        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
20        SharedBitwiseOperationLookupChip,
21    },
22    var_range::VariableRangeCheckerBus,
23};
24use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
25use openvm_instructions::{LocalOpcode, VmOpcode};
26use openvm_mod_circuit_builder::ExprBuilderConfig;
27use openvm_stark_backend::{
28    config::{StarkGenericConfig, Val},
29    engine::StarkEngine,
30    p3_field::PrimeField32,
31    prover::cpu::{CpuBackend, CpuDevice},
32};
33use serde::{Deserialize, Serialize};
34use serde_with::{serde_as, DisplayFromStr};
35use strum::EnumCount;
36
37use crate::{
38    get_ec_addne_air, get_ec_addne_chip, get_ec_addne_step, get_ec_double_air, get_ec_double_chip,
39    get_ec_double_step, EcAddNeExecutor, EcDoubleExecutor, EccCpuProverExt, WeierstrassAir,
40};
41
42#[serde_as]
43#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
44pub struct CurveConfig {
45    /// The name of the curve struct as defined by moduli_declare.
46    pub struct_name: String,
47    /// The coordinate modulus of the curve.
48    #[serde_as(as = "DisplayFromStr")]
49    pub modulus: BigUint,
50    /// The scalar field modulus of the curve.
51    #[serde_as(as = "DisplayFromStr")]
52    pub scalar: BigUint,
53    /// The coefficient a of y^2 = x^3 + ax + b.
54    #[serde_as(as = "DisplayFromStr")]
55    pub a: BigUint,
56    /// The coefficient b of y^2 = x^3 + ax + b.
57    #[serde_as(as = "DisplayFromStr")]
58    pub b: BigUint,
59}
60
61pub static SECP256K1_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
62    struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(),
63    modulus: SECP256K1_MODULUS.clone(),
64    scalar: SECP256K1_ORDER.clone(),
65    a: BigUint::zero(),
66    b: BigUint::from_u8(7u8).unwrap(),
67});
68
69pub static P256_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
70    struct_name: P256_ECC_STRUCT_NAME.to_string(),
71    modulus: P256_MODULUS.clone(),
72    scalar: P256_ORDER.clone(),
73    a: BigUint::from_bytes_le(&P256_A),
74    b: BigUint::from_bytes_le(&P256_B),
75});
76
77#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
78pub struct WeierstrassExtension {
79    pub supported_curves: Vec<CurveConfig>,
80}
81
82impl WeierstrassExtension {
83    pub fn generate_sw_init(&self) -> String {
84        let supported_curves = self
85            .supported_curves
86            .iter()
87            .map(|curve_config| format!("\"{}\"", curve_config.struct_name))
88            .collect::<Vec<String>>()
89            .join(", ");
90
91        format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}")
92    }
93}
94
95#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
96#[cfg_attr(
97    feature = "aot",
98    derive(
99        openvm_circuit_derive::AotExecutor,
100        openvm_circuit_derive::AotMeteredExecutor
101    )
102)]
103pub enum WeierstrassExtensionExecutor {
104    // 32 limbs prime
105    EcAddNeRv32_32(EcAddNeExecutor<2, 32>),
106    EcDoubleRv32_32(EcDoubleExecutor<2, 32>),
107    // 48 limbs prime
108    EcAddNeRv32_48(EcAddNeExecutor<6, 16>),
109    EcDoubleRv32_48(EcDoubleExecutor<6, 16>),
110}
111
112impl<F: PrimeField32> VmExecutionExtension<F> for WeierstrassExtension {
113    type Executor = WeierstrassExtensionExecutor;
114
115    fn extend_execution(
116        &self,
117        inventory: &mut ExecutorInventoryBuilder<F, WeierstrassExtensionExecutor>,
118    ) -> Result<(), ExecutorInventoryError> {
119        let pointer_max_bits = inventory.pointer_max_bits();
120        // TODO: somehow get the range checker bus from `ExecutorInventory`
121        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
122        for (i, curve) in self.supported_curves.iter().enumerate() {
123            let start_offset =
124                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
125            let bytes = curve.modulus.bits().div_ceil(8);
126
127            if bytes <= 32 {
128                let config = ExprBuilderConfig {
129                    modulus: curve.modulus.clone(),
130                    num_limbs: 32,
131                    limb_bits: 8,
132                };
133                let addne = get_ec_addne_step(
134                    config.clone(),
135                    dummy_range_checker_bus,
136                    pointer_max_bits,
137                    start_offset,
138                );
139
140                inventory.add_executor(
141                    WeierstrassExtensionExecutor::EcAddNeRv32_32(addne),
142                    ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
143                        ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
144                        .map(|x| VmOpcode::from_usize(x + start_offset)),
145                )?;
146
147                let double = get_ec_double_step(
148                    config,
149                    dummy_range_checker_bus,
150                    pointer_max_bits,
151                    start_offset,
152                    curve.a.clone(),
153                );
154
155                inventory.add_executor(
156                    WeierstrassExtensionExecutor::EcDoubleRv32_32(double),
157                    ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
158                        ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
159                        .map(|x| VmOpcode::from_usize(x + start_offset)),
160                )?;
161            } else if bytes <= 48 {
162                let config = ExprBuilderConfig {
163                    modulus: curve.modulus.clone(),
164                    num_limbs: 48,
165                    limb_bits: 8,
166                };
167                let addne = get_ec_addne_step(
168                    config.clone(),
169                    dummy_range_checker_bus,
170                    pointer_max_bits,
171                    start_offset,
172                );
173
174                inventory.add_executor(
175                    WeierstrassExtensionExecutor::EcAddNeRv32_48(addne),
176                    ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
177                        ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
178                        .map(|x| VmOpcode::from_usize(x + start_offset)),
179                )?;
180
181                let double = get_ec_double_step(
182                    config,
183                    dummy_range_checker_bus,
184                    pointer_max_bits,
185                    start_offset,
186                    curve.a.clone(),
187                );
188
189                inventory.add_executor(
190                    WeierstrassExtensionExecutor::EcDoubleRv32_48(double),
191                    ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
192                        ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
193                        .map(|x| VmOpcode::from_usize(x + start_offset)),
194                )?;
195            } else {
196                panic!("Modulus too large");
197            }
198        }
199
200        Ok(())
201    }
202}
203
204impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for WeierstrassExtension {
205    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
206        let SystemPort {
207            execution_bus,
208            program_bus,
209            memory_bridge,
210        } = inventory.system().port();
211
212        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
213        let range_checker_bus = inventory.range_checker().bus;
214        let pointer_max_bits = inventory.pointer_max_bits();
215
216        let bitwise_lu = {
217            // A trick to get around Rust's borrow rules
218            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
219            if let Some(air) = existing_air {
220                air.bus
221            } else {
222                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
223                let air = BitwiseOperationLookupAir::<8>::new(bus);
224                inventory.add_air(air);
225                air.bus
226            }
227        };
228        for (i, curve) in self.supported_curves.iter().enumerate() {
229            let start_offset =
230                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
231            let bytes = curve.modulus.bits().div_ceil(8);
232
233            if bytes <= 32 {
234                let config = ExprBuilderConfig {
235                    modulus: curve.modulus.clone(),
236                    num_limbs: 32,
237                    limb_bits: 8,
238                };
239
240                let addne = get_ec_addne_air::<2, 32>(
241                    exec_bridge,
242                    memory_bridge,
243                    config.clone(),
244                    range_checker_bus,
245                    bitwise_lu,
246                    pointer_max_bits,
247                    start_offset,
248                );
249                inventory.add_air(addne);
250
251                let double = get_ec_double_air::<2, 32>(
252                    exec_bridge,
253                    memory_bridge,
254                    config,
255                    range_checker_bus,
256                    bitwise_lu,
257                    pointer_max_bits,
258                    start_offset,
259                    curve.a.clone(),
260                );
261                inventory.add_air(double);
262            } else if bytes <= 48 {
263                let config = ExprBuilderConfig {
264                    modulus: curve.modulus.clone(),
265                    num_limbs: 48,
266                    limb_bits: 8,
267                };
268
269                let addne = get_ec_addne_air::<6, 16>(
270                    exec_bridge,
271                    memory_bridge,
272                    config.clone(),
273                    range_checker_bus,
274                    bitwise_lu,
275                    pointer_max_bits,
276                    start_offset,
277                );
278                inventory.add_air(addne);
279
280                let double = get_ec_double_air::<6, 16>(
281                    exec_bridge,
282                    memory_bridge,
283                    config,
284                    range_checker_bus,
285                    bitwise_lu,
286                    pointer_max_bits,
287                    start_offset,
288                    curve.a.clone(),
289                );
290                inventory.add_air(double);
291            } else {
292                panic!("Modulus too large");
293            }
294        }
295
296        Ok(())
297    }
298}
299
300// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
301// BitwiseOperationLookupChip) are specific to CpuBackend.
302impl<E, SC, RA> VmProverExtension<E, RA, WeierstrassExtension> for EccCpuProverExt
303where
304    SC: StarkGenericConfig,
305    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
306    RA: RowMajorMatrixArena<Val<SC>>,
307    Val<SC>: PrimeField32,
308{
309    fn extend_prover(
310        &self,
311        extension: &WeierstrassExtension,
312        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
313    ) -> Result<(), ChipInventoryError> {
314        let range_checker = inventory.range_checker()?.clone();
315        let timestamp_max_bits = inventory.timestamp_max_bits();
316        let pointer_max_bits = inventory.airs().pointer_max_bits();
317        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
318        let bitwise_lu = {
319            let existing_chip = inventory
320                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
321                .next();
322            if let Some(chip) = existing_chip {
323                chip.clone()
324            } else {
325                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
326                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
327                inventory.add_periphery_chip(chip.clone());
328                chip
329            }
330        };
331        for curve in extension.supported_curves.iter() {
332            let bytes = curve.modulus.bits().div_ceil(8);
333
334            if bytes <= 32 {
335                let config = ExprBuilderConfig {
336                    modulus: curve.modulus.clone(),
337                    num_limbs: 32,
338                    limb_bits: 8,
339                };
340
341                inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
342                let addne = get_ec_addne_chip::<Val<SC>, 2, 32>(
343                    config.clone(),
344                    mem_helper.clone(),
345                    range_checker.clone(),
346                    bitwise_lu.clone(),
347                    pointer_max_bits,
348                );
349                inventory.add_executor_chip(addne);
350
351                inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
352                let double = get_ec_double_chip::<Val<SC>, 2, 32>(
353                    config,
354                    mem_helper.clone(),
355                    range_checker.clone(),
356                    bitwise_lu.clone(),
357                    pointer_max_bits,
358                    curve.a.clone(),
359                );
360                inventory.add_executor_chip(double);
361            } else if bytes <= 48 {
362                let config = ExprBuilderConfig {
363                    modulus: curve.modulus.clone(),
364                    num_limbs: 48,
365                    limb_bits: 8,
366                };
367
368                inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
369                let addne = get_ec_addne_chip::<Val<SC>, 6, 16>(
370                    config.clone(),
371                    mem_helper.clone(),
372                    range_checker.clone(),
373                    bitwise_lu.clone(),
374                    pointer_max_bits,
375                );
376                inventory.add_executor_chip(addne);
377
378                inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
379                let double = get_ec_double_chip::<Val<SC>, 6, 16>(
380                    config,
381                    mem_helper.clone(),
382                    range_checker.clone(),
383                    bitwise_lu.clone(),
384                    pointer_max_bits,
385                    curve.a.clone(),
386                );
387                inventory.add_executor_chip(double);
388            } else {
389                panic!("Modulus too large");
390            }
391        }
392
393        Ok(())
394    }
395}
396
397// Convenience constants for constructors
398lazy_static! {
399    // The constants are taken from: https://en.bitcoin.it/wiki/Secp256k1
400    pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
401        "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F"
402    ));
403    pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
404        "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141"
405    ));
406}
407
408lazy_static! {
409    // The constants are taken from: https://neuromancer.sk/std/secg/secp256r1
410    pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
411        "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
412    ));
413    pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
414        "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
415    ));
416}
417// little-endian
418const P256_A: [u8; 32] = hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff");
419// little-endian
420const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a");
421
422pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point";
423pub const P256_ECC_STRUCT_NAME: &str = "P256Point";