openvm_ecc_transpiler/
lib.rs

1use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3};
2use openvm_instructions::{
3    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant,
4    VmOpcode,
5};
6use openvm_instructions_derive::LocalOpcode;
7use openvm_stark_backend::p3_field::PrimeField32;
8use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
9use rrs_lib::instruction_formats::RType;
10use strum::{EnumCount, EnumIter, FromRepr};
11
12#[derive(
13    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
14)]
15#[opcode_offset = 0x600]
16#[allow(non_camel_case_types)]
17#[repr(usize)]
18pub enum Rv32WeierstrassOpcode {
19    EC_ADD_NE,
20    SETUP_EC_ADD_NE,
21    EC_DOUBLE,
22    SETUP_EC_DOUBLE,
23}
24
25#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)]
26#[repr(u16)]
27pub enum EccPhantom {
28    HintDecompress = 0x40,
29    HintNonQr = 0x41,
30}
31
32#[derive(Default)]
33pub struct EccTranspilerExtension;
34
35impl<F: PrimeField32> TranspilerExtension<F> for EccTranspilerExtension {
36    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
37        if instruction_stream.is_empty() {
38            return None;
39        }
40        let instruction_u32 = instruction_stream[0];
41        let opcode = (instruction_u32 & 0x7f) as u8;
42        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
43
44        if opcode != OPCODE {
45            return None;
46        }
47        if funct3 != SW_FUNCT3 {
48            return None;
49        }
50
51        let instruction = {
52            // short weierstrass ec
53            assert!(
54                Rv32WeierstrassOpcode::COUNT <= SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize
55            );
56            let dec_insn = RType::new(instruction_u32);
57            let base_funct7 = (dec_insn.funct7 as u8) % SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS;
58            let curve_idx =
59                ((dec_insn.funct7 as u8) / SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS) as usize;
60            let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT;
61            if let Some(SwBaseFunct7::HintDecompress) = SwBaseFunct7::from_repr(base_funct7) {
62                assert_eq!(dec_insn.rd, 0);
63                return Some(TranspilerOutput::one_to_one(Instruction::phantom(
64                    PhantomDiscriminant(EccPhantom::HintDecompress as u16),
65                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
66                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
67                    curve_idx as u16,
68                )));
69            }
70            if let Some(SwBaseFunct7::HintNonQr) = SwBaseFunct7::from_repr(base_funct7) {
71                assert_eq!(dec_insn.rd, 0);
72                assert_eq!(dec_insn.rs1, 0);
73                assert_eq!(dec_insn.rs2, 0);
74                return Some(TranspilerOutput::one_to_one(Instruction::phantom(
75                    PhantomDiscriminant(EccPhantom::HintNonQr as u16),
76                    F::ZERO,
77                    F::ZERO,
78                    curve_idx as u16,
79                )));
80            }
81            if base_funct7 == SwBaseFunct7::SwSetup as u8 {
82                let local_opcode = match dec_insn.rs2 {
83                    0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE,
84                    _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE,
85                };
86                Some(Instruction::new(
87                    VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift),
88                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
89                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
90                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
91                    F::ONE, // d_as = 1
92                    F::TWO, // e_as = 2
93                    F::ZERO,
94                    F::ZERO,
95                ))
96            } else {
97                let global_opcode = match SwBaseFunct7::from_repr(base_funct7) {
98                    Some(SwBaseFunct7::SwAddNe) => {
99                        Rv32WeierstrassOpcode::EC_ADD_NE as usize
100                            + Rv32WeierstrassOpcode::CLASS_OFFSET
101                    }
102                    Some(SwBaseFunct7::SwDouble) => {
103                        assert!(dec_insn.rs2 == 0);
104                        Rv32WeierstrassOpcode::EC_DOUBLE as usize
105                            + Rv32WeierstrassOpcode::CLASS_OFFSET
106                    }
107                    _ => unimplemented!(),
108                };
109                let global_opcode = global_opcode + curve_idx_shift;
110                Some(from_r_type(global_opcode, 2, &dec_insn, true))
111            }
112        };
113        instruction.map(TranspilerOutput::one_to_one)
114    }
115}