openvm_rv32im_circuit/jal_lui/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::*,
5    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::{
8    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
9    AlignedBytesBorrow,
10};
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{
13    instruction::Instruction,
14    program::{DEFAULT_PC_STEP, PC_BITS},
15    LocalOpcode,
16};
17use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *};
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, FieldAlgebra, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24
25use crate::adapters::{
26    Rv32CondRdWriteAdapterExecutor, Rv32CondRdWriteAdapterFiller, RV32_CELL_BITS,
27    RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS,
28};
29
30pub(super) const ADDITIONAL_BITS: u32 = 0b11000000;
31
32#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct Rv32JalLuiCoreCols<T> {
35    pub imm: T,
36    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
37    pub is_jal: T,
38    pub is_lui: T,
39}
40
41#[derive(Debug, Clone, Copy, derive_new::new)]
42pub struct Rv32JalLuiCoreAir {
43    pub bus: BitwiseOperationLookupBus,
44}
45
46impl<F: Field> BaseAir<F> for Rv32JalLuiCoreAir {
47    fn width(&self) -> usize {
48        Rv32JalLuiCoreCols::<F>::width()
49    }
50}
51
52impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalLuiCoreAir {}
53
54impl<AB, I> VmCoreAir<AB, I> for Rv32JalLuiCoreAir
55where
56    AB: InteractionBuilder,
57    I: VmAdapterInterface<AB::Expr>,
58    I::Reads: From<[[AB::Expr; 0]; 0]>,
59    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
60    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
61{
62    fn eval(
63        &self,
64        builder: &mut AB,
65        local_core: &[AB::Var],
66        from_pc: AB::Var,
67    ) -> AdapterAirContext<AB::Expr, I> {
68        let cols: &Rv32JalLuiCoreCols<AB::Var> = (*local_core).borrow();
69        let Rv32JalLuiCoreCols::<AB::Var> {
70            imm,
71            rd_data: rd,
72            is_jal,
73            is_lui,
74        } = *cols;
75
76        builder.assert_bool(is_lui);
77        builder.assert_bool(is_jal);
78        let is_valid = is_lui + is_jal;
79        builder.assert_bool(is_valid.clone());
80        builder.when(is_lui).assert_zero(rd[0]);
81
82        for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
83            self.bus
84                .send_range(rd[i * 2], rd[i * 2 + 1])
85                .eval(builder, is_valid.clone());
86        }
87
88        // In case of JAL constrain that last limb has at most [last_limb_bits] bits
89
90        let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
91        let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
92        let additional_bits = AB::F::from_canonical_u32(additional_bits);
93        self.bus
94            .send_xor(rd[3], additional_bits, rd[3] + additional_bits)
95            .eval(builder, is_jal);
96
97        let intermed_val = rd
98            .iter()
99            .skip(1)
100            .enumerate()
101            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
102                acc + val * AB::Expr::from_canonical_u32(1 << (i * RV32_CELL_BITS))
103            });
104
105        // Constrain that imm * 2^4 is the correct composition of intermed_val in case of LUI
106        builder.when(is_lui).assert_eq(
107            intermed_val.clone(),
108            imm * AB::F::from_canonical_u32(1 << (12 - RV32_CELL_BITS)),
109        );
110
111        let intermed_val = rd[0] + intermed_val * AB::Expr::from_canonical_u32(1 << RV32_CELL_BITS);
112        // Constrain that from_pc + DEFAULT_PC_STEP is the correct composition of intermed_val in
113        // case of JAL
114        builder.when(is_jal).assert_eq(
115            intermed_val,
116            from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP),
117        );
118
119        let to_pc = from_pc + is_lui * AB::F::from_canonical_u32(DEFAULT_PC_STEP) + is_jal * imm;
120
121        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
122            self,
123            is_lui * AB::F::from_canonical_u32(LUI as u32)
124                + is_jal * AB::F::from_canonical_u32(JAL as u32),
125        );
126
127        AdapterAirContext {
128            to_pc: Some(to_pc),
129            reads: [].into(),
130            writes: [rd.map(|x| x.into())].into(),
131            instruction: ImmInstruction {
132                is_valid,
133                opcode: expected_opcode,
134                immediate: imm.into(),
135            }
136            .into(),
137        }
138    }
139
140    fn start_offset(&self) -> usize {
141        Rv32JalLuiOpcode::CLASS_OFFSET
142    }
143}
144
145#[repr(C)]
146#[derive(AlignedBytesBorrow, Debug)]
147pub struct Rv32JalLuiCoreRecord {
148    pub imm: u32,
149    pub rd_data: [u8; RV32_REGISTER_NUM_LIMBS],
150    pub is_jal: bool,
151}
152
153#[derive(Clone, Copy, derive_new::new)]
154pub struct Rv32JalLuiExecutor<A = Rv32CondRdWriteAdapterExecutor> {
155    pub adapter: A,
156}
157
158#[derive(Clone, derive_new::new)]
159pub struct Rv32JalLuiFiller<A = Rv32CondRdWriteAdapterFiller> {
160    adapter: A,
161    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
162}
163
164impl<F, A, RA> PreflightExecutor<F, RA> for Rv32JalLuiExecutor<A>
165where
166    F: PrimeField32,
167    A: 'static
168        + for<'a> AdapterTraceExecutor<F, ReadData = (), WriteData = [u8; RV32_REGISTER_NUM_LIMBS]>,
169    for<'buf> RA: RecordArena<
170        'buf,
171        EmptyAdapterCoreLayout<F, A>,
172        (A::RecordMut<'buf>, &'buf mut Rv32JalLuiCoreRecord),
173    >,
174{
175    fn get_opcode_name(&self, opcode: usize) -> String {
176        format!(
177            "{:?}",
178            Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET)
179        )
180    }
181
182    fn execute(
183        &self,
184        state: VmStateMut<F, TracingMemory, RA>,
185        instruction: &Instruction<F>,
186    ) -> Result<(), ExecutionError> {
187        let &Instruction { opcode, c: imm, .. } = instruction;
188
189        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
190
191        A::start(*state.pc, state.memory, &mut adapter_record);
192
193        let is_jal = opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET) == JAL as usize;
194        let signed_imm = get_signed_imm(is_jal, imm);
195
196        let (to_pc, rd_data) = run_jal_lui(is_jal, *state.pc, signed_imm);
197
198        core_record.imm = imm.as_canonical_u32();
199        core_record.rd_data = rd_data;
200        core_record.is_jal = is_jal;
201
202        self.adapter
203            .write(state.memory, instruction, rd_data, &mut adapter_record);
204
205        *state.pc = to_pc;
206
207        Ok(())
208    }
209}
210
211impl<F, A> TraceFiller<F> for Rv32JalLuiFiller<A>
212where
213    F: PrimeField32,
214    A: 'static + AdapterTraceFiller<F>,
215{
216    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
217        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
218        // Rv32JalLuiCoreCols::width() elements
219        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
220        self.adapter.fill_trace_row(mem_helper, adapter_row);
221        // SAFETY: core_row contains a valid Rv32JalLuiCoreRecord written by the executor
222        // during trace generation
223        let record: &Rv32JalLuiCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
224        let core_row: &mut Rv32JalLuiCoreCols<F> = core_row.borrow_mut();
225
226        for pair in record.rd_data.chunks_exact(2) {
227            self.bitwise_lookup_chip
228                .request_range(pair[0] as u32, pair[1] as u32);
229        }
230        if record.is_jal {
231            self.bitwise_lookup_chip
232                .request_xor(record.rd_data[3] as u32, ADDITIONAL_BITS);
233        }
234
235        // Writing in reverse order
236        core_row.is_lui = F::from_bool(!record.is_jal);
237        core_row.is_jal = F::from_bool(record.is_jal);
238        core_row.rd_data = record.rd_data.map(F::from_canonical_u8);
239        core_row.imm = F::from_canonical_u32(record.imm);
240    }
241}
242
243// returns the canonical signed representation of the immediate
244// `imm` can be "negative" as a field element
245pub(super) fn get_signed_imm<F: PrimeField32>(is_jal: bool, imm: F) -> i32 {
246    let imm_f = imm.as_canonical_u32();
247    if is_jal {
248        if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) {
249            imm_f as i32
250        } else {
251            let neg_imm_f = F::ORDER_U32 - imm_f;
252            debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)));
253            -(neg_imm_f as i32)
254        }
255    } else {
256        imm_f as i32
257    }
258}
259
260// returns (to_pc, rd_data)
261#[inline(always)]
262pub(super) fn run_jal_lui(is_jal: bool, pc: u32, imm: i32) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) {
263    if is_jal {
264        let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes();
265        let next_pc = pc as i32 + imm;
266        debug_assert!(next_pc >= 0);
267        (next_pc as u32, rd_data)
268    } else {
269        let imm = imm as u32;
270        let rd = imm << 12;
271        (pc + DEFAULT_PC_STEP, rd.to_le_bytes())
272    }
273}