openvm_ecc_circuit/
weierstrass_extension.rs

1use derive_more::derive::From;
2use num_bigint::BigUint;
3use num_traits::{FromPrimitive, Zero};
4use once_cell::sync::Lazy;
5use openvm_algebra_guest::IntMod;
6use openvm_circuit::{
7    arch::{SystemPort, VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError},
8    system::phantom::PhantomChip,
9};
10use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
11use openvm_circuit_primitives::bitwise_op_lookup::{
12    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
13};
14use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
15use openvm_ecc_guest::{
16    k256::{SECP256K1_MODULUS, SECP256K1_ORDER},
17    p256::{CURVE_A as P256_A, CURVE_B as P256_B, P256_MODULUS, P256_ORDER},
18};
19use openvm_ecc_transpiler::{EccPhantom, Rv32WeierstrassOpcode};
20use openvm_instructions::{LocalOpcode, PhantomDiscriminant, VmOpcode};
21use openvm_mod_circuit_builder::ExprBuilderConfig;
22use openvm_rv32_adapters::Rv32VecHeapAdapterChip;
23use openvm_stark_backend::p3_field::PrimeField32;
24use serde::{Deserialize, Serialize};
25use serde_with::{serde_as, DisplayFromStr};
26use strum::EnumCount;
27
28use super::{EcAddNeChip, EcDoubleChip};
29
30#[serde_as]
31#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
32pub struct CurveConfig {
33    /// The coordinate modulus of the curve.
34    #[serde_as(as = "DisplayFromStr")]
35    pub modulus: BigUint,
36    /// The scalar field modulus of the curve.
37    #[serde_as(as = "DisplayFromStr")]
38    pub scalar: BigUint,
39    /// The coefficient a of y^2 = x^3 + ax + b.
40    #[serde_as(as = "DisplayFromStr")]
41    pub a: BigUint,
42    /// The coefficient b of y^2 = x^3 + ax + b.
43    #[serde_as(as = "DisplayFromStr")]
44    pub b: BigUint,
45}
46
47pub static SECP256K1_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
48    modulus: SECP256K1_MODULUS.clone(),
49    scalar: SECP256K1_ORDER.clone(),
50    a: BigUint::zero(),
51    b: BigUint::from_u8(7u8).unwrap(),
52});
53
54pub static P256_CONFIG: Lazy<CurveConfig> = Lazy::new(|| CurveConfig {
55    modulus: P256_MODULUS.clone(),
56    scalar: P256_ORDER.clone(),
57    a: BigUint::from_bytes_le(P256_A.as_le_bytes()),
58    b: BigUint::from_bytes_le(P256_B.as_le_bytes()),
59});
60
61#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
62pub struct WeierstrassExtension {
63    pub supported_curves: Vec<CurveConfig>,
64}
65
66#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)]
67pub enum WeierstrassExtensionExecutor<F: PrimeField32> {
68    // 32 limbs prime
69    EcAddNeRv32_32(EcAddNeChip<F, 2, 32>),
70    EcDoubleRv32_32(EcDoubleChip<F, 2, 32>),
71    // 48 limbs prime
72    EcAddNeRv32_48(EcAddNeChip<F, 6, 16>),
73    EcDoubleRv32_48(EcDoubleChip<F, 6, 16>),
74}
75
76#[derive(ChipUsageGetter, Chip, AnyEnum, From)]
77pub enum WeierstrassExtensionPeriphery<F: PrimeField32> {
78    BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>),
79    Phantom(PhantomChip<F>),
80}
81
82impl<F: PrimeField32> VmExtension<F> for WeierstrassExtension {
83    type Executor = WeierstrassExtensionExecutor<F>;
84    type Periphery = WeierstrassExtensionPeriphery<F>;
85
86    fn build(
87        &self,
88        builder: &mut VmInventoryBuilder<F>,
89    ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError> {
90        let mut inventory = VmInventory::new();
91        let SystemPort {
92            execution_bus,
93            program_bus,
94            memory_bridge,
95        } = builder.system_port();
96        let bitwise_lu_chip = if let Some(&chip) = builder
97            .find_chip::<SharedBitwiseOperationLookupChip<8>>()
98            .first()
99        {
100            chip.clone()
101        } else {
102            let bitwise_lu_bus = BitwiseOperationLookupBus::new(builder.new_bus_idx());
103            let chip = SharedBitwiseOperationLookupChip::new(bitwise_lu_bus);
104            inventory.add_periphery_chip(chip.clone());
105            chip
106        };
107        let offline_memory = builder.system_base().offline_memory();
108        let range_checker = builder.system_base().range_checker_chip.clone();
109        let pointer_bits = builder.system_config().memory_config.pointer_max_bits;
110        let ec_add_ne_opcodes = (Rv32WeierstrassOpcode::EC_ADD_NE as usize)
111            ..=(Rv32WeierstrassOpcode::SETUP_EC_ADD_NE as usize);
112        let ec_double_opcodes = (Rv32WeierstrassOpcode::EC_DOUBLE as usize)
113            ..=(Rv32WeierstrassOpcode::SETUP_EC_DOUBLE as usize);
114
115        for (i, curve) in self.supported_curves.iter().enumerate() {
116            let start_offset =
117                Rv32WeierstrassOpcode::CLASS_OFFSET + i * Rv32WeierstrassOpcode::COUNT;
118            let bytes = curve.modulus.bits().div_ceil(8);
119            let config32 = ExprBuilderConfig {
120                modulus: curve.modulus.clone(),
121                num_limbs: 32,
122                limb_bits: 8,
123            };
124            let config48 = ExprBuilderConfig {
125                modulus: curve.modulus.clone(),
126                num_limbs: 48,
127                limb_bits: 8,
128            };
129            if bytes <= 32 {
130                let add_ne_chip = EcAddNeChip::new(
131                    Rv32VecHeapAdapterChip::<F, 2, 2, 2, 32, 32>::new(
132                        execution_bus,
133                        program_bus,
134                        memory_bridge,
135                        pointer_bits,
136                        bitwise_lu_chip.clone(),
137                    ),
138                    config32.clone(),
139                    start_offset,
140                    range_checker.clone(),
141                    offline_memory.clone(),
142                );
143                inventory.add_executor(
144                    WeierstrassExtensionExecutor::EcAddNeRv32_32(add_ne_chip),
145                    ec_add_ne_opcodes
146                        .clone()
147                        .map(|x| VmOpcode::from_usize(x + start_offset)),
148                )?;
149                let double_chip = EcDoubleChip::new(
150                    Rv32VecHeapAdapterChip::<F, 1, 2, 2, 32, 32>::new(
151                        execution_bus,
152                        program_bus,
153                        memory_bridge,
154                        pointer_bits,
155                        bitwise_lu_chip.clone(),
156                    ),
157                    range_checker.clone(),
158                    config32.clone(),
159                    start_offset,
160                    curve.a.clone(),
161                    offline_memory.clone(),
162                );
163                inventory.add_executor(
164                    WeierstrassExtensionExecutor::EcDoubleRv32_32(double_chip),
165                    ec_double_opcodes
166                        .clone()
167                        .map(|x| VmOpcode::from_usize(x + start_offset)),
168                )?;
169            } else if bytes <= 48 {
170                let add_ne_chip = EcAddNeChip::new(
171                    Rv32VecHeapAdapterChip::<F, 2, 6, 6, 16, 16>::new(
172                        execution_bus,
173                        program_bus,
174                        memory_bridge,
175                        pointer_bits,
176                        bitwise_lu_chip.clone(),
177                    ),
178                    config48.clone(),
179                    start_offset,
180                    range_checker.clone(),
181                    offline_memory.clone(),
182                );
183                inventory.add_executor(
184                    WeierstrassExtensionExecutor::EcAddNeRv32_48(add_ne_chip),
185                    ec_add_ne_opcodes
186                        .clone()
187                        .map(|x| VmOpcode::from_usize(x + start_offset)),
188                )?;
189                let double_chip = EcDoubleChip::new(
190                    Rv32VecHeapAdapterChip::<F, 1, 6, 6, 16, 16>::new(
191                        execution_bus,
192                        program_bus,
193                        memory_bridge,
194                        pointer_bits,
195                        bitwise_lu_chip.clone(),
196                    ),
197                    range_checker.clone(),
198                    config48.clone(),
199                    start_offset,
200                    curve.a.clone(),
201                    offline_memory.clone(),
202                );
203                inventory.add_executor(
204                    WeierstrassExtensionExecutor::EcDoubleRv32_48(double_chip),
205                    ec_double_opcodes
206                        .clone()
207                        .map(|x| VmOpcode::from_usize(x + start_offset)),
208                )?;
209            } else {
210                panic!("Modulus too large");
211            }
212        }
213        let non_qr_hint_sub_ex = phantom::NonQrHintSubEx::new(self.supported_curves.clone());
214        builder.add_phantom_sub_executor(
215            non_qr_hint_sub_ex.clone(),
216            PhantomDiscriminant(EccPhantom::HintNonQr as u16),
217        )?;
218        builder.add_phantom_sub_executor(
219            phantom::DecompressHintSubEx::new(non_qr_hint_sub_ex),
220            PhantomDiscriminant(EccPhantom::HintDecompress as u16),
221        )?;
222
223        Ok(inventory)
224    }
225}
226
227pub(crate) mod phantom {
228    use std::{
229        iter::{once, repeat},
230        ops::Deref,
231    };
232
233    use eyre::bail;
234    use num_bigint::{BigUint, RandBigInt};
235    use num_integer::Integer;
236    use num_traits::{FromPrimitive, One};
237    use openvm_circuit::{
238        arch::{PhantomSubExecutor, Streams},
239        system::memory::MemoryController,
240    };
241    use openvm_ecc_guest::weierstrass::DecompressionHint;
242    use openvm_instructions::{riscv::RV32_MEMORY_AS, PhantomDiscriminant};
243    use openvm_rv32im_circuit::adapters::unsafe_read_rv32_register;
244    use openvm_stark_backend::p3_field::PrimeField32;
245    use rand::{rngs::StdRng, SeedableRng};
246
247    use super::CurveConfig;
248
249    #[derive(derive_new::new)]
250    pub struct DecompressHintSubEx(NonQrHintSubEx);
251
252    impl Deref for DecompressHintSubEx {
253        type Target = NonQrHintSubEx;
254
255        fn deref(&self) -> &NonQrHintSubEx {
256            &self.0
257        }
258    }
259
260    impl<F: PrimeField32> PhantomSubExecutor<F> for DecompressHintSubEx {
261        fn phantom_execute(
262            &mut self,
263            memory: &MemoryController<F>,
264            streams: &mut Streams<F>,
265            _: PhantomDiscriminant,
266            a: F,
267            b: F,
268            c_upper: u16,
269        ) -> eyre::Result<()> {
270            let c_idx = c_upper as usize;
271            if c_idx >= self.supported_curves.len() {
272                bail!(
273                    "Curve index {c_idx} out of range: {} supported curves",
274                    self.supported_curves.len()
275                );
276            }
277            let curve = &self.supported_curves[c_idx];
278            let rs1 = unsafe_read_rv32_register(memory, a);
279            let num_limbs: usize = if curve.modulus.bits().div_ceil(8) <= 32 {
280                32
281            } else if curve.modulus.bits().div_ceil(8) <= 48 {
282                48
283            } else {
284                bail!("Modulus too large")
285            };
286            let mut x_limbs: Vec<u8> = Vec::with_capacity(num_limbs);
287            for i in 0..num_limbs {
288                let limb = memory.unsafe_read_cell(
289                    F::from_canonical_u32(RV32_MEMORY_AS),
290                    F::from_canonical_u32(rs1 + i as u32),
291                );
292                x_limbs.push(limb.as_canonical_u32() as u8);
293            }
294            let x = BigUint::from_bytes_le(&x_limbs);
295            let rs2 = unsafe_read_rv32_register(memory, b);
296            let rec_id = memory.unsafe_read_cell(
297                F::from_canonical_u32(RV32_MEMORY_AS),
298                F::from_canonical_u32(rs2),
299            );
300            let hint = self.decompress_point(x, rec_id.as_canonical_u32() & 1 == 1, c_idx);
301            let hint_bytes = once(F::from_bool(hint.possible))
302                .chain(repeat(F::ZERO))
303                .take(4)
304                .chain(
305                    hint.sqrt
306                        .to_bytes_le()
307                        .into_iter()
308                        .map(F::from_canonical_u8)
309                        .chain(repeat(F::ZERO))
310                        .take(num_limbs),
311                )
312                .collect();
313            streams.hint_stream = hint_bytes;
314            Ok(())
315        }
316    }
317
318    impl DecompressHintSubEx {
319        /// Given `x` in the coordinate field of the curve, and the recovery id,
320        /// return the unique `y` such that `(x, y)` is a point on the curve and
321        /// `y` has the same parity as the recovery id.
322        ///
323        /// If no such `y` exists, return the square root of `(x^3 + ax + b) * non_qr`
324        /// where `non_qr` is a quadratic nonresidue of the field.
325        fn decompress_point(
326            &self,
327            x: BigUint,
328            is_y_odd: bool,
329            curve_idx: usize,
330        ) -> DecompressionHint<BigUint> {
331            let curve = &self.supported_curves[curve_idx];
332            let alpha = ((&x * &x * &x) + (&x * &curve.a) + &curve.b) % &curve.modulus;
333            match mod_sqrt(&alpha, &curve.modulus, &self.non_qrs[curve_idx]) {
334                Some(beta) => {
335                    if is_y_odd == beta.is_odd() {
336                        DecompressionHint {
337                            possible: true,
338                            sqrt: beta,
339                        }
340                    } else {
341                        DecompressionHint {
342                            possible: true,
343                            sqrt: &curve.modulus - &beta,
344                        }
345                    }
346                }
347                None => {
348                    debug_assert_eq!(
349                        self.non_qrs[curve_idx]
350                            .modpow(&((&curve.modulus - BigUint::one()) >> 1), &curve.modulus),
351                        &curve.modulus - BigUint::one()
352                    );
353                    let sqrt = mod_sqrt(
354                        &(&alpha * &self.non_qrs[curve_idx]),
355                        &curve.modulus,
356                        &self.non_qrs[curve_idx],
357                    )
358                    .unwrap();
359                    DecompressionHint {
360                        possible: false,
361                        sqrt,
362                    }
363                }
364            }
365        }
366    }
367
368    /// Find the square root of `x` modulo `modulus` with `non_qr` a
369    /// quadratic nonresidue of the field.
370    pub fn mod_sqrt(x: &BigUint, modulus: &BigUint, non_qr: &BigUint) -> Option<BigUint> {
371        if modulus % 4u32 == BigUint::from_u8(3).unwrap() {
372            // x^(1/2) = x^((p+1)/4) when p = 3 mod 4
373            let exponent = (modulus + BigUint::one()) >> 2;
374            let ret = x.modpow(&exponent, modulus);
375            if &ret * &ret % modulus == x % modulus {
376                Some(ret)
377            } else {
378                None
379            }
380        } else {
381            // Tonelli-Shanks algorithm
382            // https://en.wikipedia.org/wiki/Tonelli%E2%80%93Shanks_algorithm#The_algorithm
383            let mut q = modulus - BigUint::one();
384            let mut s = 0;
385            while &q % 2u32 == BigUint::ZERO {
386                s += 1;
387                q /= 2u32;
388            }
389            let z = non_qr;
390            let mut m = s;
391            let mut c = z.modpow(&q, modulus);
392            let mut t = x.modpow(&q, modulus);
393            let mut r = x.modpow(&((q + BigUint::one()) >> 1), modulus);
394            loop {
395                if t == BigUint::ZERO {
396                    return Some(BigUint::ZERO);
397                }
398                if t == BigUint::one() {
399                    return Some(r);
400                }
401                let mut i = 0;
402                let mut tmp = t.clone();
403                while tmp != BigUint::one() && i < m {
404                    tmp = &tmp * &tmp % modulus;
405                    i += 1;
406                }
407                if i == m {
408                    // self is not a quadratic residue
409                    return None;
410                }
411                for _ in 0..m - i - 1 {
412                    c = &c * &c % modulus;
413                }
414                let b = c;
415                m = i;
416                c = &b * &b % modulus;
417                t = ((t * &b % modulus) * &b) % modulus;
418                r = (r * b) % modulus;
419            }
420        }
421    }
422
423    #[derive(Clone)]
424    pub struct NonQrHintSubEx {
425        pub supported_curves: Vec<CurveConfig>,
426        pub non_qrs: Vec<BigUint>,
427    }
428
429    impl NonQrHintSubEx {
430        pub fn new(supported_curves: Vec<CurveConfig>) -> Self {
431            let non_qrs = supported_curves
432                .iter()
433                .map(|curve| find_non_qr(&curve.modulus))
434                .collect();
435            Self {
436                supported_curves,
437                non_qrs,
438            }
439        }
440    }
441
442    impl<F: PrimeField32> PhantomSubExecutor<F> for NonQrHintSubEx {
443        fn phantom_execute(
444            &mut self,
445            _: &MemoryController<F>,
446            streams: &mut Streams<F>,
447            _: PhantomDiscriminant,
448            _: F,
449            _: F,
450            c_upper: u16,
451        ) -> eyre::Result<()> {
452            let c_idx = c_upper as usize;
453            if c_idx >= self.supported_curves.len() {
454                bail!(
455                    "Curve index {c_idx} out of range: {} supported curves",
456                    self.supported_curves.len()
457                );
458            }
459            let curve = &self.supported_curves[c_idx];
460
461            let num_limbs: usize = if curve.modulus.bits().div_ceil(8) <= 32 {
462                32
463            } else if curve.modulus.bits().div_ceil(8) <= 48 {
464                48
465            } else {
466                bail!("Modulus too large")
467            };
468
469            let hint_bytes = self.non_qrs[c_idx]
470                .to_bytes_le()
471                .into_iter()
472                .map(F::from_canonical_u8)
473                .chain(repeat(F::ZERO))
474                .take(num_limbs)
475                .collect();
476            streams.hint_stream = hint_bytes;
477            Ok(())
478        }
479    }
480
481    // Returns a non-quadratic residue in the field
482    fn find_non_qr(modulus: &BigUint) -> BigUint {
483        if modulus % 4u32 == BigUint::from(3u8) {
484            // p = 3 mod 4 then -1 is a quadratic residue
485            modulus - BigUint::one()
486        } else if modulus % 8u32 == BigUint::from(5u8) {
487            // p = 5 mod 8 then 2 is a non-quadratic residue
488            // since 2^((p-1)/2) = (-1)^((p^2-1)/8)
489            BigUint::from_u8(2u8).unwrap()
490        } else {
491            let mut rng = StdRng::from_entropy();
492            let mut non_qr = rng.gen_biguint_range(
493                &BigUint::from_u8(2).unwrap(),
494                &(modulus - BigUint::from_u8(1).unwrap()),
495            );
496            // To check if non_qr is a quadratic nonresidue, we compute non_qr^((p-1)/2)
497            // If the result is p-1, then non_qr is a quadratic nonresidue
498            // Otherwise, non_qr is a quadratic residue
499            let exponent = (modulus - BigUint::one()) >> 1;
500            while non_qr.modpow(&exponent, modulus) != modulus - BigUint::one() {
501                non_qr = rng.gen_biguint_range(
502                    &BigUint::from_u8(2).unwrap(),
503                    &(modulus - BigUint::from_u8(1).unwrap()),
504                );
505            }
506            non_qr
507        }
508    }
509}