openvm_rv32im_circuit/jal_lui/
core.rs

1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4    arch::*,
5    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
6};
7use openvm_circuit_primitives::{
8    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
9    AlignedBytesBorrow,
10};
11use openvm_circuit_primitives_derive::AlignedBorrow;
12use openvm_instructions::{
13    instruction::Instruction,
14    program::{DEFAULT_PC_STEP, PC_BITS},
15    LocalOpcode,
16};
17use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *};
18use openvm_stark_backend::{
19    interaction::InteractionBuilder,
20    p3_air::{AirBuilder, BaseAir},
21    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
22    rap::BaseAirWithPublicValues,
23};
24
25use crate::adapters::{
26    Rv32CondRdWriteAdapterExecutor, Rv32CondRdWriteAdapterFiller, RV32_CELL_BITS,
27    RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS,
28};
29
30pub(super) const ADDITIONAL_BITS: u32 = 0b11000000;
31
32#[repr(C)]
33#[derive(Debug, Clone, AlignedBorrow)]
34pub struct Rv32JalLuiCoreCols<T> {
35    pub imm: T,
36    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
37    pub is_jal: T,
38    pub is_lui: T,
39}
40
41#[derive(Debug, Clone, Copy, derive_new::new)]
42pub struct Rv32JalLuiCoreAir {
43    pub bus: BitwiseOperationLookupBus,
44}
45
46impl<F: Field> BaseAir<F> for Rv32JalLuiCoreAir {
47    fn width(&self) -> usize {
48        Rv32JalLuiCoreCols::<F>::width()
49    }
50}
51
52impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalLuiCoreAir {}
53
54impl<AB, I> VmCoreAir<AB, I> for Rv32JalLuiCoreAir
55where
56    AB: InteractionBuilder,
57    I: VmAdapterInterface<AB::Expr>,
58    I::Reads: From<[[AB::Expr; 0]; 0]>,
59    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
60    I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
61{
62    fn eval(
63        &self,
64        builder: &mut AB,
65        local_core: &[AB::Var],
66        from_pc: AB::Var,
67    ) -> AdapterAirContext<AB::Expr, I> {
68        let cols: &Rv32JalLuiCoreCols<AB::Var> = (*local_core).borrow();
69        let Rv32JalLuiCoreCols::<AB::Var> {
70            imm,
71            rd_data: rd,
72            is_jal,
73            is_lui,
74        } = *cols;
75
76        builder.assert_bool(is_lui);
77        builder.assert_bool(is_jal);
78        let is_valid = is_lui + is_jal;
79        builder.assert_bool(is_valid.clone());
80        builder.when(is_lui).assert_zero(rd[0]);
81
82        for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
83            self.bus
84                .send_range(rd[i * 2], rd[i * 2 + 1])
85                .eval(builder, is_valid.clone());
86        }
87
88        // In case of JAL constrain that last limb has at most [last_limb_bits] bits
89
90        let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
91        let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
92        let additional_bits = AB::F::from_u32(additional_bits);
93        self.bus
94            .send_xor(rd[3], additional_bits, rd[3] + additional_bits)
95            .eval(builder, is_jal);
96
97        let intermed_val = rd
98            .iter()
99            .skip(1)
100            .enumerate()
101            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
102                acc + val * AB::Expr::from_u32(1 << (i * RV32_CELL_BITS))
103            });
104
105        // Constrain that imm * 2^4 is the correct composition of intermed_val in case of LUI
106        builder.when(is_lui).assert_eq(
107            intermed_val.clone(),
108            imm * AB::F::from_u32(1 << (12 - RV32_CELL_BITS)),
109        );
110
111        let intermed_val = rd[0] + intermed_val * AB::Expr::from_u32(1 << RV32_CELL_BITS);
112        // Constrain that from_pc + DEFAULT_PC_STEP is the correct composition of intermed_val in
113        // case of JAL
114        builder
115            .when(is_jal)
116            .assert_eq(intermed_val, from_pc + AB::F::from_u32(DEFAULT_PC_STEP));
117
118        let to_pc = from_pc + is_lui * AB::F::from_u32(DEFAULT_PC_STEP) + is_jal * imm;
119
120        let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
121            self,
122            is_lui * AB::F::from_u32(LUI as u32) + is_jal * AB::F::from_u32(JAL as u32),
123        );
124
125        AdapterAirContext {
126            to_pc: Some(to_pc),
127            reads: [].into(),
128            writes: [rd.map(|x| x.into())].into(),
129            instruction: ImmInstruction {
130                is_valid,
131                opcode: expected_opcode,
132                immediate: imm.into(),
133            }
134            .into(),
135        }
136    }
137
138    fn start_offset(&self) -> usize {
139        Rv32JalLuiOpcode::CLASS_OFFSET
140    }
141}
142
143#[repr(C)]
144#[derive(AlignedBytesBorrow, Debug)]
145pub struct Rv32JalLuiCoreRecord {
146    pub imm: u32,
147    pub rd_data: [u8; RV32_REGISTER_NUM_LIMBS],
148    pub is_jal: bool,
149}
150
151#[derive(Clone, Copy, derive_new::new)]
152pub struct Rv32JalLuiExecutor<A = Rv32CondRdWriteAdapterExecutor> {
153    pub adapter: A,
154}
155
156#[derive(Clone, derive_new::new)]
157pub struct Rv32JalLuiFiller<A = Rv32CondRdWriteAdapterFiller> {
158    adapter: A,
159    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
160}
161
162impl<F, A, RA> PreflightExecutor<F, RA> for Rv32JalLuiExecutor<A>
163where
164    F: PrimeField32,
165    A: 'static
166        + for<'a> AdapterTraceExecutor<F, ReadData = (), WriteData = [u8; RV32_REGISTER_NUM_LIMBS]>,
167    for<'buf> RA: RecordArena<
168        'buf,
169        EmptyAdapterCoreLayout<F, A>,
170        (A::RecordMut<'buf>, &'buf mut Rv32JalLuiCoreRecord),
171    >,
172{
173    fn get_opcode_name(&self, opcode: usize) -> String {
174        format!(
175            "{:?}",
176            Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET)
177        )
178    }
179
180    fn execute(
181        &self,
182        state: VmStateMut<F, TracingMemory, RA>,
183        instruction: &Instruction<F>,
184    ) -> Result<(), ExecutionError> {
185        let &Instruction { opcode, c: imm, .. } = instruction;
186
187        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
188
189        A::start(*state.pc, state.memory, &mut adapter_record);
190
191        let is_jal = opcode.local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET) == JAL as usize;
192        let signed_imm = get_signed_imm(is_jal, imm);
193
194        let (to_pc, rd_data) = run_jal_lui(is_jal, *state.pc, signed_imm);
195
196        core_record.imm = imm.as_canonical_u32();
197        core_record.rd_data = rd_data;
198        core_record.is_jal = is_jal;
199
200        self.adapter
201            .write(state.memory, instruction, rd_data, &mut adapter_record);
202
203        *state.pc = to_pc;
204
205        Ok(())
206    }
207}
208
209impl<F, A> TraceFiller<F> for Rv32JalLuiFiller<A>
210where
211    F: PrimeField32,
212    A: 'static + AdapterTraceFiller<F>,
213{
214    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
215        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
216        // Rv32JalLuiCoreCols::width() elements
217        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
218        self.adapter.fill_trace_row(mem_helper, adapter_row);
219        // SAFETY: core_row contains a valid Rv32JalLuiCoreRecord written by the executor
220        // during trace generation
221        let record: &Rv32JalLuiCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
222        let core_row: &mut Rv32JalLuiCoreCols<F> = core_row.borrow_mut();
223
224        for pair in record.rd_data.chunks_exact(2) {
225            self.bitwise_lookup_chip
226                .request_range(pair[0] as u32, pair[1] as u32);
227        }
228        if record.is_jal {
229            self.bitwise_lookup_chip
230                .request_xor(record.rd_data[3] as u32, ADDITIONAL_BITS);
231        }
232
233        // Writing in reverse order
234        core_row.is_lui = F::from_bool(!record.is_jal);
235        core_row.is_jal = F::from_bool(record.is_jal);
236        core_row.rd_data = record.rd_data.map(F::from_u8);
237        core_row.imm = F::from_u32(record.imm);
238    }
239}
240
241// returns the canonical signed representation of the immediate
242// `imm` can be "negative" as a field element
243pub(super) fn get_signed_imm<F: PrimeField32>(is_jal: bool, imm: F) -> i32 {
244    let imm_f = imm.as_canonical_u32();
245    if is_jal {
246        if imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)) {
247            imm_f as i32
248        } else {
249            let neg_imm_f = F::ORDER_U32 - imm_f;
250            debug_assert!(neg_imm_f < (1 << (RV_J_TYPE_IMM_BITS - 1)));
251            -(neg_imm_f as i32)
252        }
253    } else {
254        imm_f as i32
255    }
256}
257
258// returns (to_pc, rd_data)
259#[inline(always)]
260pub(super) fn run_jal_lui(is_jal: bool, pc: u32, imm: i32) -> (u32, [u8; RV32_REGISTER_NUM_LIMBS]) {
261    if is_jal {
262        let rd_data = (pc + DEFAULT_PC_STEP).to_le_bytes();
263        let next_pc = pc as i32 + imm;
264        debug_assert!(next_pc >= 0);
265        (next_pc as u32, rd_data)
266    } else {
267        let imm = imm as u32;
268        let rd = imm << 12;
269        (pc + DEFAULT_PC_STEP, rd.to_le_bytes())
270    }
271}