openvm_pairing_circuit/
pairing_extension.rs

1use derive_more::derive::From;
2use num_bigint::BigUint;
3use num_traits::{FromPrimitive, Zero};
4use openvm_circuit::{
5    arch::{VmExtension, VmInventory, VmInventoryBuilder, VmInventoryError},
6    system::phantom::PhantomChip,
7};
8use openvm_circuit_derive::{AnyEnum, InstructionExecutor};
9use openvm_circuit_primitives::bitwise_op_lookup::SharedBitwiseOperationLookupChip;
10use openvm_circuit_primitives_derive::{Chip, ChipUsageGetter};
11use openvm_ecc_circuit::CurveConfig;
12use openvm_instructions::PhantomDiscriminant;
13use openvm_pairing_guest::{
14    bls12_381::{BLS12_381_MODULUS, BLS12_381_ORDER, BLS12_381_XI_ISIZE},
15    bn254::{BN254_MODULUS, BN254_ORDER, BN254_XI_ISIZE},
16};
17use openvm_pairing_transpiler::PairingPhantom;
18use openvm_stark_backend::p3_field::PrimeField32;
19use serde::{Deserialize, Serialize};
20use strum::FromRepr;
21
22use super::*;
23
24// All the supported pairing curves.
25#[derive(Clone, Copy, Debug, FromRepr, Serialize, Deserialize)]
26#[repr(usize)]
27pub enum PairingCurve {
28    Bn254,
29    Bls12_381,
30}
31
32impl PairingCurve {
33    pub fn curve_config(&self) -> CurveConfig {
34        match self {
35            PairingCurve::Bn254 => CurveConfig::new(
36                BN254_MODULUS.clone(),
37                BN254_ORDER.clone(),
38                BigUint::zero(),
39                BigUint::from_u8(3).unwrap(),
40            ),
41            PairingCurve::Bls12_381 => CurveConfig::new(
42                BLS12_381_MODULUS.clone(),
43                BLS12_381_ORDER.clone(),
44                BigUint::zero(),
45                BigUint::from_u8(4).unwrap(),
46            ),
47        }
48    }
49
50    pub fn xi(&self) -> [isize; 2] {
51        match self {
52            PairingCurve::Bn254 => BN254_XI_ISIZE,
53            PairingCurve::Bls12_381 => BLS12_381_XI_ISIZE,
54        }
55    }
56}
57
58#[derive(Clone, Debug, derive_new::new, Serialize, Deserialize)]
59pub struct PairingExtension {
60    pub supported_curves: Vec<PairingCurve>,
61}
62
63#[derive(Chip, ChipUsageGetter, InstructionExecutor, AnyEnum)]
64pub enum PairingExtensionExecutor<F: PrimeField32> {
65    // bn254 (32 limbs)
66    MillerDoubleAndAddStepRv32_32(MillerDoubleAndAddStepChip<F, 4, 12, 32>),
67    EvaluateLineRv32_32(EvaluateLineChip<F, 4, 2, 4, 32>),
68    // bls12-381 (48 limbs)
69    MillerDoubleAndAddStepRv32_48(MillerDoubleAndAddStepChip<F, 12, 36, 16>),
70    EvaluateLineRv32_48(EvaluateLineChip<F, 12, 6, 12, 16>),
71}
72
73#[derive(ChipUsageGetter, Chip, AnyEnum, From)]
74pub enum PairingExtensionPeriphery<F: PrimeField32> {
75    BitwiseOperationLookup(SharedBitwiseOperationLookupChip<8>),
76    Phantom(PhantomChip<F>),
77}
78
79impl<F: PrimeField32> VmExtension<F> for PairingExtension {
80    type Executor = PairingExtensionExecutor<F>;
81    type Periphery = PairingExtensionPeriphery<F>;
82
83    fn build(
84        &self,
85        builder: &mut VmInventoryBuilder<F>,
86    ) -> Result<VmInventory<Self::Executor, Self::Periphery>, VmInventoryError> {
87        let inventory = VmInventory::new();
88
89        builder.add_phantom_sub_executor(
90            phantom::PairingHintSubEx,
91            PhantomDiscriminant(PairingPhantom::HintFinalExp as u16),
92        )?;
93
94        Ok(inventory)
95    }
96}
97
98pub(crate) mod phantom {
99    use std::collections::VecDeque;
100
101    use eyre::bail;
102    use openvm_circuit::{
103        arch::{PhantomSubExecutor, Streams},
104        system::memory::MemoryController,
105    };
106    use openvm_ecc_guest::{algebra::field::FieldExtension, halo2curves::ff, AffinePoint};
107    use openvm_instructions::{
108        riscv::{RV32_MEMORY_AS, RV32_REGISTER_NUM_LIMBS},
109        PhantomDiscriminant,
110    };
111    use openvm_pairing_guest::{
112        bls12_381::BLS12_381_NUM_LIMBS,
113        bn254::BN254_NUM_LIMBS,
114        pairing::{FinalExp, MultiMillerLoop},
115    };
116    use openvm_rv32im_circuit::adapters::{compose, unsafe_read_rv32_register};
117    use openvm_stark_backend::p3_field::PrimeField32;
118
119    use super::PairingCurve;
120
121    pub struct PairingHintSubEx;
122
123    impl<F: PrimeField32> PhantomSubExecutor<F> for PairingHintSubEx {
124        fn phantom_execute(
125            &mut self,
126            memory: &MemoryController<F>,
127            streams: &mut Streams<F>,
128            _: PhantomDiscriminant,
129            a: F,
130            b: F,
131            c_upper: u16,
132        ) -> eyre::Result<()> {
133            let rs1 = unsafe_read_rv32_register(memory, a);
134            let rs2 = unsafe_read_rv32_register(memory, b);
135            hint_pairing(memory, &mut streams.hint_stream, rs1, rs2, c_upper)
136        }
137    }
138
139    fn hint_pairing<F: PrimeField32>(
140        memory: &MemoryController<F>,
141        hint_stream: &mut VecDeque<F>,
142        rs1: u32,
143        rs2: u32,
144        c_upper: u16,
145    ) -> eyre::Result<()> {
146        let p_ptr = compose(memory.unsafe_read(
147            F::from_canonical_u32(RV32_MEMORY_AS),
148            F::from_canonical_u32(rs1),
149        ));
150        // len in bytes
151        let p_len = compose(memory.unsafe_read(
152            F::from_canonical_u32(RV32_MEMORY_AS),
153            F::from_canonical_u32(rs1 + RV32_REGISTER_NUM_LIMBS as u32),
154        ));
155        let q_ptr = compose(memory.unsafe_read(
156            F::from_canonical_u32(RV32_MEMORY_AS),
157            F::from_canonical_u32(rs2),
158        ));
159        // len in bytes
160        let q_len = compose(memory.unsafe_read(
161            F::from_canonical_u32(RV32_MEMORY_AS),
162            F::from_canonical_u32(rs2 + RV32_REGISTER_NUM_LIMBS as u32),
163        ));
164
165        match PairingCurve::from_repr(c_upper as usize) {
166            Some(PairingCurve::Bn254) => {
167                use openvm_ecc_guest::halo2curves::bn256::{Fq, Fq12, Fq2};
168                use openvm_pairing_guest::halo2curves_shims::bn254::Bn254;
169                const N: usize = BN254_NUM_LIMBS;
170                if p_len != q_len {
171                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
172                }
173                let p = (0..p_len)
174                    .map(|i| -> eyre::Result<_> {
175                        let ptr = p_ptr + i * 2 * (N as u32);
176                        let x = read_fp::<N, F, Fq>(memory, ptr)?;
177                        let y = read_fp::<N, F, Fq>(memory, ptr + N as u32)?;
178                        Ok(AffinePoint::new(x, y))
179                    })
180                    .collect::<eyre::Result<Vec<_>>>()?;
181                let q = (0..q_len)
182                    .map(|i| -> eyre::Result<_> {
183                        let mut ptr = q_ptr + i * 4 * (N as u32);
184                        let mut read_fp2 = || -> eyre::Result<_> {
185                            let c0 = read_fp::<N, F, Fq>(memory, ptr)?;
186                            let c1 = read_fp::<N, F, Fq>(memory, ptr + N as u32)?;
187                            ptr += 2 * N as u32;
188                            Ok(Fq2::new(c0, c1))
189                        };
190                        let x = read_fp2()?;
191                        let y = read_fp2()?;
192                        Ok(AffinePoint::new(x, y))
193                    })
194                    .collect::<eyre::Result<Vec<_>>>()?;
195
196                let f: Fq12 = Bn254::multi_miller_loop(&p, &q);
197                let (c, u) = Bn254::final_exp_hint(&f);
198                hint_stream.clear();
199                hint_stream.extend(
200                    c.to_coeffs()
201                        .into_iter()
202                        .chain(u.to_coeffs())
203                        .flat_map(|fp2| fp2.to_coeffs())
204                        .flat_map(|fp| fp.to_bytes())
205                        .map(F::from_canonical_u8),
206                );
207            }
208            Some(PairingCurve::Bls12_381) => {
209                use openvm_ecc_guest::halo2curves::bls12_381::{Fq, Fq12, Fq2};
210                use openvm_pairing_guest::halo2curves_shims::bls12_381::Bls12_381;
211                const N: usize = BLS12_381_NUM_LIMBS;
212                if p_len != q_len {
213                    bail!("hint_pairing: p_len={p_len} != q_len={q_len}");
214                }
215                let p = (0..p_len)
216                    .map(|i| -> eyre::Result<_> {
217                        let ptr = p_ptr + i * 2 * (N as u32);
218                        let x = read_fp::<N, F, Fq>(memory, ptr)?;
219                        let y = read_fp::<N, F, Fq>(memory, ptr + N as u32)?;
220                        Ok(AffinePoint::new(x, y))
221                    })
222                    .collect::<eyre::Result<Vec<_>>>()?;
223                let q = (0..q_len)
224                    .map(|i| -> eyre::Result<_> {
225                        let mut ptr = q_ptr + i * 4 * (N as u32);
226                        let mut read_fp2 = || -> eyre::Result<_> {
227                            let c0 = read_fp::<N, F, Fq>(memory, ptr)?;
228                            let c1 = read_fp::<N, F, Fq>(memory, ptr + N as u32)?;
229                            ptr += 2 * N as u32;
230                            Ok(Fq2 { c0, c1 })
231                        };
232                        let x = read_fp2()?;
233                        let y = read_fp2()?;
234                        Ok(AffinePoint::new(x, y))
235                    })
236                    .collect::<eyre::Result<Vec<_>>>()?;
237
238                let f: Fq12 = Bls12_381::multi_miller_loop(&p, &q);
239                let (c, u) = Bls12_381::final_exp_hint(&f);
240                hint_stream.clear();
241                hint_stream.extend(
242                    c.to_coeffs()
243                        .into_iter()
244                        .chain(u.to_coeffs())
245                        .flat_map(|fp2| fp2.to_coeffs())
246                        .flat_map(|fp| fp.to_bytes())
247                        .map(F::from_canonical_u8),
248                );
249            }
250            _ => {
251                bail!("hint_pairing: invalid PairingCurve={c_upper}");
252            }
253        }
254        Ok(())
255    }
256
257    fn read_fp<const N: usize, F: PrimeField32, Fp: ff::PrimeField>(
258        memory: &MemoryController<F>,
259        ptr: u32,
260    ) -> eyre::Result<Fp>
261    where
262        Fp::Repr: From<[u8; N]>,
263    {
264        let mut repr = [0u8; N];
265        for (i, byte) in repr.iter_mut().enumerate() {
266            *byte = memory
267                .unsafe_read_cell(
268                    F::from_canonical_u32(RV32_MEMORY_AS),
269                    F::from_canonical_u32(ptr + i as u32),
270                )
271                .as_canonical_u32()
272                .try_into()?;
273        }
274        Fp::from_repr(repr.into())
275            .into_option()
276            .ok_or(eyre::eyre!("bad ff::PrimeField repr"))
277    }
278}