openvm_pairing_circuit/
pairing_extension.rs

1use derive_more::derive::From;
2use num_bigint::BigUint;
3use num_traits::{FromPrimitive, Zero};
4use openvm_circuit::{
5    arch::{
6        AirInventory, AirInventoryError, ChipInventory, ChipInventoryError,
7        ExecutorInventoryBuilder, ExecutorInventoryError, VmCircuitExtension, VmExecutionExtension,
8        VmProverExtension,
9    },
10    system::phantom::PhantomExecutor,
11};
12use openvm_circuit_derive::{AnyEnum, Executor, MeteredExecutor, PreflightExecutor};
13use openvm_ecc_circuit::CurveConfig;
14use openvm_instructions::PhantomDiscriminant;
15use openvm_pairing_guest::{
16    bls12_381::{
17        BLS12_381_ECC_STRUCT_NAME, BLS12_381_MODULUS, BLS12_381_ORDER, BLS12_381_XI_ISIZE,
18    },
19    bn254::{BN254_ECC_STRUCT_NAME, BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE},
20};
21use openvm_pairing_transpiler::PairingPhantom;
22use openvm_stark_backend::{config::StarkGenericConfig, engine::StarkEngine, p3_field::Field};
23use serde::{Deserialize, Serialize};
24use strum::FromRepr;
25
26// All the supported pairing curves.
27#[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)]
28#[repr(usize)]
29pub enum PairingCurve {
30    Bn254,
31    Bls12_381,
32}
33
34impl PairingCurve {
35    pub fn curve_config(&self) -> CurveConfig {
36        match self {
37            PairingCurve::Bn254 => CurveConfig::new(
38                BN254_ECC_STRUCT_NAME.to_string(),
39                BN254_MODULUS.clone(),
40                BN254_ORDER.clone(),
41                BigUint::zero(),
42                BigUint::from_u8(3).unwrap(),
43            ),
44            PairingCurve::Bls12_381 => CurveConfig::new(
45                BLS12_381_ECC_STRUCT_NAME.to_string(),
46                BLS12_381_MODULUS.clone(),
47                BLS12_381_ORDER.clone(),
48                BigUint::zero(),
49                BigUint::from_u8(4).unwrap(),
50            ),
51        }
52    }
53
54    pub fn xi(&self) -> [isize; 2] {
55        match self {
56            PairingCurve::Bn254 => BN254_XI_ISIZE,
57            PairingCurve::Bls12_381 => BLS12_381_XI_ISIZE,
58        }
59    }
60}
61
62#[derive(Clone, Debug, From, derive_new::new, Serialize, Deserialize)]
63pub struct PairingExtension {
64    pub supported_curves: Vec<PairingCurve>,
65}
66
67#[derive(Clone, AnyEnum, Executor, MeteredExecutor, PreflightExecutor)]
68#[cfg_attr(
69    feature = "aot",
70    derive(
71        openvm_circuit_derive::AotExecutor,
72        openvm_circuit_derive::AotMeteredExecutor
73    )
74)]
75pub enum PairingExtensionExecutor<F: Field> {
76    Phantom(PhantomExecutor<F>),
77}
78
79impl<F: Field> VmExecutionExtension<F> for PairingExtension {
80    type Executor = PairingExtensionExecutor<F>;
81
82    fn extend_execution(
83        &self,
84        inventory: &mut ExecutorInventoryBuilder<F, PairingExtensionExecutor<F>>,
85    ) -> Result<(), ExecutorInventoryError> {
86        inventory.add_phantom_sub_executor(
87            phantom::PairingHintSubEx,
88            PhantomDiscriminant(PairingPhantom::HintFinalExp as u16),
89        )?;
90        Ok(())
91    }
92}
93
94impl<SC: StarkGenericConfig> VmCircuitExtension<SC> for PairingExtension {
95    fn extend_circuit(&self, _inventory: &mut AirInventory<SC>) -> Result<(), AirInventoryError> {
96        Ok(())
97    }
98}
99
100pub struct PairingProverExt;
101impl<E, RA> VmProverExtension<E, RA, PairingExtension> for PairingProverExt
102where
103    E: StarkEngine,
104{
105    fn extend_prover(
106        &self,
107        _: &PairingExtension,
108        _inventory: &mut ChipInventory<E::SC, RA, E::PB>,
109    ) -> Result<(), ChipInventoryError> {
110        Ok(())
111    }
112}
113
114pub(crate) mod phantom {
115    use std::collections::VecDeque;
116
117    use eyre::bail;
118    use halo2curves_axiom::ff;
119    use openvm_circuit::{
120        arch::{PhantomSubExecutor, Streams},
121        system::memory::online::GuestMemory,
122    };
123    use openvm_ecc_guest::{algebra::field::FieldExtension, AffinePoint};
124    use openvm_instructions::{
125        riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS},
126        PhantomDiscriminant,
127    };
128    use openvm_pairing_guest::{
129        bls12_381::BLS12_381_NUM_LIMBS,
130        bn254::BN254_NUM_LIMBS,
131        pairing::{FinalExp, MultiMillerLoop},
132    };
133    use openvm_rv32im_circuit::adapters::{memory_read, read_rv32_register};
134    use openvm_stark_backend::p3_field::Field;
135    use rand::rngs::StdRng;
136
137    use super::PairingCurve;
138
139    pub struct PairingHintSubEx;
140
141    impl<F: Field> PhantomSubExecutor<F> for PairingHintSubEx {
142        fn phantom_execute(
143            &self,
144            memory: &GuestMemory,
145            streams: &mut Streams<F>,
146            _: &mut StdRng,
147            _: PhantomDiscriminant,
148            a: u32,
149            b: u32,
150            c_upper: u16,
151        ) -> eyre::Result<()> {
152            let rs1 = read_rv32_register(memory, a);
153            let rs2 = read_rv32_register(memory, b);
154            hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper)
155        }
156    }
157
158    fn hint_pairing<F: Field>(
159        memory: &GuestMemory,
160        hint_stream: &mut VecDeque<F>,
161        rs1: u32,
162        rs2: u32,
163        c_upper: u16,
164    ) -> eyre::Result<()> {
165        let p_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs1));
166        // len in bytes
167        let p_len = u32::from_le_bytes(memory_read(
168            memory,
169            RV32_MEMORY_AS,
170            rs1 + RV32_REGISTER_NUM_LIMBS as u32,
171        ));
172
173        let q_ptr = u32::from_le_bytes(memory_read(memory, RV32_MEMORY_AS, rs2));
174        // len in bytes
175        let q_len = u32::from_le_bytes(memory_read(
176            memory,
177            RV32_MEMORY_AS,
178            rs2 + RV32_REGISTER_NUM_LIMBS as u32,
179        ));
180
181        match PairingCurve::from_repr(c_upper as usize) {
182            Some(PairingCurve::Bn254) => {
183                use halo2curves_axiom::bn256::{Fq, Fq12, Fq2};
184                use openvm_pairing_guest::halo2curves_shims::bn254::Bn254;
185                const N: usize = BN254_NUM_LIMBS;
186                if p_len != q_len {
187                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
188                }
189                let p = (0..p_len)
190                    .map(|i| -> eyre::Result<_> {
191                        let ptr = p_ptr + i * 2 * (N as u32);
192                        let x = read_fp::<N, Fq>(memory, ptr)?;
193                        let y = read_fp::<N, Fq>(memory, ptr + N as u32)?;
194                        Ok(AffinePoint::new(x, y))
195                    })
196                    .collect::<eyre::Result<Vec<_>>>()?;
197                let q = (0..q_len)
198                    .map(|i| -> eyre::Result<_> {
199                        let mut ptr = q_ptr + i * 4 * (N as u32);
200                        let mut read_fp2 = || -> eyre::Result<_> {
201                            let c0 = read_fp::<N, Fq>(memory, ptr)?;
202                            let c1 = read_fp::<N, Fq>(memory, ptr + N as u32)?;
203                            ptr += 2 * N as u32;
204                            Ok(Fq2::new(c0, c1))
205                        };
206                        let x = read_fp2()?;
207                        let y = read_fp2()?;
208                        Ok(AffinePoint::new(x, y))
209                    })
210                    .collect::<eyre::Result<Vec<_>>>()?;
211
212                let f: Fq12 = Bn254::multi_miller_loop(&p, &q);
213                let (c, u) = Bn254::final_exp_hint(&f);
214                hint_stream.clear();
215                hint_stream.extend(
216                    c.to_coeffs()
217                        .into_iter()
218                        .chain(u.to_coeffs())
219                        .flat_map(|fp2| fp2.to_coeffs())
220                        .flat_map(|fp| fp.to_bytes())
221                        .map(F::from_canonical_u8),
222                );
223            }
224            Some(PairingCurve::Bls12_381) => {
225                use halo2curves_axiom::bls12_381::{Fq, Fq12, Fq2};
226                use openvm_pairing_guest::halo2curves_shims::bls12_381::Bls12_381;
227                const N: usize = BLS12_381_NUM_LIMBS;
228                if p_len != q_len {
229                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
230                }
231                let p = (0..p_len)
232                    .map(|i| -> eyre::Result<_> {
233                        let ptr = p_ptr + i * 2 * (N as u32);
234                        let x = read_fp::<N, Fq>(memory, ptr)?;
235                        let y = read_fp::<N, Fq>(memory, ptr + N as u32)?;
236                        Ok(AffinePoint::new(x, y))
237                    })
238                    .collect::<eyre::Result<Vec<_>>>()?;
239                let q = (0..q_len)
240                    .map(|i| -> eyre::Result<_> {
241                        let mut ptr = q_ptr + i * 4 * (N as u32);
242                        let mut read_fp2 = || -> eyre::Result<_> {
243                            let c0 = read_fp::<N, Fq>(memory, ptr)?;
244                            let c1 = read_fp::<N, Fq>(memory, ptr + N as u32)?;
245                            ptr += 2 * N as u32;
246                            Ok(Fq2 { c0, c1 })
247                        };
248                        let x = read_fp2()?;
249                        let y = read_fp2()?;
250                        Ok(AffinePoint::new(x, y))
251                    })
252                    .collect::<eyre::Result<Vec<_>>>()?;
253
254                let f: Fq12 = Bls12_381::multi_miller_loop(&p, &q);
255                let (c, u) = Bls12_381::final_exp_hint(&f);
256                hint_stream.clear();
257                hint_stream.extend(
258                    c.to_coeffs()
259                        .into_iter()
260                        .chain(u.to_coeffs())
261                        .flat_map(|fp2| fp2.to_coeffs())
262                        .flat_map(|fp| fp.to_bytes())
263                        .map(F::from_canonical_u8),
264                );
265            }
266            _ => {
267                bail!("hint_pairing: invalid PairingCurve={c_upper}");
268            }
269        }
270        Ok(())
271    }
272
273    fn read_fp<const N: usize, Fp: ff::PrimeField>(
274        memory: &GuestMemory,
275        ptr: u32,
276    ) -> eyre::Result<Fp>
277    where
278        Fp::Repr: From<[u8; N]>,
279    {
280        // SAFETY:
281        // - RV32_MEMORY_AS consists of `u8`s
282        // - RV32_MEMORY_AS is in bounds
283        let repr: &[u8; N] = unsafe {
284            memory
285                .memory
286                .get_slice::<u8>((RV32_MEMORY_AS, ptr), N)
287                .try_into()
288                .unwrap()
289        };
290        Fp::from_repr((*repr).into())
291            .into_option()
292            .ok_or(eyre::eyre!("bad ff::PrimeField repr"))
293    }
294}