openvm_ecc_transpiler/
lib.rs1use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3};
2use openvm_instructions::{
3 instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode,
4};
5use openvm_instructions_derive::LocalOpcode;
6use openvm_stark_backend::p3_field::PrimeField32;
7use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
8use rrs_lib::instruction_formats::RType;
9use strum::{EnumCount, EnumIter, FromRepr};
10
11#[derive(
12 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
13)]
14#[opcode_offset = 0x600]
15#[allow(non_camel_case_types)]
16#[repr(usize)]
17pub enum Rv32WeierstrassOpcode {
18 EC_ADD_NE,
19 SETUP_EC_ADD_NE,
20 EC_DOUBLE,
21 SETUP_EC_DOUBLE,
22}
23
24#[derive(Default)]
25pub struct EccTranspilerExtension;
26
27impl<F: PrimeField32> TranspilerExtension<F> for EccTranspilerExtension {
28 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
29 if instruction_stream.is_empty() {
30 return None;
31 }
32 let instruction_u32 = instruction_stream[0];
33 let opcode = (instruction_u32 & 0x7f) as u8;
34 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
35
36 if opcode != OPCODE {
37 return None;
38 }
39 if funct3 != SW_FUNCT3 {
40 return None;
41 }
42
43 let instruction = {
44 assert!(
46 Rv32WeierstrassOpcode::COUNT <= SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize
47 );
48 let dec_insn = RType::new(instruction_u32);
49 let base_funct7 = (dec_insn.funct7 as u8) % SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS;
50 let curve_idx =
51 ((dec_insn.funct7 as u8) / SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS) as usize;
52 let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT;
53 if base_funct7 == SwBaseFunct7::SwSetup as u8 {
54 let local_opcode = match dec_insn.rs2 {
55 0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE,
56 _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE,
57 };
58 Some(Instruction::new(
59 VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift),
60 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
61 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
62 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
63 F::ONE, F::TWO, F::ZERO,
66 F::ZERO,
67 ))
68 } else {
69 let global_opcode = match SwBaseFunct7::from_repr(base_funct7) {
70 Some(SwBaseFunct7::SwAddNe) => {
71 Rv32WeierstrassOpcode::EC_ADD_NE as usize
72 + Rv32WeierstrassOpcode::CLASS_OFFSET
73 }
74 Some(SwBaseFunct7::SwDouble) => {
75 assert!(dec_insn.rs2 == 0);
76 Rv32WeierstrassOpcode::EC_DOUBLE as usize
77 + Rv32WeierstrassOpcode::CLASS_OFFSET
78 }
79 _ => unimplemented!(),
80 };
81 let global_opcode = global_opcode + curve_idx_shift;
82 Some(from_r_type(global_opcode, 2, &dec_insn, true))
83 }
84 };
85 instruction.map(TranspilerOutput::one_to_one)
86 }
87}