openvm_pairing_circuit/
pairing_extension.rs

1use derive_more::derive::From;
2use num_bigint::BigUint;
3use num_traits::{FromPrimitive, Zero};
4use openvm_circuit::{
5    arch::{
6        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError,
7        ExecutorInventoryBuilder, ExecutorInventoryError, VmCircuitExtension, VmExecutionExtension,
8        VmProverExtension,
9    },
10    system::phantom::PhantomExecutor,
11};
12use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
13use openvm_ecc_circuit::CurveConfig;
14use openvm_instructions::PhantomDiscriminant;
15use openvm_pairing_guest::{
16    bls12_381::{
17        BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, BLS12_381_ORDER, BLS12_381_XI_ISIZE,
18    },
19    bn254::{BN254_ECC_STRUCT_NAME, BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE},
20};
21use openvm_pairing_transpiler::PairingPhantom;
22use openvm_stark_backend::{config::StarkGenericConfig, engine::StarkEngine, p3_field::Field};
23use serde::{Deserialize, Serialize};
24use strum::FromRepr;
25
26// All the supported pairing curves.
27#[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)]
28#[repr(usize)]
29pub enum PairingCurve {
30    Bn254,
31    Bls12_381,
32}
33
34impl PairingCurve {
35    pub fn curve_config(&self) -> CurveConfig {
36        match self {
37            PairingCurve::Bn254 => CurveConfig::new(
38                BN254_ECC_STRUCT_NAME.to_string(),
39                BN254_MODULUS.clone(),
40                BN254_ORDER.clone(),
41                BigUint::zero(),
42                BigUint::from_u8(3).unwrap(),
43            ),
44            PairingCurve::Bls12_381 => CurveConfig::new(
45                BLS12_381_ECC_STRUCT_NAME.to_string(),
46                BLS12_381_MODULUS.clone(),
47                BLS12_381_ORDER.clone(),
48                BigUint::zero(),
49                BigUint::from_u8(4).unwrap(),
50            ),
51        }
52    }
53
54    pub fn xi(&self) -> [isize; 2] {
55        match self {
56            PairingCurve::Bn254 => BN254_XI_ISIZE,
57            PairingCurve::Bls12_381 => BLS12_381_XI_ISIZE,
58        }
59    }
60}
61
62#[derive(Clone, Debug, From, derive_new::new, Serialize, Deserialize)]
63pub struct PairingExtension {
64    pub supported_curves: Vec<PairingCurve>,
65}
66
67#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
68pub enum PairingExtensionExecutor<F: Field> {
69    Phantom(PhantomExecutor<F>),
70}
71
72impl<F: Field> VmExecutionExtension<F> for PairingExtension {
73    type Executor = PairingExtensionExecutor<F>;
74
75    fn extend_execution(
76        &self,
77        inventory: &mut ExecutorInventoryBuilder<F, PairingExtensionExecutor<F>>,
78    ) -> Result<(), ExecutorInventoryError> {
79        inventory.add_phantom_sub_executor(
80            phantom::PairingHintSubEx,
81            PhantomDiscriminant(PairingPhantom::HintFinalExp as u16),
82        )?;
83        Ok(())
84    }
85}
86
87impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for PairingExtension {
88    fn extend_circuit(&self, _inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
89        Ok(())
90    }
91}
92
93pub struct PairingProverExt;
94impl<E, RA> VmProverExtension<E, RA, PairingExtension> for PairingProverExt
95where
96    E: StarkEngine,
97{
98    fn extend_prover(
99        &self,
100        _: &PairingExtension,
101        _inventory: &mut ChipInventory<E::SC, RA, E::PB>,
102    ) -> Result<(), ChipInventoryError> {
103        Ok(())
104    }
105}
106
107pub(crate) mod phantom {
108    use std::collections::VecDeque;
109
110    use eyre::bail;
111    use halo2curves_axiom::ff;
112    use openvm_circuit::{
113        arch::{PhantomSubExecutor, Streams},
114        system::memory::online::GuestMemory,
115    };
116    use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint};
117    use openvm_instructions::{
118        riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS},
119        PhantomDiscriminant,
120    };
121    use openvm_pairing_guest::{
122        bls12_381::BLS12_381_NUM_LIMBS,
123        bn254::BN254_NUM_LIMBS,
124        pairing::{FinalExp, MultiMillerLoop},
125    };
126    use openvm_rv32im_circuit::adapters::{memory_read, read_rv32_register};
127    use openvm_stark_backend::p3_field::Field;
128    use rand::rngs::StdRng;
129
130    use super::PairingCurve;
131
132    pub struct PairingHintSubEx;
133
134    impl<F: Field> PhantomSubExecutor<F> for PairingHintSubEx {
135        fn phantom_execute(
136            &self,
137            memory: &GuestMemory,
138            streams: &mut Streams<F>,
139            _: &mut StdRng,
140            _: PhantomDiscriminant,
141            a: u32,
142            b: u32,
143            c_upper: u16,
144        ) -> eyre::Result<()> {
145            let rs1 = read_rv32_register(memory, a);
146            let rs2 = read_rv32_register(memory, b);
147            hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper)
148        }
149    }
150
151    fn hint_pairing<F: Field>(
152        memory: &GuestMemory,
153        hint_stream: &mut VecDeque<F>,
154        rs1: u32,
155        rs2: u32,
156        c_upper: u16,
157    ) -> eyre::Result<()> {
158        let p_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs1));
159        // len in bytes
160        let p_len = u32::from_le_bytes(memory_read(
161            memory,
162            RV32_MEMORY_AS,
163            rs1 + RV32_REGISTER_NUM_LIMBS as u32,
164        ));
165
166        let q_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs2));
167        // len in bytes
168        let q_len = u32::from_le_bytes(memory_read(
169            memory,
170            RV32_MEMORY_AS,
171            rs2 + RV32_REGISTER_NUM_LIMBS as u32,
172        ));
173
174        match PairingCurve::from_repr(c_upper as usize) {
175            Some(PairingCurve::Bn254) => {
176                use halo2curves_axiom::bn256::{Fq, Fq12, Fq2};
177                use openvm_pairing_guest::halo2curves_shims::bn254::Bn254;
178                const N: usize = BN254_NUM_LIMBS;
179                if p_len != q_len {
180                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
181                }
182                let p = (0..p_len)
183                    .map(|i| -> eyre::Result<_> {
184                        let ptr = p_ptr + i * 2 * (N as u32);
185                        let x = read_fp::<N, Fq>(memory, ptr)?;
186                        let y = read_fp::<N, Fq>(memory, ptr + N as u32)?;
187                        Ok(AffinePoint::new(x, y))
188                    })
189                    .collect::<eyre::Result<Vec<_>>>()?;
190                let q = (0..q_len)
191                    .map(|i| -> eyre::Result<_> {
192                        let mut ptr = q_ptr + i * 4 * (N as u32);
193                        let mut read_fp2 = || -> eyre::Result<_> {
194                            let c0 = read_fp::<N, Fq>(memory, ptr)?;
195                            let c1 = read_fp::<N, Fq>(memory, ptr + N as u32)?;
196                            ptr += 2 * N as u32;
197                            Ok(Fq2::new(c0, c1))
198                        };
199                        let x = read_fp2()?;
200                        let y = read_fp2()?;
201                        Ok(AffinePoint::new(x, y))
202                    })
203                    .collect::<eyre::Result<Vec<_>>>()?;
204
205                let f: Fq12 = Bn254::multi_miller_loop(&p, &q);
206                let (c, u) = Bn254::final_exp_hint(&f);
207                hint_stream.clear();
208                hint_stream.extend(
209                    c.to_coeffs()
210                        .into_iter()
211                        .chain(u.to_coeffs())
212                        .flat_map(|fp2| fp2.to_coeffs())
213                        .flat_map(|fp| fp.to_bytes())
214                        .map(F::from_canonical_u8),
215                );
216            }
217            Some(PairingCurve::Bls12_381) => {
218                use halo2curves_axiom::bls12_381::{Fq, Fq12, Fq2};
219                use openvm_pairing_guest::halo2curves_shims::bls12_381::Bls12_381;
220                const N: usize = BLS12_381_NUM_LIMBS;
221                if p_len != q_len {
222                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
223                }
224                let p = (0..p_len)
225                    .map(|i| -> eyre::Result<_> {
226                        let ptr = p_ptr + i * 2 * (N as u32);
227                        let x = read_fp::<N, Fq>(memory, ptr)?;
228                        let y = read_fp::<N, Fq>(memory, ptr + N as u32)?;
229                        Ok(AffinePoint::new(x, y))
230                    })
231                    .collect::<eyre::Result<Vec<_>>>()?;
232                let q = (0..q_len)
233                    .map(|i| -> eyre::Result<_> {
234                        let mut ptr = q_ptr + i * 4 * (N as u32);
235                        let mut read_fp2 = || -> eyre::Result<_> {
236                            let c0 = read_fp::<N, Fq>(memory, ptr)?;
237                            let c1 = read_fp::<N, Fq>(memory, ptr + N as u32)?;
238                            ptr += 2 * N as u32;
239                            Ok(Fq2 { c0, c1 })
240                        };
241                        let x = read_fp2()?;
242                        let y = read_fp2()?;
243                        Ok(AffinePoint::new(x, y))
244                    })
245                    .collect::<eyre::Result<Vec<_>>>()?;
246
247                let f: Fq12 = Bls12_381::multi_miller_loop(&p, &q);
248                let (c, u) = Bls12_381::final_exp_hint(&f);
249                hint_stream.clear();
250                hint_stream.extend(
251                    c.to_coeffs()
252                        .into_iter()
253                        .chain(u.to_coeffs())
254                        .flat_map(|fp2| fp2.to_coeffs())
255                        .flat_map(|fp| fp.to_bytes())
256                        .map(F::from_canonical_u8),
257                );
258            }
259            _ => {
260                bail!("hint_pairing: invalid PairingCurve={c_upper}");
261            }
262        }
263        Ok(())
264    }
265
266    fn read_fp<const N: usize, Fp: ff::PrimeField>(
267        memory: &GuestMemory,
268        ptr: u32,
269    ) -> eyre::Result<Fp>
270    where
271        Fp::Repr: From<[u8; N]>,
272    {
273        // SAFETY:
274        // - RV32_MEMORY_AS consists of `u8`s
275        // - RV32_MEMORY_AS is in bounds
276        let repr: &[u8; N] = unsafe {
277            memory
278                .memory
279                .get_slice::<u8>((RV32_MEMORY_AS, ptr), N)
280                .try_into()
281                .unwrap()
282        };
283        Fp::from_repr((*repr).into())
284            .into_option()
285            .ok_or(eyre::eyre!("bad ff::PrimeField repr"))
286    }
287}