openvm_rv32im_circuit/jalr/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7    arch::*,
8    system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13    AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17    instruction::Instruction,
18    program::{DEFAULT_PC_STEP, PC_BITS},
19    LocalOpcode,
20};
21use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
22use openvm_stark_backend::{
23    interaction::InteractionBuilder,
24    p3_air::{AirBuilder, BaseAir},
25    p3_field::{Field, FieldAlgebra, PrimeField32},
26    rap::BaseAirWithPublicValues,
27};
28
29use crate::adapters::{
30    Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
31};
32
33#[repr(C)]
34#[derive(Debug, Clone, AlignedBorrow)]
35pub struct Rv32JalrCoreCols<T> {
36    pub imm: T,
37    pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
38    // To save a column, we only store the 3 most significant limbs of `rd_data`
39    // the least significant limb can be derived using from_pc and the other limbs
40    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS - 1],
41    pub is_valid: T,
42
43    pub to_pc_least_sig_bit: T,
44    /// These are the limbs of `to_pc * 2`.
45    pub to_pc_limbs: [T; 2],
46    pub imm_sign: T,
47}
48
49#[derive(Debug, Clone, derive_new::new)]
50pub struct Rv32JalrCoreAir {
51    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
52    pub range_bus: VariableRangeCheckerBus,
53}
54
55impl<F: Field> BaseAir<F> for Rv32JalrCoreAir {
56    fn width(&self) -> usize {
57        Rv32JalrCoreCols::<F>::width()
58    }
59}
60
61impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalrCoreAir {}
62
63impl<AB, I> VmCoreAir<AB, I> for Rv32JalrCoreAir
64where
65    AB: InteractionBuilder,
66    I: VmAdapterInterface<AB::Expr>,
67    I::Reads: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
68    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
69    I::ProcessedInstruction: From<SignedImmInstruction<AB::Expr>>,
70{
71    fn eval(
72        &self,
73        builder: &mut AB,
74        local_core: &[AB::Var],
75        from_pc: AB::Var,
76    ) -> AdapterAirContext<AB::Expr, I> {
77        let cols: &Rv32JalrCoreCols<AB::Var> = (*local_core).borrow();
78        let Rv32JalrCoreCols::<AB::Var> {
79            imm,
80            rs1_data: rs1,
81            rd_data: rd,
82            is_valid,
83            imm_sign,
84            to_pc_least_sig_bit,
85            to_pc_limbs,
86        } = *cols;
87
88        builder.assert_bool(is_valid);
89
90        // composed is the composition of 3 most significant limbs of rd
91        let composed = rd
92            .iter()
93            .enumerate()
94            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
95                acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
96            });
97
98        let least_sig_limb = from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP) - composed;
99
100        // rd_data is the final decomposition of `from_pc + DEFAULT_PC_STEP` we need.
101        // The range check on `least_sig_limb` also ensures that `rd_data` correctly represents
102        // `from_pc + DEFAULT_PC_STEP`. Specifically, if `rd_data` does not match the
103        // expected limb, then `least_sig_limb` becomes the real `least_sig_limb` plus the
104        // difference between `composed` and the three most significant limbs of `from_pc +
105        // DEFAULT_PC_STEP`. In that case, `least_sig_limb` >= 2^RV32_CELL_BITS.
106        let rd_data = array::from_fn(|i| {
107            if i == 0 {
108                least_sig_limb.clone()
109            } else {
110                rd[i - 1].into().clone()
111            }
112        });
113
114        // Constrain rd_data
115        // Assumes only from_pc in [0,2^PC_BITS) is allowed by program bus
116        self.bitwise_lookup_bus
117            .send_range(rd_data[0].clone(), rd_data[1].clone())
118            .eval(builder, is_valid);
119        self.range_bus
120            .range_check(rd_data[2].clone(), RV32_CELL_BITS)
121            .eval(builder, is_valid);
122        self.range_bus
123            .range_check(rd_data[3].clone(), PC_BITS - RV32_CELL_BITS * 3)
124            .eval(builder, is_valid);
125
126        builder.assert_bool(imm_sign);
127
128        // Constrain to_pc_least_sig_bit + 2 * to_pc_limbs = rs1 + imm as a i32 addition with 2
129        // limbs RISC-V spec explicitly sets the least significant bit of `to_pc` to 0
130        let rs1_limbs_01 = rs1[0] + rs1[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
131        let rs1_limbs_23 = rs1[2] + rs1[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
132        let inv = AB::F::from_canonical_u32(1 << 16).inverse();
133
134        builder.assert_bool(to_pc_least_sig_bit);
135        let carry = (rs1_limbs_01 + imm - to_pc_limbs[0] * AB::F::TWO - to_pc_least_sig_bit) * inv;
136        builder.when(is_valid).assert_bool(carry.clone());
137
138        let imm_extend_limb = imm_sign * AB::F::from_canonical_u32((1 << 16) - 1);
139        let carry = (rs1_limbs_23 + imm_extend_limb + carry - to_pc_limbs[1]) * inv;
140        builder.when(is_valid).assert_bool(carry);
141
142        // preventing to_pc overflow
143        self.range_bus
144            .range_check(to_pc_limbs[1], PC_BITS - 16)
145            .eval(builder, is_valid);
146        self.range_bus
147            .range_check(to_pc_limbs[0], 15)
148            .eval(builder, is_valid);
149        let to_pc =
150            to_pc_limbs[0] * AB::F::TWO + to_pc_limbs[1] * AB::F::from_canonical_u32(1 << 16);
151
152        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, JALR);
153
154        AdapterAirContext {
155            to_pc: Some(to_pc),
156            reads: [rs1.map(|x| x.into())].into(),
157            writes: [rd_data].into(),
158            instruction: SignedImmInstruction {
159                is_valid: is_valid.into(),
160                opcode: expected_opcode,
161                immediate: imm.into(),
162                imm_sign: imm_sign.into(),
163            }
164            .into(),
165        }
166    }
167
168    fn start_offset(&self) -> usize {
169        Rv32JalrOpcode::CLASS_OFFSET
170    }
171}
172
173#[repr(C)]
174#[derive(AlignedBytesBorrow, Debug)]
175pub struct Rv32JalrCoreRecord {
176    pub imm: u16,
177    pub from_pc: u32,
178    pub rs1_val: u32,
179    pub imm_sign: bool,
180}
181
182#[derive(Clone, Copy, derive_new::new)]
183pub struct Rv32JalrExecutor<A = Rv32JalrAdapterExecutor> {
184    adapter: A,
185}
186
187#[derive(Clone)]
188pub struct Rv32JalrFiller<A = Rv32JalrAdapterFiller> {
189    adapter: A,
190    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
191    pub range_checker_chip: SharedVariableRangeCheckerChip,
192}
193
194impl<A> Rv32JalrFiller<A> {
195    pub fn new(
196        adapter: A,
197        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
198        range_checker_chip: SharedVariableRangeCheckerChip,
199    ) -> Self {
200        assert!(range_checker_chip.range_max_bits() >= 16);
201        Self {
202            adapter,
203            bitwise_lookup_chip,
204            range_checker_chip,
205        }
206    }
207}
208
209impl<F, A, RA> PreflightExecutor<F, RA> for Rv32JalrExecutor<A>
210where
211    F: PrimeField32,
212    A: 'static
213        + AdapterTraceExecutor<
214            F,
215            ReadData = [u8; RV32_REGISTER_NUM_LIMBS],
216            WriteData = [u8; RV32_REGISTER_NUM_LIMBS],
217        >,
218    for<'buf> RA: RecordArena<
219        'buf,
220        EmptyAdapterCoreLayout<F, A>,
221        (A::RecordMut<'buf>, &'buf mut Rv32JalrCoreRecord),
222    >,
223{
224    fn get_opcode_name(&self, opcode: usize) -> String {
225        format!(
226            "{:?}",
227            Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET)
228        )
229    }
230
231    fn execute(
232        &self,
233        state: VmStateMut<F, TracingMemory, RA>,
234        instruction: &Instruction<F>,
235    ) -> Result<(), ExecutionError> {
236        let Instruction { opcode, c, g, .. } = *instruction;
237
238        debug_assert_eq!(
239            opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET),
240            JALR as usize
241        );
242
243        let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
244
245        A::start(*state.pc, state.memory, &mut adapter_record);
246
247        core_record.rs1_val = u32::from_le_bytes(self.adapter.read(
248            state.memory,
249            instruction,
250            &mut adapter_record,
251        ));
252
253        core_record.imm = c.as_canonical_u32() as u16;
254        core_record.imm_sign = g.is_one();
255        core_record.from_pc = *state.pc;
256
257        let (to_pc, rd_data) = run_jalr(
258            core_record.from_pc,
259            core_record.rs1_val,
260            core_record.imm,
261            core_record.imm_sign,
262        );
263
264        self.adapter
265            .write(state.memory, instruction, rd_data, &mut adapter_record);
266
267        // RISC-V spec explicitly sets the least significant bit of `to_pc` to 0
268        *state.pc = to_pc & !1;
269
270        Ok(())
271    }
272}
273impl<F, A> TraceFiller<F> for Rv32JalrFiller<A>
274where
275    F: PrimeField32,
276    A: 'static + AdapterTraceFiller<F>,
277{
278    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
279        // SAFETY: row_slice is guaranteed by the caller to have at least A::WIDTH +
280        // Rv32JalrCoreCols::width() elements
281        let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
282        self.adapter.fill_trace_row(mem_helper, adapter_row);
283        // SAFETY: core_row contains a valid Rv32JalrCoreRecord written by the executor
284        // during trace generation
285        let record: &Rv32JalrCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
286
287        let core_row: &mut Rv32JalrCoreCols<F> = core_row.borrow_mut();
288
289        let (to_pc, rd_data) =
290            run_jalr(record.from_pc, record.rs1_val, record.imm, record.imm_sign);
291        let to_pc_limbs = [(to_pc & ((1 << 16) - 1)) >> 1, to_pc >> 16];
292        self.range_checker_chip.add_count(to_pc_limbs[0], 15);
293        self.range_checker_chip
294            .add_count(to_pc_limbs[1], PC_BITS - 16);
295        self.bitwise_lookup_chip
296            .request_range(rd_data[0] as u32, rd_data[1] as u32);
297
298        self.range_checker_chip
299            .add_count(rd_data[2] as u32, RV32_CELL_BITS);
300        self.range_checker_chip
301            .add_count(rd_data[3] as u32, PC_BITS - RV32_CELL_BITS * 3);
302
303        // Write in reverse order
304        core_row.imm_sign = F::from_bool(record.imm_sign);
305        core_row.to_pc_limbs = to_pc_limbs.map(F::from_canonical_u32);
306        core_row.to_pc_least_sig_bit = F::from_bool(to_pc & 1 == 1);
307        // fill_trace_row is called only on valid rows
308        core_row.is_valid = F::ONE;
309        core_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8);
310        core_row
311            .rd_data
312            .iter_mut()
313            .rev()
314            .zip(rd_data.iter().skip(1).rev())
315            .for_each(|(dst, src)| {
316                *dst = F::from_canonical_u8(*src);
317            });
318        core_row.imm = F::from_canonical_u16(record.imm);
319    }
320}
321
322// returns (to_pc, rd_data)
323#[inline(always)]
324pub(super) fn run_jalr(pc: u32, rs1: u32, imm: u16, imm_sign: bool) -> (u32, [u8; 4]) {
325    let to_pc = rs1.wrapping_add(imm as u32 + (imm_sign as u32 * 0xffff0000));
326    assert!(to_pc < (1 << PC_BITS));
327    (to_pc, pc.wrapping_add(DEFAULT_PC_STEP).to_le_bytes())
328}