openvm_ecc_transpiler/
lib.rs

1use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3};
2use openvm_instructions::{
3    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode,
4};
5use openvm_instructions_derive::LocalOpcode;
6use openvm_stark_backend::p3_field::PrimeField32;
7use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
8use rrs_lib::instruction_formats::RType;
9use strum::{EnumCount, EnumIter, FromRepr};
10
11#[derive(
12    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
13)]
14#[opcode_offset = 0x600]
15#[allow(non_camel_case_types)]
16#[repr(usize)]
17pub enum Rv32WeierstrassOpcode {
18    EC_ADD_NE,
19    SETUP_EC_ADD_NE,
20    EC_DOUBLE,
21    SETUP_EC_DOUBLE,
22}
23
24#[derive(Default)]
25pub struct EccTranspilerExtension;
26
27impl<F: PrimeField32> TranspilerExtension<F> for EccTranspilerExtension {
28    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
29        if instruction_stream.is_empty() {
30            return None;
31        }
32        let instruction_u32 = instruction_stream[0];
33        let opcode = (instruction_u32 & 0x7f) as u8;
34        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
35
36        if opcode != OPCODE {
37            return None;
38        }
39        if funct3 != SW_FUNCT3 {
40            return None;
41        }
42
43        let instruction = {
44            // short weierstrass ec
45            assert!(
46                Rv32WeierstrassOpcode::COUNT <= SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize
47            );
48            let dec_insn = RType::new(instruction_u32);
49            let base_funct7 = (dec_insn.funct7 as u8) % SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS;
50            let curve_idx =
51                ((dec_insn.funct7 as u8) / SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS) as usize;
52            let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT;
53            if base_funct7 == SwBaseFunct7::SwSetup as u8 {
54                let local_opcode = match dec_insn.rs2 {
55                    0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE,
56                    _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE,
57                };
58                Some(Instruction::new(
59                    VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift),
60                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
61                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
62                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
63                    F::ONE, // d_as = 1
64                    F::TWO, // e_as = 2
65                    F::ZERO,
66                    F::ZERO,
67                ))
68            } else {
69                let global_opcode = match SwBaseFunct7::from_repr(base_funct7) {
70                    Some(SwBaseFunct7::SwAddNe) => {
71                        Rv32WeierstrassOpcode::EC_ADD_NE as usize
72                            + Rv32WeierstrassOpcode::CLASS_OFFSET
73                    }
74                    Some(SwBaseFunct7::SwDouble) => {
75                        assert!(dec_insn.rs2 == 0);
76                        Rv32WeierstrassOpcode::EC_DOUBLE as usize
77                            + Rv32WeierstrassOpcode::CLASS_OFFSET
78                    }
79                    _ => unimplemented!(),
80                };
81                let global_opcode = global_opcode + curve_idx_shift;
82                Some(from_r_type(global_opcode, 2, &dec_insn, true))
83            }
84        };
85        instruction.map(TranspilerOutput::one_to_one)
86    }
87}