openvm_pairing_transpiler/
lib.rs

1use openvm_instructions::{
2    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant,
3};
4use openvm_instructions_derive::LocalOpcode;
5use openvm_pairing_guest::{PairingBaseFunct7, OPCODE, PAIRING_FUNCT3};
6use openvm_stark_backend::p3_field::PrimeField32;
7use openvm_transpiler::{TranspilerExtension, TranspilerOutput};
8use rrs_lib::instruction_formats::RType;
9use strum::{EnumCount, EnumIter, FromRepr};
10
11// NOTE: the following opcodes are enabled only in testing and not enabled in the VM Extension
12#[derive(
13    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
14)]
15#[opcode_offset = 0x750]
16#[repr(usize)]
17#[allow(non_camel_case_types)]
18pub enum PairingOpcode {
19    MILLER_DOUBLE_AND_ADD_STEP,
20    MILLER_DOUBLE_STEP,
21    EVALUATE_LINE,
22    MUL_013_BY_013,
23    MUL_023_BY_023,
24    MUL_BY_01234,
25    MUL_BY_02345,
26}
27
28// NOTE: Fp12 opcodes are only enabled in testing and not enabled in the VM Extension
29#[derive(
30    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
31)]
32#[opcode_offset = 0x700]
33#[repr(usize)]
34#[allow(non_camel_case_types)]
35pub enum Fp12Opcode {
36    ADD,
37    SUB,
38    MUL,
39}
40const FP12_OPS: usize = 4;
41
42pub struct Bn254Fp12Opcode(Fp12Opcode);
43
44impl LocalOpcode for Bn254Fp12Opcode {
45    const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET;
46
47    fn from_usize(value: usize) -> Self {
48        Self(Fp12Opcode::from_usize(value))
49    }
50
51    fn local_usize(&self) -> usize {
52        self.0.local_usize()
53    }
54}
55
56pub struct Bls12381Fp12Opcode(Fp12Opcode);
57
58impl LocalOpcode for Bls12381Fp12Opcode {
59    const CLASS_OFFSET: usize = Fp12Opcode::CLASS_OFFSET + FP12_OPS;
60
61    fn from_usize(value: usize) -> Self {
62        Self(Fp12Opcode::from_usize(value - FP12_OPS))
63    }
64
65    fn local_usize(&self) -> usize {
66        self.0.local_usize() + FP12_OPS
67    }
68}
69
70#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)]
71#[repr(u16)]
72pub enum PairingPhantom {
73    /// Uses `b` to determine the curve: `b` is the discriminant of `PairingCurve` kind.
74    /// Peeks at `[r32{0}(a)..r32{0}(a) + Fp::NUM_LIMBS * 12]_2` to get `f: Fp12` and then resets the hint stream to equal `final_exp_hint(f) = (residue_witness, scaling_factor): (Fp12, Fp12)` as `Fp::NUM_LIMBS * 12 * 2` bytes.
75    HintFinalExp = 0x30,
76}
77
78#[derive(Default)]
79pub struct PairingTranspilerExtension;
80
81impl<F: PrimeField32> TranspilerExtension<F> for PairingTranspilerExtension {
82    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
83        if instruction_stream.is_empty() {
84            return None;
85        }
86        let instruction_u32 = instruction_stream[0];
87        let opcode = (instruction_u32 & 0x7f) as u8;
88        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
89
90        if opcode != OPCODE {
91            return None;
92        }
93        if funct3 != PAIRING_FUNCT3 {
94            return None;
95        }
96
97        let dec_insn = RType::new(instruction_u32);
98        let base_funct7 = (dec_insn.funct7 as u8) % PairingBaseFunct7::PAIRING_MAX_KINDS;
99        let pairing_idx = ((dec_insn.funct7 as u8) / PairingBaseFunct7::PAIRING_MAX_KINDS) as usize;
100        if let Some(PairingBaseFunct7::HintFinalExp) = PairingBaseFunct7::from_repr(base_funct7) {
101            assert_eq!(dec_insn.rd, 0);
102            // Return exits the outermost function
103            return Some(TranspilerOutput::one_to_one(Instruction::phantom(
104                PhantomDiscriminant(PairingPhantom::HintFinalExp as u16),
105                F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
106                F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
107                pairing_idx as u16,
108            )));
109        }
110        None
111    }
112}