openvm_algebra_transpiler/
lib.rs

1use openvm_algebra_guest::{
2    ComplexExtFieldBaseFunct7, ModArithBaseFunct7, COMPLEX_EXT_FIELD_FUNCT3,
3    MODULAR_ARITHMETIC_FUNCT3, OPCODE,
4};
5use openvm_instructions::{
6    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode, VmOpcode,
7};
8use openvm_instructions_derive::LocalOpcode;
9use openvm_stark_backend::p3_field::PrimeField32;
10use openvm_transpiler::{util::from_r_type, TranspilerExtension, TranspilerOutput};
11use rrs_lib::instruction_formats::RType;
12use strum::{EnumCount, EnumIter, FromRepr};
13
14#[derive(
15    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
16)]
17#[opcode_offset = 0x500]
18#[repr(usize)]
19#[allow(non_camel_case_types)]
20pub enum Rv32ModularArithmeticOpcode {
21    ADD,
22    SUB,
23    SETUP_ADDSUB,
24    MUL,
25    DIV,
26    SETUP_MULDIV,
27    IS_EQ,
28    SETUP_ISEQ,
29}
30
31#[derive(
32    Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, EnumCount, EnumIter, FromRepr, LocalOpcode,
33)]
34#[opcode_offset = 0x710]
35#[repr(usize)]
36#[allow(non_camel_case_types)]
37pub enum Fp2Opcode {
38    ADD,
39    SUB,
40    SETUP_ADDSUB,
41    MUL,
42    DIV,
43    SETUP_MULDIV,
44}
45
46#[derive(Default)]
47pub struct ModularTranspilerExtension;
48
49#[derive(Default)]
50pub struct Fp2TranspilerExtension;
51
52impl<F: PrimeField32> TranspilerExtension<F> for ModularTranspilerExtension {
53    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
54        if instruction_stream.is_empty() {
55            return None;
56        }
57        let instruction_u32 = instruction_stream[0];
58        let opcode = (instruction_u32 & 0x7f) as u8;
59        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
60
61        if opcode != OPCODE {
62            return None;
63        }
64        if funct3 != MODULAR_ARITHMETIC_FUNCT3 {
65            return None;
66        }
67
68        let instruction = {
69            let dec_insn = RType::new(instruction_u32);
70            let base_funct7 =
71                (dec_insn.funct7 as u8) % ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS;
72            assert!(
73                Rv32ModularArithmeticOpcode::COUNT
74                    <= ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS as usize
75            );
76            let mod_idx_shift = ((dec_insn.funct7 as u8)
77                / ModArithBaseFunct7::MODULAR_ARITHMETIC_MAX_KINDS)
78                as usize
79                * Rv32ModularArithmeticOpcode::COUNT;
80            if base_funct7 == ModArithBaseFunct7::SetupMod as u8 {
81                let local_opcode = match dec_insn.rs2 {
82                    0 => Rv32ModularArithmeticOpcode::SETUP_ADDSUB,
83                    1 => Rv32ModularArithmeticOpcode::SETUP_MULDIV,
84                    2 => Rv32ModularArithmeticOpcode::SETUP_ISEQ,
85                    _ => panic!("invalid opcode"),
86                };
87                if local_opcode == Rv32ModularArithmeticOpcode::SETUP_ISEQ && dec_insn.rd == 0 {
88                    panic!("SETUP_ISEQ is not valid for rd = x0");
89                } else {
90                    Some(Instruction::new(
91                        VmOpcode::from_usize(
92                            local_opcode.global_opcode().as_usize() + mod_idx_shift,
93                        ),
94                        F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
95                        F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
96                        F::ZERO, // rs2 = 0
97                        F::ONE,  // d_as = 1
98                        F::TWO,  // e_as = 2
99                        F::ZERO,
100                        F::ZERO,
101                    ))
102                }
103            } else {
104                let global_opcode = match ModArithBaseFunct7::from_repr(base_funct7) {
105                    Some(ModArithBaseFunct7::AddMod) => {
106                        Rv32ModularArithmeticOpcode::ADD as usize
107                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
108                    }
109                    Some(ModArithBaseFunct7::SubMod) => {
110                        Rv32ModularArithmeticOpcode::SUB as usize
111                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
112                    }
113                    Some(ModArithBaseFunct7::MulMod) => {
114                        Rv32ModularArithmeticOpcode::MUL as usize
115                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
116                    }
117                    Some(ModArithBaseFunct7::DivMod) => {
118                        Rv32ModularArithmeticOpcode::DIV as usize
119                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
120                    }
121                    Some(ModArithBaseFunct7::IsEqMod) => {
122                        Rv32ModularArithmeticOpcode::IS_EQ as usize
123                            + Rv32ModularArithmeticOpcode::CLASS_OFFSET
124                    }
125                    _ => unimplemented!(),
126                };
127                let global_opcode = global_opcode + mod_idx_shift;
128                // The only opcode in this extension which can write to rd is `IsEqMod`
129                // so we cannot allow rd to be zero in this case.
130                let allow_rd_zero =
131                    ModArithBaseFunct7::from_repr(base_funct7) != Some(ModArithBaseFunct7::IsEqMod);
132                Some(from_r_type(global_opcode, 2, &dec_insn, allow_rd_zero))
133            }
134        };
135        instruction.map(TranspilerOutput::one_to_one)
136    }
137}
138
139impl<F: PrimeField32> TranspilerExtension<F> for Fp2TranspilerExtension {
140    fn process_custom(&self, instruction_stream: &[u32]) -> Option<TranspilerOutput<F>> {
141        if instruction_stream.is_empty() {
142            return None;
143        }
144        let instruction_u32 = instruction_stream[0];
145        let opcode = (instruction_u32 & 0x7f) as u8;
146        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;
147
148        if opcode != OPCODE {
149            return None;
150        }
151        if funct3 != COMPLEX_EXT_FIELD_FUNCT3 {
152            return None;
153        }
154
155        let instruction = {
156            assert!(
157                Fp2Opcode::COUNT <= ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS as usize
158            );
159            let dec_insn = RType::new(instruction_u32);
160            let base_funct7 =
161                (dec_insn.funct7 as u8) % ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS;
162            let complex_idx_shift = ((dec_insn.funct7 as u8)
163                / ComplexExtFieldBaseFunct7::COMPLEX_EXT_FIELD_MAX_KINDS)
164                as usize
165                * Fp2Opcode::COUNT;
166
167            if base_funct7 == ComplexExtFieldBaseFunct7::Setup as u8 {
168                let local_opcode = match dec_insn.rs2 {
169                    0 => Fp2Opcode::SETUP_ADDSUB,
170                    1 => Fp2Opcode::SETUP_MULDIV,
171                    _ => panic!("invalid opcode"),
172                };
173                Some(Instruction::new(
174                    VmOpcode::from_usize(
175                        local_opcode.global_opcode().as_usize() + complex_idx_shift,
176                    ),
177                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rd),
178                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
179                    F::ZERO, // rs2 = 0
180                    F::ONE,  // d_as = 1
181                    F::TWO,  // e_as = 2
182                    F::ZERO,
183                    F::ZERO,
184                ))
185            } else {
186                let global_opcode = match ComplexExtFieldBaseFunct7::from_repr(base_funct7) {
187                    Some(ComplexExtFieldBaseFunct7::Add) => {
188                        Fp2Opcode::ADD as usize + Fp2Opcode::CLASS_OFFSET
189                    }
190                    Some(ComplexExtFieldBaseFunct7::Sub) => {
191                        Fp2Opcode::SUB as usize + Fp2Opcode::CLASS_OFFSET
192                    }
193                    Some(ComplexExtFieldBaseFunct7::Mul) => {
194                        Fp2Opcode::MUL as usize + Fp2Opcode::CLASS_OFFSET
195                    }
196                    Some(ComplexExtFieldBaseFunct7::Div) => {
197                        Fp2Opcode::DIV as usize + Fp2Opcode::CLASS_OFFSET
198                    }
199                    _ => unimplemented!(),
200                };
201                let global_opcode = global_opcode + complex_idx_shift;
202                Some(from_r_type(global_opcode, 2, &dec_insn, true))
203            }
204        };
205        instruction.map(TranspilerOutput::one_to_one)
206    }
207}