openvm_bigint_transpiler/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
use openvm_bigint_guest::{Int256Funct7, BEQ256_FUNCT3, INT256_FUNCT3, OPCODE};
use openvm_instructions::{
    instruction::Instruction, riscv::RV32_REGISTER_NUM_LIMBS, utils::isize_to_field, UsizeOpcode,
    VmOpcode,
};
use openvm_instructions_derive::UsizeOpcode;
use openvm_rv32im_transpiler::{
    BaseAluOpcode, BranchEqualOpcode, BranchLessThanOpcode, LessThanOpcode, MulOpcode, ShiftOpcode,
};
use openvm_stark_backend::p3_field::PrimeField32;
use openvm_transpiler::{util::from_r_type, TranspilerExtension};
use rrs_lib::instruction_formats::{BType, RType};
use strum::IntoEnumIterator;

// =================================================================================================
// Intrinsics: 256-bit Integers
// =================================================================================================

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x400]
pub struct Rv32BaseAlu256Opcode(pub BaseAluOpcode);

impl Rv32BaseAlu256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        BaseAluOpcode::iter().map(Self)
    }
}

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x405]
pub struct Rv32Shift256Opcode(pub ShiftOpcode);

impl Rv32Shift256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        ShiftOpcode::iter().map(Self)
    }
}

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x408]
pub struct Rv32LessThan256Opcode(pub LessThanOpcode);

impl Rv32LessThan256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        LessThanOpcode::iter().map(Self)
    }
}

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x420]
pub struct Rv32BranchEqual256Opcode(pub BranchEqualOpcode);

impl Rv32BranchEqual256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        BranchEqualOpcode::iter().map(Self)
    }
}

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x425]
pub struct Rv32BranchLessThan256Opcode(pub BranchLessThanOpcode);

impl Rv32BranchLessThan256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        BranchLessThanOpcode::iter().map(Self)
    }
}

#[derive(Copy, Clone, Debug, UsizeOpcode)]
#[opcode_offset = 0x450]
pub struct Rv32Mul256Opcode(pub MulOpcode);

impl Rv32Mul256Opcode {
    pub fn iter() -> impl Iterator<Item = Self> {
        MulOpcode::iter().map(Self)
    }
}

#[derive(Default)]
pub struct Int256TranspilerExtension;

impl<F: PrimeField32> TranspilerExtension<F> for Int256TranspilerExtension {
    fn process_custom(&self, instruction_stream: &[u32]) -> Option<(Instruction<F>, usize)> {
        if instruction_stream.is_empty() {
            return None;
        }
        let instruction_u32 = instruction_stream[0];
        let opcode = (instruction_u32 & 0x7f) as u8;
        let funct3 = ((instruction_u32 >> 12) & 0b111) as u8;

        if opcode != OPCODE {
            return None;
        }
        if funct3 != INT256_FUNCT3 && funct3 != BEQ256_FUNCT3 {
            return None;
        }

        let dec_insn = RType::new(instruction_u32);
        let instruction = match funct3 {
            INT256_FUNCT3 => {
                let global_opcode = match Int256Funct7::from_repr(dec_insn.funct7 as u8) {
                    Some(Int256Funct7::Add) => {
                        BaseAluOpcode::ADD as usize + Rv32BaseAlu256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Sub) => {
                        BaseAluOpcode::SUB as usize + Rv32BaseAlu256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Xor) => {
                        BaseAluOpcode::XOR as usize + Rv32BaseAlu256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Or) => {
                        BaseAluOpcode::OR as usize + Rv32BaseAlu256Opcode::default_offset()
                    }
                    Some(Int256Funct7::And) => {
                        BaseAluOpcode::AND as usize + Rv32BaseAlu256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Sll) => {
                        ShiftOpcode::SLL as usize + Rv32Shift256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Srl) => {
                        ShiftOpcode::SRL as usize + Rv32Shift256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Sra) => {
                        ShiftOpcode::SRA as usize + Rv32Shift256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Slt) => {
                        LessThanOpcode::SLT as usize + Rv32LessThan256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Sltu) => {
                        LessThanOpcode::SLTU as usize + Rv32LessThan256Opcode::default_offset()
                    }
                    Some(Int256Funct7::Mul) => {
                        MulOpcode::MUL as usize + Rv32Mul256Opcode::default_offset()
                    }
                    _ => unimplemented!(),
                };
                Some(from_r_type(global_opcode, 2, &dec_insn))
            }
            BEQ256_FUNCT3 => {
                let dec_insn = BType::new(instruction_u32);
                Some(Instruction::new(
                    VmOpcode::from_usize(
                        BranchEqualOpcode::BEQ as usize
                            + Rv32BranchEqual256Opcode::default_offset(),
                    ),
                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs1),
                    F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS * dec_insn.rs2),
                    isize_to_field(dec_insn.imm as isize),
                    F::ONE,
                    F::TWO,
                    F::ZERO,
                    F::ZERO,
                ))
            }
            _ => None,
        };
        instruction.map(|instruction| (instruction, 1))
    }
}