openvm_algebra_transpiler/
lib.rs1use openvm_algebra_guest::{
2 ComplexExtFieldBaseFunct7, ModArithBaseFunct7, COMPLEX_EXT_FIELD_FUNCT3,
3 MODULAR_ARITHMETIC_FUNCT3, OPCODE,
4};
5use openvm_instructions::{
6 instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant,
7 VmOpcode,
8};
9use openvm_instructions_derive::LocalOpcode;
10use openvm_stark_backend::p3_field::PrimeField32;
11use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
12use rrs_lib::instruction_formats::RType;
13use strum::{EnumCount, EnumIter, FromRepr};
14
15#[derive(
16 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
17)]
18#[opcode_offset = 0x500]
19#[repr(usize)]
20#[allow(non_camel_case_types)]
21pub enum Rv32ModularArithmeticOpcode {
22 ADD,
23 SUB,
24 SETUP_ADDSUB,
25 MUL,
26 DIV,
27 SETUP_MULDIV,
28 IS_EQ,
29 SETUP_ISEQ,
30}
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)]
33#[repr(u16)]
34pub enum ModularPhantom {
35 HintNonQr = 0x50,
36 HintSqrt = 0x51,
37}
38
39#[derive(
40 Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
41)]
42#[opcode_offset = 0x710]
43#[repr(usize)]
44#[allow(non_camel_case_types)]
45pub enum Fp2Opcode {
46 ADD,
47 SUB,
48 SETUP_ADDSUB,
49 MUL,
50 DIV,
51 SETUP_MULDIV,
52}
53
54#[derive(Default)]
55pub struct ModularTranspilerExtension;
56
57#[derive(Default)]
58pub struct Fp2TranspilerExtension;
59
60impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
61 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
62 if instruction_stream.is_empty() {
63 return None;
64 }
65 let instruction_u32 = instruction_stream[0];
66 let opcode = (instruction_u32 & 0x7f) as u8;
67 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
68
69 if opcode != OPCODE {
70 return None;
71 }
72 if funct3 != MODULAR_ARITHMETIC_FUNCT3 {
73 return None;
74 }
75
76 let instruction = {
77 let dec_insn = RType::new(instruction_u32);
78 let base_funct7 =
79 (dec_insn.funct7 as u8) % ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS;
80 assert!(
81 Rv32ModularArithmeticOpcode::COUNT
82 <= ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize
83 );
84 let mod_idx = ((dec_insn.funct7 as u8)
85 / ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS)
86 as usize;
87 let mod_idx_shift = mod_idx * Rv32ModularArithmeticOpcode::COUNT;
88 if base_funct7 == ModArithBaseFunct7::SetupMod as u8 {
89 let local_opcode = match dec_insn.rs2 {
90 0 => Rv32ModularArithmeticOpcode::SETUP_ADDSUB,
91 1 => Rv32ModularArithmeticOpcode::SETUP_MULDIV,
92 2 => Rv32ModularArithmeticOpcode::SETUP_ISEQ,
93 _ => panic!("invalid opcode"),
94 };
95 if local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ && dec_insn.rd == 0 {
96 panic!("SETUP_ISEQ is not valid for rd = x0");
97 } else {
98 Some(Instruction::new(
99 VmOpcode::from_usize(
100 local_opcode.global_opcode().as_usize() + mod_idx_shift,
101 ),
102 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
103 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
104 F::ZERO, F::ONE, F::TWO, F::ZERO,
108 F::ZERO,
109 ))
110 }
111 } else if base_funct7 == ModArithBaseFunct7::HintNonQr as u8 {
112 assert_eq!(dec_insn.rd, 0);
113 assert_eq!(dec_insn.rs1, 0);
114 assert_eq!(dec_insn.rs2, 0);
115 Some(Instruction::phantom(
116 PhantomDiscriminant(ModularPhantom::HintNonQr as u16),
117 F::ZERO,
118 F::ZERO,
119 mod_idx as u16,
120 ))
121 } else if base_funct7 == ModArithBaseFunct7::HintSqrt as u8 {
122 assert_eq!(dec_insn.rd, 0);
123 assert_eq!(dec_insn.rs2, 0);
124 Some(Instruction::phantom(
125 PhantomDiscriminant(ModularPhantom::HintSqrt as u16),
126 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
127 F::ZERO,
128 mod_idx as u16,
129 ))
130 } else {
131 let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) {
132 Some(ModArithBaseFunct7::AddMod) => {
133 Rv32ModularArithmeticOpcode::ADD as usize
134 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
135 }
136 Some(ModArithBaseFunct7::SubMod) => {
137 Rv32ModularArithmeticOpcode::SUB as usize
138 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
139 }
140 Some(ModArithBaseFunct7::MulMod) => {
141 Rv32ModularArithmeticOpcode::MUL as usize
142 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
143 }
144 Some(ModArithBaseFunct7::DivMod) => {
145 Rv32ModularArithmeticOpcode::DIV as usize
146 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
147 }
148 Some(ModArithBaseFunct7::IsEqMod) => {
149 Rv32ModularArithmeticOpcode::IS_EQ as usize
150 + Rv32ModularArithmeticOpcode::CLASS_OFFSET
151 }
152 _ => unimplemented!(),
153 };
154 let global_opcode = global_opcode + mod_idx_shift;
155 let allow_rd_zero =
158 ModArithBaseFunct7::from_repr(base_funct7) != Some(ModArithBaseFunct7::IsEqMod);
159 Some(from_r_type(global_opcode, 2, &dec_insn, allow_rd_zero))
160 }
161 };
162 instruction.map(TranspilerOutput::one_to_one)
163 }
164}
165
166impl<F: PrimeField32> TranspilerExtension<F> for Fp2TranspilerExtension {
167 fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
168 if instruction_stream.is_empty() {
169 return None;
170 }
171 let instruction_u32 = instruction_stream[0];
172 let opcode = (instruction_u32 & 0x7f) as u8;
173 let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
174
175 if opcode != OPCODE {
176 return None;
177 }
178 if funct3 != COMPLEX_EXT_FIELD_FUNCT3 {
179 return None;
180 }
181
182 let instruction = {
183 assert!(
184 Fp2Opcode::COUNT <= ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize
185 );
186 let dec_insn = RType::new(instruction_u32);
187 let base_funct7 =
188 (dec_insn.funct7 as u8) % ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS;
189 let complex_idx_shift = ((dec_insn.funct7 as u8)
190 / ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS)
191 as usize
192 * Fp2Opcode::COUNT;
193
194 if base_funct7 == ComplexExtFieldBaseFunct7::Setup as u8 {
195 let local_opcode = match dec_insn.rs2 {
196 0 => Fp2Opcode::SETUP_ADDSUB,
197 1 => Fp2Opcode::SETUP_MULDIV,
198 _ => panic!("invalid opcode"),
199 };
200 Some(Instruction::new(
201 VmOpcode::from_usize(
202 local_opcode.global_opcode().as_usize() + complex_idx_shift,
203 ),
204 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
205 F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
206 F::ZERO, F::ONE, F::TWO, F::ZERO,
210 F::ZERO,
211 ))
212 } else {
213 let global_opcode = match ComplexExtFieldBaseFunct7::from_repr(base_funct7) {
214 Some(ComplexExtFieldBaseFunct7::Add) => {
215 Fp2Opcode::ADD as usize + Fp2Opcode::CLASS_OFFSET
216 }
217 Some(ComplexExtFieldBaseFunct7::Sub) => {
218 Fp2Opcode::SUB as usize + Fp2Opcode::CLASS_OFFSET
219 }
220 Some(ComplexExtFieldBaseFunct7::Mul) => {
221 Fp2Opcode::MUL as usize + Fp2Opcode::CLASS_OFFSET
222 }
223 Some(ComplexExtFieldBaseFunct7::Div) => {
224 Fp2Opcode::DIV as usize + Fp2Opcode::CLASS_OFFSET
225 }
226 _ => unimplemented!(),
227 };
228 let global_opcode = global_opcode + complex_idx_shift;
229 Some(from_r_type(global_opcode, 2, &dec_insn, true))
230 }
231 };
232 instruction.map(TranspilerOutput::one_to_one)
233 }
234}