openvm_algebra_transpiler/
lib.rs

1use openvm_algebra_guest::{
2    ComplexExtFieldBaseFunct7, ModArithBaseFunct7, COMPLEX_EXT_FIELD_FUNCT3,
3    MODULAR_ARITHMETIC_FUNCT3, OPCODE,
4};
5use openvm_instructions::{
6    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, PhantomDiscriminant,
7    VmOpcode,
8};
9use openvm_instructions_derive::LocalOpcode;
10use openvm_stark_backend::p3_field::PrimeField32;
11use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
12use rrs_lib::instruction_formats::RType;
13use strum::{EnumCount, EnumIter, FromRepr};
14
15#[derive(
16    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
17)]
18#[opcode_offset = 0x500]
19#[repr(usize)]
20#[allow(non_camel_case_types)]
21pub enum Rv32ModularArithmeticOpcode {
22    ADD,
23    SUB,
24    SETUP_ADDSUB,
25    MUL,
26    DIV,
27    SETUP_MULDIV,
28    IS_EQ,
29    SETUP_ISEQ,
30}
31
32#[derive(Copy, Clone, Debug, PartialEq, Eq, FromRepr)]
33#[repr(u16)]
34pub enum ModularPhantom {
35    HintNonQr = 0x50,
36    HintSqrt = 0x51,
37}
38
39#[derive(
40    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
41)]
42#[opcode_offset = 0x710]
43#[repr(usize)]
44#[allow(non_camel_case_types)]
45pub enum Fp2Opcode {
46    ADD,
47    SUB,
48    SETUP_ADDSUB,
49    MUL,
50    DIV,
51    SETUP_MULDIV,
52}
53
54#[derive(Default)]
55pub struct ModularTranspilerExtension;
56
57#[derive(Default)]
58pub struct Fp2TranspilerExtension;
59
60impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
61    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
62        if instruction_stream.is_empty() {
63            return None;
64        }
65        let instruction_u32 = instruction_stream[0];
66        let opcode = (instruction_u32 & 0x7f) as u8;
67        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
68
69        if opcode != OPCODE {
70            return None;
71        }
72        if funct3 != MODULAR_ARITHMETIC_FUNCT3 {
73            return None;
74        }
75
76        let instruction = {
77            let dec_insn = RType::new(instruction_u32);
78            let base_funct7 =
79                (dec_insn.funct7 as u8) % ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS;
80            assert!(
81                Rv32ModularArithmeticOpcode::COUNT
82                    <= ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize
83            );
84            let mod_idx = ((dec_insn.funct7 as u8)
85                / ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS)
86                as usize;
87            let mod_idx_shift = mod_idx * Rv32ModularArithmeticOpcode::COUNT;
88            if base_funct7 == ModArithBaseFunct7::SetupMod as u8 {
89                let local_opcode = match dec_insn.rs2 {
90                    0 => Rv32ModularArithmeticOpcode::SETUP_ADDSUB,
91                    1 => Rv32ModularArithmeticOpcode::SETUP_MULDIV,
92                    2 => Rv32ModularArithmeticOpcode::SETUP_ISEQ,
93                    _ => panic!("invalid opcode"),
94                };
95                if local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ && dec_insn.rd == 0 {
96                    panic!("SETUP_ISEQ is not valid for rd = x0");
97                } else {
98                    Some(Instruction::new(
99                        VmOpcode::from_usize(
100                            local_opcode.global_opcode().as_usize() + mod_idx_shift,
101                        ),
102                        F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
103                        F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
104                        F::ZERO, // rs2 = 0
105                        F::ONE,  // d_as = 1
106                        F::TWO,  // e_as = 2
107                        F::ZERO,
108                        F::ZERO,
109                    ))
110                }
111            } else if base_funct7 == ModArithBaseFunct7::HintNonQr as u8 {
112                assert_eq!(dec_insn.rd, 0);
113                assert_eq!(dec_insn.rs1, 0);
114                assert_eq!(dec_insn.rs2, 0);
115                Some(Instruction::phantom(
116                    PhantomDiscriminant(ModularPhantom::HintNonQr as u16),
117                    F::ZERO,
118                    F::ZERO,
119                    mod_idx as u16,
120                ))
121            } else if base_funct7 == ModArithBaseFunct7::HintSqrt as u8 {
122                assert_eq!(dec_insn.rd, 0);
123                assert_eq!(dec_insn.rs2, 0);
124                Some(Instruction::phantom(
125                    PhantomDiscriminant(ModularPhantom::HintSqrt as u16),
126                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
127                    F::ZERO,
128                    mod_idx as u16,
129                ))
130            } else {
131                let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) {
132                    Some(ModArithBaseFunct7::AddMod) => {
133                        Rv32ModularArithmeticOpcode::ADD as usize
134                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
135                    }
136                    Some(ModArithBaseFunct7::SubMod) => {
137                        Rv32ModularArithmeticOpcode::SUB as usize
138                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
139                    }
140                    Some(ModArithBaseFunct7::MulMod) => {
141                        Rv32ModularArithmeticOpcode::MUL as usize
142                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
143                    }
144                    Some(ModArithBaseFunct7::DivMod) => {
145                        Rv32ModularArithmeticOpcode::DIV as usize
146                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
147                    }
148                    Some(ModArithBaseFunct7::IsEqMod) => {
149                        Rv32ModularArithmeticOpcode::IS_EQ as usize
150                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
151                    }
152                    _ => unimplemented!(),
153                };
154                let global_opcode = global_opcode + mod_idx_shift;
155                // The only opcode in this extension which can write to rd is `IsEqMod`
156                // so we cannot allow rd to be zero in this case.
157                let allow_rd_zero =
158                    ModArithBaseFunct7::from_repr(base_funct7) != Some(ModArithBaseFunct7::IsEqMod);
159                Some(from_r_type(global_opcode, 2, &dec_insn, allow_rd_zero))
160            }
161        };
162        instruction.map(TranspilerOutput::one_to_one)
163    }
164}
165
166impl<F: PrimeField32> TranspilerExtension<F> for Fp2TranspilerExtension {
167    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
168        if instruction_stream.is_empty() {
169            return None;
170        }
171        let instruction_u32 = instruction_stream[0];
172        let opcode = (instruction_u32 & 0x7f) as u8;
173        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
174
175        if opcode != OPCODE {
176            return None;
177        }
178        if funct3 != COMPLEX_EXT_FIELD_FUNCT3 {
179            return None;
180        }
181
182        let instruction = {
183            assert!(
184                Fp2Opcode::COUNT <= ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize
185            );
186            let dec_insn = RType::new(instruction_u32);
187            let base_funct7 =
188                (dec_insn.funct7 as u8) % ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS;
189            let complex_idx_shift = ((dec_insn.funct7 as u8)
190                / ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS)
191                as usize
192                * Fp2Opcode::COUNT;
193
194            if base_funct7 == ComplexExtFieldBaseFunct7::Setup as u8 {
195                let local_opcode = match dec_insn.rs2 {
196                    0 => Fp2Opcode::SETUP_ADDSUB,
197                    1 => Fp2Opcode::SETUP_MULDIV,
198                    _ => panic!("invalid opcode"),
199                };
200                Some(Instruction::new(
201                    VmOpcode::from_usize(
202                        local_opcode.global_opcode().as_usize() + complex_idx_shift,
203                    ),
204                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
205                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
206                    F::ZERO, // rs2 = 0
207                    F::ONE,  // d_as = 1
208                    F::TWO,  // e_as = 2
209                    F::ZERO,
210                    F::ZERO,
211                ))
212            } else {
213                let global_opcode = match ComplexExtFieldBaseFunct7::from_repr(base_funct7) {
214                    Some(ComplexExtFieldBaseFunct7::Add) => {
215                        Fp2Opcode::ADD as usize + Fp2Opcode::CLASS_OFFSET
216                    }
217                    Some(ComplexExtFieldBaseFunct7::Sub) => {
218                        Fp2Opcode::SUB as usize + Fp2Opcode::CLASS_OFFSET
219                    }
220                    Some(ComplexExtFieldBaseFunct7::Mul) => {
221                        Fp2Opcode::MUL as usize + Fp2Opcode::CLASS_OFFSET
222                    }
223                    Some(ComplexExtFieldBaseFunct7::Div) => {
224                        Fp2Opcode::DIV as usize + Fp2Opcode::CLASS_OFFSET
225                    }
226                    _ => unimplemented!(),
227                };
228                let global_opcode = global_opcode + complex_idx_shift;
229                Some(from_r_type(global_opcode, 2, &dec_insn, true))
230            }
231        };
232        instruction.map(TranspilerOutput::one_to_one)
233    }
234}