openvm_ecc_circuit/extension/
weierstrass.rs

1use std::sync::Arc;
2
3use hex_literal::hex;
4use lazy_static::lazy_static;
5use num_bigint::BigUint;
6use num_traits::{FromPrimitive, Zero};
7use once_cell::sync::Lazy;
8use openvm_circuit::{
9    arch::{
10        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError, ExecutionBridge,
11        ExecutorInventoryBuilder, ExecutorInventoryError, RowMajorMatrixArena, VmCircuitExtension,
12        VmExecutionExtension, VmProverExtension,
13    },
14    system::{memory::SharedMemoryHelper, SystemPort},
15};
16use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
17use openvm_circuit_primitives::{
18    bitwise_op_lookup::{
19        BitwiseOperationLookupAir, BitwiseOperationLookupBus, BitwiseOperationLookupChip,
20        SharedBitwiseOperationLookupChip,
21    },
22    var_range::VariableRangeCheckerBus,
23};
24use openvm_ecc_transpiler::Rv32WeierstrassOpcode;
25use openvm_instructions::{LocalOpcode, VmOpcode};
26use openvm_mod_circuit_builder::ExprBuilderConfig;
27use openvm_stark_backend::{
28    config::{StarkGenericConfig, Val},
29    engine::StarkEngine,
30    p3_field::PrimeField32,
31    prover::cpu::{CpuBackend, CpuDevice},
32};
33use serde::{Deserialize, Serialize};
34use serde_with::{serde_as, DisplayFromStr};
35use strum::EnumCount;
36
37use crate::{
38    get_ec_addne_air, get_ec_addne_chip, get_ec_addne_step, get_ec_double_air, get_ec_double_chip,
39    get_ec_double_step, EcAddNeExecutor, EcDoubleExecutor, EccCpuProverExt, WeierstrassAir,
40};
41
42#[serde_as]
43#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
44pub struct CurveConfig {
45    /// The name of the curve struct as defined by moduli_declare.
46    pub struct_name: String,
47    /// The coordinate modulus of the curve.
48    #[serde_as(as = "DisplayFromStr")]
49    pub modulus: BigUint,
50    /// The scalar field modulus of the curve.
51    #[serde_as(as = "DisplayFromStr")]
52    pub scalar: BigUint,
53    /// The coefficient a of y^2 = x^3 + ax + b.
54    #[serde_as(as = "DisplayFromStr")]
55    pub a: BigUint,
56    /// The coefficient b of y^2 = x^3 + ax + b.
57    #[serde_as(as = "DisplayFromStr")]
58    pub b: BigUint,
59}
60
61pub static SECP256K1_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
62    struct_name: SECP256K1_ECC_STRUCT_NAME.to_string(),
63    modulus: SECP256K1_MODULUS.clone(),
64    scalar: SECP256K1_ORDER.clone(),
65    a: BigUint::zero(),
66    b: BigUint::from_u8(7u8).unwrap(),
67});
68
69pub static P256_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
70    struct_name: P256_ECC_STRUCT_NAME.to_string(),
71    modulus: P256_MODULUS.clone(),
72    scalar: P256_ORDER.clone(),
73    a: BigUint::from_bytes_le(&P256_A),
74    b: BigUint::from_bytes_le(&P256_B),
75});
76
77#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
78pub struct WeierstrassExtension {
79    pub supported_curves: Vec<CurveConfig>,
80}
81
82impl WeierstrassExtension {
83    pub fn generate_sw_init(&self) -> String {
84        let supported_curves = self
85            .supported_curves
86            .iter()
87            .map(|curve_config| format!("\"{}\"", curve_config.struct_name))
88            .collect::<Vec<String>>()
89            .join(", ");
90
91        format!("openvm_ecc_guest::sw_macros::sw_init! {{ {supported_curves} }}")
92    }
93}
94
95#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
96pub enum WeierstrassExtensionExecutor {
97    // 32 limbs prime
98    EcAddNeRv32_32(EcAddNeExecutor<2, 32>),
99    EcDoubleRv32_32(EcDoubleExecutor<2, 32>),
100    // 48 limbs prime
101    EcAddNeRv32_48(EcAddNeExecutor<6, 16>),
102    EcDoubleRv32_48(EcDoubleExecutor<6, 16>),
103}
104
105impl<F: PrimeField32> VmExecutionExtension<F> for WeierstrassExtension {
106    type Executor = WeierstrassExtensionExecutor;
107
108    fn extend_execution(
109        &self,
110        inventory: &mut ExecutorInventoryBuilder<F, WeierstrassExtensionExecutor>,
111    ) -> Result<(), ExecutorInventoryError> {
112        let pointer_max_bits = inventory.pointer_max_bits();
113        // TODO: somehow get the range checker bus from `ExecutorInventory`
114        let dummy_range_checker_bus = VariableRangeCheckerBus::new(u16::MAX, 16);
115        for (i, curve) in self.supported_curves.iter().enumerate() {
116            let start_offset =
117                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
118            let bytes = curve.modulus.bits().div_ceil(8);
119
120            if bytes <= 32 {
121                let config = ExprBuilderConfig {
122                    modulus: curve.modulus.clone(),
123                    num_limbs: 32,
124                    limb_bits: 8,
125                };
126                let addne = get_ec_addne_step(
127                    config.clone(),
128                    dummy_range_checker_bus,
129                    pointer_max_bits,
130                    start_offset,
131                );
132
133                inventory.add_executor(
134                    WeierstrassExtensionExecutor::EcAddNeRv32_32(addne),
135                    ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
136                        ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
137                        .map(|x| VmOpcode::from_usize(x + start_offset)),
138                )?;
139
140                let double = get_ec_double_step(
141                    config,
142                    dummy_range_checker_bus,
143                    pointer_max_bits,
144                    start_offset,
145                    curve.a.clone(),
146                );
147
148                inventory.add_executor(
149                    WeierstrassExtensionExecutor::EcDoubleRv32_32(double),
150                    ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
151                        ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
152                        .map(|x| VmOpcode::from_usize(x + start_offset)),
153                )?;
154            } else if bytes <= 48 {
155                let config = ExprBuilderConfig {
156                    modulus: curve.modulus.clone(),
157                    num_limbs: 48,
158                    limb_bits: 8,
159                };
160                let addne = get_ec_addne_step(
161                    config.clone(),
162                    dummy_range_checker_bus,
163                    pointer_max_bits,
164                    start_offset,
165                );
166
167                inventory.add_executor(
168                    WeierstrassExtensionExecutor::EcAddNeRv32_48(addne),
169                    ((Rv32WeierstrassOpcode::EC_ADD_NE as usize)
170                        ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize))
171                        .map(|x| VmOpcode::from_usize(x + start_offset)),
172                )?;
173
174                let double = get_ec_double_step(
175                    config,
176                    dummy_range_checker_bus,
177                    pointer_max_bits,
178                    start_offset,
179                    curve.a.clone(),
180                );
181
182                inventory.add_executor(
183                    WeierstrassExtensionExecutor::EcDoubleRv32_48(double),
184                    ((Rv32WeierstrassOpcode::EC_DOUBLE as usize)
185                        ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize))
186                        .map(|x| VmOpcode::from_usize(x + start_offset)),
187                )?;
188            } else {
189                panic!("Modulus too large");
190            }
191        }
192
193        Ok(())
194    }
195}
196
197impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for WeierstrassExtension {
198    fn extend_circuit(&self, inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
199        let SystemPort {
200            execution_bus,
201            program_bus,
202            memory_bridge,
203        } = inventory.system().port();
204
205        let exec_bridge = ExecutionBridge::new(execution_bus, program_bus);
206        let range_checker_bus = inventory.range_checker().bus;
207        let pointer_max_bits = inventory.pointer_max_bits();
208
209        let bitwise_lu = {
210            // A trick to get around Rust's borrow rules
211            let existing_air = inventory.find_air::<BitwiseOperationLookupAir<8>>().next();
212            if let Some(air) = existing_air {
213                air.bus
214            } else {
215                let bus = BitwiseOperationLookupBus::new(inventory.new_bus_idx());
216                let air = BitwiseOperationLookupAir::<8>::new(bus);
217                inventory.add_air(air);
218                air.bus
219            }
220        };
221        for (i, curve) in self.supported_curves.iter().enumerate() {
222            let start_offset =
223                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
224            let bytes = curve.modulus.bits().div_ceil(8);
225
226            if bytes <= 32 {
227                let config = ExprBuilderConfig {
228                    modulus: curve.modulus.clone(),
229                    num_limbs: 32,
230                    limb_bits: 8,
231                };
232
233                let addne = get_ec_addne_air::<2, 32>(
234                    exec_bridge,
235                    memory_bridge,
236                    config.clone(),
237                    range_checker_bus,
238                    bitwise_lu,
239                    pointer_max_bits,
240                    start_offset,
241                );
242                inventory.add_air(addne);
243
244                let double = get_ec_double_air::<2, 32>(
245                    exec_bridge,
246                    memory_bridge,
247                    config,
248                    range_checker_bus,
249                    bitwise_lu,
250                    pointer_max_bits,
251                    start_offset,
252                    curve.a.clone(),
253                );
254                inventory.add_air(double);
255            } else if bytes <= 48 {
256                let config = ExprBuilderConfig {
257                    modulus: curve.modulus.clone(),
258                    num_limbs: 48,
259                    limb_bits: 8,
260                };
261
262                let addne = get_ec_addne_air::<6, 16>(
263                    exec_bridge,
264                    memory_bridge,
265                    config.clone(),
266                    range_checker_bus,
267                    bitwise_lu,
268                    pointer_max_bits,
269                    start_offset,
270                );
271                inventory.add_air(addne);
272
273                let double = get_ec_double_air::<6, 16>(
274                    exec_bridge,
275                    memory_bridge,
276                    config,
277                    range_checker_bus,
278                    bitwise_lu,
279                    pointer_max_bits,
280                    start_offset,
281                    curve.a.clone(),
282                );
283                inventory.add_air(double);
284            } else {
285                panic!("Modulus too large");
286            }
287        }
288
289        Ok(())
290    }
291}
292
293// This implementation is specific to CpuBackend because the lookup chips (VariableRangeChecker,
294// BitwiseOperationLookupChip) are specific to CpuBackend.
295impl<E, SC, RA> VmProverExtension<E, RA, WeierstrassExtension> for EccCpuProverExt
296where
297    SC: StarkGenericConfig,
298    E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
299    RA: RowMajorMatrixArena<Val<SC>>,
300    Val<SC>: PrimeField32,
301{
302    fn extend_prover(
303        &self,
304        extension: &WeierstrassExtension,
305        inventory: &mut ChipInventory<SC, RA, CpuBackend<SC>>,
306    ) -> Result<(), ChipInventoryError> {
307        let range_checker = inventory.range_checker()?.clone();
308        let timestamp_max_bits = inventory.timestamp_max_bits();
309        let pointer_max_bits = inventory.airs().pointer_max_bits();
310        let mem_helper = SharedMemoryHelper::new(range_checker.clone(), timestamp_max_bits);
311        let bitwise_lu = {
312            let existing_chip = inventory
313                .find_chip::<SharedBitwiseOperationLookupChip<8>>()
314                .next();
315            if let Some(chip) = existing_chip {
316                chip.clone()
317            } else {
318                let air: &BitwiseOperationLookupAir<8> = inventory.next_air()?;
319                let chip = Arc::new(BitwiseOperationLookupChip::new(air.bus));
320                inventory.add_periphery_chip(chip.clone());
321                chip
322            }
323        };
324        for curve in extension.supported_curves.iter() {
325            let bytes = curve.modulus.bits().div_ceil(8);
326
327            if bytes <= 32 {
328                let config = ExprBuilderConfig {
329                    modulus: curve.modulus.clone(),
330                    num_limbs: 32,
331                    limb_bits: 8,
332                };
333
334                inventory.next_air::<WeierstrassAir<2, 2, 32>>()?;
335                let addne = get_ec_addne_chip::<Val<SC>, 2, 32>(
336                    config.clone(),
337                    mem_helper.clone(),
338                    range_checker.clone(),
339                    bitwise_lu.clone(),
340                    pointer_max_bits,
341                );
342                inventory.add_executor_chip(addne);
343
344                inventory.next_air::<WeierstrassAir<1, 2, 32>>()?;
345                let double = get_ec_double_chip::<Val<SC>, 2, 32>(
346                    config,
347                    mem_helper.clone(),
348                    range_checker.clone(),
349                    bitwise_lu.clone(),
350                    pointer_max_bits,
351                    curve.a.clone(),
352                );
353                inventory.add_executor_chip(double);
354            } else if bytes <= 48 {
355                let config = ExprBuilderConfig {
356                    modulus: curve.modulus.clone(),
357                    num_limbs: 48,
358                    limb_bits: 8,
359                };
360
361                inventory.next_air::<WeierstrassAir<2, 6, 16>>()?;
362                let addne = get_ec_addne_chip::<Val<SC>, 6, 16>(
363                    config.clone(),
364                    mem_helper.clone(),
365                    range_checker.clone(),
366                    bitwise_lu.clone(),
367                    pointer_max_bits,
368                );
369                inventory.add_executor_chip(addne);
370
371                inventory.next_air::<WeierstrassAir<1, 6, 16>>()?;
372                let double = get_ec_double_chip::<Val<SC>, 6, 16>(
373                    config,
374                    mem_helper.clone(),
375                    range_checker.clone(),
376                    bitwise_lu.clone(),
377                    pointer_max_bits,
378                    curve.a.clone(),
379                );
380                inventory.add_executor_chip(double);
381            } else {
382                panic!("Modulus too large");
383            }
384        }
385
386        Ok(())
387    }
388}
389
390// Convenience constants for constructors
391lazy_static! {
392    // The constants are taken from: https://en.bitcoin.it/wiki/Secp256k1
393    pub static ref SECP256K1_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
394        "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE FFFFFC2F"
395    ));
396    pub static ref SECP256K1_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
397        "FFFFFFFF FFFFFFFF FFFFFFFF FFFFFFFE BAAEDCE6 AF48A03B BFD25E8C D0364141"
398    ));
399}
400
401lazy_static! {
402    // The constants are taken from: https://neuromancer.sk/std/secg/secp256r1
403    pub static ref P256_MODULUS: BigUint = BigUint::from_bytes_be(&hex!(
404        "ffffffff00000001000000000000000000000000ffffffffffffffffffffffff"
405    ));
406    pub static ref P256_ORDER: BigUint = BigUint::from_bytes_be(&hex!(
407        "ffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
408    ));
409}
410// little-endian
411const P256_A: [u8; 32] = hex!("fcffffffffffffffffffffff00000000000000000000000001000000ffffffff");
412// little-endian
413const P256_B: [u8; 32] = hex!("4b60d2273e3cce3bf6b053ccb0061d65bc86987655bdebb3e7933aaad835c65a");
414
415pub const SECP256K1_ECC_STRUCT_NAME: &str = "Secp256k1Point";
416pub const P256_ECC_STRUCT_NAME: &str = "P256Point";