openvm_ecc_transpiler/
lib.rs
1use openvm_ecc_guest::{SwBaseFunct7, OPCODE, SW_FUNCT3};
2use openvm_instructions::{
3 instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant,
4 VmOpcode,
5};
6use openvm_instructions_derive::LocalOpcode;
7use openvm_stark_backend::p3_field::PrimeField32;
8use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
9use rrs_lib::instruction_formats::RType;
10use strum::{EnumCount, EnumIter, FromRepr};
11
12#[derive(
13 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
14)]
15#[opcode_offset = 0x600]
16#[allow(non_camel_case_types)]
17#[repr(usize)]
18pub enum Rv32WeierstrassOpcode {
19 EC_ADD_NE,
20 SETUP_EC_ADD_NE,
21 EC_DOUBLE,
22 SETUP_EC_DOUBLE,
23}
24
25#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)]
26#[repr(u16)]
27pub enum EccPhantom {
28 HintDecompress = 0x40,
29 HintNonQr = 0x41,
30}
31
32#[derive(Default)]
33pub struct EccTranspilerExtension;
34
35impl<F: PrimeField32> TranspilerExtension<F> for EccTranspilerExtension {
36 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
37 if instruction_stream.is_empty() {
38 return None;
39 }
40 let instruction_u32 = instruction_stream[0];
41 let opcode = (instruction_u32 & 0x7f) as u8;
42 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
43
44 if opcode != OPCODE {
45 return None;
46 }
47 if funct3 != SW_FUNCT3 {
48 return None;
49 }
50
51 let instruction = {
52 assert!(
54 Rv32WeierstrassOpcode::COUNT <= SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize
55 );
56 let dec_insn = RType::new(instruction_u32);
57 let base_funct7 = (dec_insn.funct7 as u8) % SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS;
58 let curve_idx =
59 ((dec_insn.funct7 as u8) / SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS) as usize;
60 let curve_idx_shift = curve_idx * Rv32WeierstrassOpcode::COUNT;
61 if let Some(SwBaseFunct7::HintDecompress) = SwBaseFunct7::from_repr(base_funct7) {
62 assert_eq!(dec_insn.rd, 0);
63 return Some(TranspilerOutput::one_to_one(Instruction::phantom(
64 PhantomDiscriminant(EccPhantom::HintDecompress as u16),
65 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
66 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
67 curve_idx as u16,
68 )));
69 }
70 if let Some(SwBaseFunct7::HintNonQr) = SwBaseFunct7::from_repr(base_funct7) {
71 assert_eq!(dec_insn.rd, 0);
72 assert_eq!(dec_insn.rs1, 0);
73 assert_eq!(dec_insn.rs2, 0);
74 return Some(TranspilerOutput::one_to_one(Instruction::phantom(
75 PhantomDiscriminant(EccPhantom::HintNonQr as u16),
76 F::ZERO,
77 F::ZERO,
78 curve_idx as u16,
79 )));
80 }
81 if base_funct7 == SwBaseFunct7::SwSetup as u8 {
82 let local_opcode = match dec_insn.rs2 {
83 0 => Rv32WeierstrassOpcode::SETUP_EC_DOUBLE,
84 _ => Rv32WeierstrassOpcode::SETUP_EC_ADD_NE,
85 };
86 Some(Instruction::new(
87 VmOpcode::from_usize(local_opcode.global_opcode().as_usize() + curve_idx_shift),
88 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
89 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
90 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
91 F::ONE, F::TWO, F::ZERO,
94 F::ZERO,
95 ))
96 } else {
97 let global_opcode = match SwBaseFunct7::from_repr(base_funct7) {
98 Some(SwBaseFunct7::SwAddNe) => {
99 Rv32WeierstrassOpcode::EC_ADD_NE as usize
100 + Rv32WeierstrassOpcode::CLASS_OFFSET
101 }
102 Some(SwBaseFunct7::SwDouble) => {
103 assert!(dec_insn.rs2 == 0);
104 Rv32WeierstrassOpcode::EC_DOUBLE as usize
105 + Rv32WeierstrassOpcode::CLASS_OFFSET
106 }
107 _ => unimplemented!(),
108 };
109 let global_opcode = global_opcode + curve_idx_shift;
110 Some(from_r_type(global_opcode, 2, &dec_insn, true))
111 }
112 };
113 instruction.map(TranspilerOutput::one_to_one)
114 }
115}