openvm_algebra_transpiler/
lib.rs
1use openvm_algebra_guest::{
2 ComplexExtFieldBaseFunct7, ModArithBaseFunct7, COMPLEX_EXT_FIELD_FUNCT3,
3 MODULAR_ARITHMETIC_FUNCT3, OPCODE,
4};
5use openvm_instructions::{
6 instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode,
7};
8use openvm_instructions_derive::LocalOpcode;
9use openvm_stark_backend::p3_field::PrimeField32;
10use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
11use rrs_lib::instruction_formats::RType;
12use strum::{EnumCount, EnumIter, FromRepr};
13
14#[derive(
15 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
16)]
17#[opcode_offset = 0x500]
18#[repr(usize)]
19#[allow(non_camel_case_types)]
20pub enum Rv32ModularArithmeticOpcode {
21 ADD,
22 SUB,
23 SETUP_ADDSUB,
24 MUL,
25 DIV,
26 SETUP_MULDIV,
27 IS_EQ,
28 SETUP_ISEQ,
29}
30
31#[derive(
32 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
33)]
34#[opcode_offset = 0x710]
35#[repr(usize)]
36#[allow(non_camel_case_types)]
37pub enum Fp2Opcode {
38 ADD,
39 SUB,
40 SETUP_ADDSUB,
41 MUL,
42 DIV,
43 SETUP_MULDIV,
44}
45
46#[derive(Default)]
47pub struct ModularTranspilerExtension;
48
49#[derive(Default)]
50pub struct Fp2TranspilerExtension;
51
52impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
53 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
54 if instruction_stream.is_empty() {
55 return None;
56 }
57 let instruction_u32 = instruction_stream[0];
58 let opcode = (instruction_u32 & 0x7f) as u8;
59 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
60
61 if opcode != OPCODE {
62 return None;
63 }
64 if funct3 != MODULAR_ARITHMETIC_FUNCT3 {
65 return None;
66 }
67
68 let instruction = {
69 let dec_insn = RType::new(instruction_u32);
70 let base_funct7 =
71 (dec_insn.funct7 as u8) % ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS;
72 assert!(
73 Rv32ModularArithmeticOpcode::COUNT
74 <= ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize
75 );
76 let mod_idx_shift = ((dec_insn.funct7 as u8)
77 / ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS)
78 as usize
79 * Rv32ModularArithmeticOpcode::COUNT;
80 if base_funct7 == ModArithBaseFunct7::SetupMod as u8 {
81 let local_opcode = match dec_insn.rs2 {
82 0 => Rv32ModularArithmeticOpcode::SETUP_ADDSUB,
83 1 => Rv32ModularArithmeticOpcode::SETUP_MULDIV,
84 2 => Rv32ModularArithmeticOpcode::SETUP_ISEQ,
85 _ => panic!("invalid opcode"),
86 };
87 if local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ && dec_insn.rd == 0 {
88 panic!("SETUP_ISEQ is not valid for rd = x0");
89 } else {
90 Some(Instruction::new(
91 VmOpcode::from_usize(
92 local_opcode.global_opcode().as_usize() + mod_idx_shift,
93 ),
94 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
95 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
96 F::ZERO, F::ONE, F::TWO, F::ZERO,
100 F::ZERO,
101 ))
102 }
103 } else {
104 let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) {
105 Some(ModArithBaseFunct7::AddMod) => {
106 Rv32ModularArithmeticOpcode::ADD as usize
107 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
108 }
109 Some(ModArithBaseFunct7::SubMod) => {
110 Rv32ModularArithmeticOpcode::SUB as usize
111 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
112 }
113 Some(ModArithBaseFunct7::MulMod) => {
114 Rv32ModularArithmeticOpcode::MUL as usize
115 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
116 }
117 Some(ModArithBaseFunct7::DivMod) => {
118 Rv32ModularArithmeticOpcode::DIV as usize
119 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
120 }
121 Some(ModArithBaseFunct7::IsEqMod) => {
122 Rv32ModularArithmeticOpcode::IS_EQ as usize
123 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
124 }
125 _ => unimplemented!(),
126 };
127 let global_opcode = global_opcode + mod_idx_shift;
128 let allow_rd_zero =
131 ModArithBaseFunct7::from_repr(base_funct7) != Some(ModArithBaseFunct7::IsEqMod);
132 Some(from_r_type(global_opcode, 2, &dec_insn, allow_rd_zero))
133 }
134 };
135 instruction.map(TranspilerOutput::one_to_one)
136 }
137}
138
139impl<F: PrimeField32> TranspilerExtension<F> for Fp2TranspilerExtension {
140 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
141 if instruction_stream.is_empty() {
142 return None;
143 }
144 let instruction_u32 = instruction_stream[0];
145 let opcode = (instruction_u32 & 0x7f) as u8;
146 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
147
148 if opcode != OPCODE {
149 return None;
150 }
151 if funct3 != COMPLEX_EXT_FIELD_FUNCT3 {
152 return None;
153 }
154
155 let instruction = {
156 assert!(
157 Fp2Opcode::COUNT <= ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize
158 );
159 let dec_insn = RType::new(instruction_u32);
160 let base_funct7 =
161 (dec_insn.funct7 as u8) % ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS;
162 let complex_idx_shift = ((dec_insn.funct7 as u8)
163 / ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS)
164 as usize
165 * Fp2Opcode::COUNT;
166
167 if base_funct7 == ComplexExtFieldBaseFunct7::Setup as u8 {
168 let local_opcode = match dec_insn.rs2 {
169 0 => Fp2Opcode::SETUP_ADDSUB,
170 1 => Fp2Opcode::SETUP_MULDIV,
171 _ => panic!("invalid opcode"),
172 };
173 Some(Instruction::new(
174 VmOpcode::from_usize(
175 local_opcode.global_opcode().as_usize() + complex_idx_shift,
176 ),
177 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
178 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
179 F::ZERO, F::ONE, F::TWO, F::ZERO,
183 F::ZERO,
184 ))
185 } else {
186 let global_opcode = match ComplexExtFieldBaseFunct7::from_repr(base_funct7) {
187 Some(ComplexExtFieldBaseFunct7::Add) => {
188 Fp2Opcode::ADD as usize + Fp2Opcode::CLASS_OFFSET
189 }
190 Some(ComplexExtFieldBaseFunct7::Sub) => {
191 Fp2Opcode::SUB as usize + Fp2Opcode::CLASS_OFFSET
192 }
193 Some(ComplexExtFieldBaseFunct7::Mul) => {
194 Fp2Opcode::MUL as usize + Fp2Opcode::CLASS_OFFSET
195 }
196 Some(ComplexExtFieldBaseFunct7::Div) => {
197 Fp2Opcode::DIV as usize + Fp2Opcode::CLASS_OFFSET
198 }
199 _ => unimplemented!(),
200 };
201 let global_opcode = global_opcode + complex_idx_shift;
202 Some(from_r_type(global_opcode, 2, &dec_insn, true))
203 }
204 };
205 instruction.map(TranspilerOutput::one_to_one)
206 }
207}