openvm_rv32im_circuit/jal_lui/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::bitwise_op_lookup::{
11    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
12};
13use openvm_circuit_primitives_derive::AlignedBorrow;
14use openvm_instructions::{
15    instruction::Instruction,
16    program::{DEFAULT_PC_STEP, PC_BITS},
17    LocalOpcode,
18};
19use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *};
20use openvm_stark_backend::{
21    interaction::InteractionBuilder,
22    p3_air::{AirBuilder, BaseAir},
23    p3_field::{Field, FieldAlgebra, PrimeField32},
24    rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS};
29
30#[repr(C)]
31#[derive(Debug, Clone, AlignedBorrow)]
32pub struct Rv32JalLuiCoreCols<T> {
33    pub imm: T,
34    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
35    pub is_jal: T,
36    pub is_lui: T,
37}
38
39#[derive(Debug, Clone)]
40pub struct Rv32JalLuiCoreAir {
41    pub bus: BitwiseOperationLookupBus,
42}
43
44impl<F: Field> BaseAir<F> for Rv32JalLuiCoreAir {
45    fn width(&self) -> usize {
46        Rv32JalLuiCoreCols::<F>::width()
47    }
48}
49
50impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalLuiCoreAir {}
51
52impl<AB, I> VmCoreAir<AB, I> for Rv32JalLuiCoreAir
53where
54    AB: InteractionBuilder,
55    I: VmAdapterInterface<AB::Expr>,
56    I::Reads: From<[[AB::Expr; 0]; 0]>,
57    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
58    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
59{
60    fn eval(
61        &self,
62        builder: &mut AB,
63        local_core: &[AB::Var],
64        from_pc: AB::Var,
65    ) -> AdapterAirContext<AB::Expr, I> {
66        let cols: &Rv32JalLuiCoreCols<AB::Var> = (*local_core).borrow();
67        let Rv32JalLuiCoreCols::<AB::Var> {
68            imm,
69            rd_data: rd,
70            is_jal,
71            is_lui,
72        } = *cols;
73
74        builder.assert_bool(is_lui);
75        builder.assert_bool(is_jal);
76        let is_valid = is_lui + is_jal;
77        builder.assert_bool(is_valid.clone());
78        builder.when(is_lui).assert_zero(rd[0]);
79
80        for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
81            self.bus
82                .send_range(rd[i * 2], rd[i * 2 + 1])
83                .eval(builder, is_valid.clone());
84        }
85
86        // In case of JAL constrain that last limb has at most [last_limb_bits] bits
87
88        let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
89        let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
90        let additional_bits = AB::F::from_canonical_u32(additional_bits);
91        self.bus
92            .send_xor(rd[3], additional_bits, rd[3] + additional_bits)
93            .eval(builder, is_jal);
94
95        let intermed_val = rd
96            .iter()
97            .skip(1)
98            .enumerate()
99            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
100                acc + val * AB::Expr::from_canonical_u32(1 << (i * RV32_CELL_BITS))
101            });
102
103        // Constrain that imm * 2^4 is the correct composition of intermed_val in case of LUI
104        builder.when(is_lui).assert_eq(
105            intermed_val.clone(),
106            imm * AB::F::from_canonical_u32(1 << (12 - RV32_CELL_BITS)),
107        );
108
109        let intermed_val = rd[0] + intermed_val * AB::Expr::from_canonical_u32(1 << RV32_CELL_BITS);
110        // Constrain that from_pc + DEFAULT_PC_STEP is the correct composition of intermed_val in case of JAL
111        builder.when(is_jal).assert_eq(
112            intermed_val,
113            from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP),
114        );
115
116        let to_pc = from_pc + is_lui * AB::F::from_canonical_u32(DEFAULT_PC_STEP) + is_jal * imm;
117
118        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
119            self,
120            is_lui * AB::F::from_canonical_u32(LUI as u32)
121                + is_jal * AB::F::from_canonical_u32(JAL as u32),
122        );
123
124        AdapterAirContext {
125            to_pc: Some(to_pc),
126            reads: [].into(),
127            writes: [rd.map(|x| x.into())].into(),
128            instruction: ImmInstruction {
129                is_valid,
130                opcode: expected_opcode,
131                immediate: imm.into(),
132            }
133            .into(),
134        }
135    }
136
137    fn start_offset(&self) -> usize {
138        Rv32JalLuiOpcode::CLASS_OFFSET
139    }
140}
141
142#[repr(C)]
143#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(bound = "F: Field")]
145pub struct Rv32JalLuiCoreRecord<F: Field> {
146    pub rd_data: [F; RV32_REGISTER_NUM_LIMBS],
147    pub imm: F,
148    pub is_jal: bool,
149    pub is_lui: bool,
150}
151
152pub struct Rv32JalLuiCoreChip {
153    pub air: Rv32JalLuiCoreAir,
154    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
155}
156
157impl Rv32JalLuiCoreChip {
158    pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>) -> Self {
159        Self {
160            air: Rv32JalLuiCoreAir {
161                bus: bitwise_lookup_chip.bus(),
162            },
163            bitwise_lookup_chip,
164        }
165    }
166}
167
168impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for Rv32JalLuiCoreChip
169where
170    I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
171{
172    type Record = Rv32JalLuiCoreRecord<F>;
173    type Air = Rv32JalLuiCoreAir;
174
175    #[allow(clippy::type_complexity)]
176    fn execute_instruction(
177        &self,
178        instruction: &Instruction<F>,
179        from_pc: u32,
180        _reads: I::Reads,
181    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
182        let local_opcode = Rv32JalLuiOpcode::from_usize(
183            instruction
184                .opcode
185                .local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET),
186        );
187        let imm = instruction.c;
188
189        let signed_imm = match local_opcode {
190            JAL => {
191                // Note: signed_imm is a signed integer and imm is a field element
192                (imm + F::from_canonical_u32(1 << (RV_J_TYPE_IMM_BITS - 1))).as_canonical_u32()
193                    as i32
194                    - (1 << (RV_J_TYPE_IMM_BITS - 1))
195            }
196            LUI => imm.as_canonical_u32() as i32,
197        };
198        let (to_pc, rd_data) = run_jal_lui(local_opcode, from_pc, signed_imm);
199
200        for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
201            self.bitwise_lookup_chip
202                .request_range(rd_data[i * 2], rd_data[i * 2 + 1]);
203        }
204
205        if local_opcode == JAL {
206            let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
207            let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
208            self.bitwise_lookup_chip
209                .request_xor(rd_data[3], additional_bits);
210        }
211
212        let rd_data = rd_data.map(F::from_canonical_u32);
213
214        let output = AdapterRuntimeContext {
215            to_pc: Some(to_pc),
216            writes: [rd_data].into(),
217        };
218
219        Ok((
220            output,
221            Rv32JalLuiCoreRecord {
222                rd_data,
223                imm,
224                is_jal: local_opcode == JAL,
225                is_lui: local_opcode == LUI,
226            },
227        ))
228    }
229
230    fn get_opcode_name(&self, opcode: usize) -> String {
231        format!(
232            "{:?}",
233            Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET)
234        )
235    }
236
237    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
238        let core_cols: &mut Rv32JalLuiCoreCols<F> = row_slice.borrow_mut();
239        core_cols.rd_data = record.rd_data;
240        core_cols.imm = record.imm;
241        core_cols.is_jal = F::from_bool(record.is_jal);
242        core_cols.is_lui = F::from_bool(record.is_lui);
243    }
244
245    fn air(&self) -> &Self::Air {
246        &self.air
247    }
248}
249
250// returns (to_pc, rd_data)
251pub(super) fn run_jal_lui(
252    opcode: Rv32JalLuiOpcode,
253    pc: u32,
254    imm: i32,
255) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) {
256    match opcode {
257        JAL => {
258            let rd_data = array::from_fn(|i| {
259                ((pc + DEFAULT_PC_STEP) >> (8 * i)) & ((1 << RV32_CELL_BITS) - 1)
260            });
261            let next_pc = pc as i32 + imm;
262            assert!(next_pc >= 0);
263            (next_pc as u32, rd_data)
264        }
265        LUI => {
266            let imm = imm as u32;
267            let rd = imm << 12;
268            let rd_data =
269                array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1));
270            (pc + DEFAULT_PC_STEP, rd_data)
271        }
272    }
273}