openvm_rv32im_circuit/jalr/
core.rs

1use std::{
2    array,
3    borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7    AdapterAirContext, AdapterRuntimeContext, Result, SignedImmInstruction, VmAdapterInterface,
8    VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12    var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{
16    instruction::Instruction,
17    program::{DEFAULT_PC_STEP, PC_BITS},
18    LocalOpcode,
19};
20use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
21use openvm_stark_backend::{
22    interaction::InteractionBuilder,
23    p3_air::{AirBuilder, BaseAir},
24    p3_field::{Field, FieldAlgebra, PrimeField32},
25    rap::BaseAirWithPublicValues,
26};
27use serde::{Deserialize, Serialize};
28
29use crate::adapters::{compose, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
30
31const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1;
32
33#[repr(C)]
34#[derive(Debug, Clone, AlignedBorrow)]
35pub struct Rv32JalrCoreCols<T> {
36    pub imm: T,
37    pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
38    // To save a column, we only store the 3 most significant limbs of `rd_data`
39    // the least significant limb can be derived using from_pc and the other limbs
40    pub rd_data: [T; RV32_REGISTER_NUM_LIMBS - 1],
41    pub is_valid: T,
42
43    pub to_pc_least_sig_bit: T,
44    /// These are the limbs of `to_pc * 2`.
45    pub to_pc_limbs: [T; 2],
46    pub imm_sign: T,
47}
48
49#[repr(C)]
50#[derive(Serialize, Deserialize)]
51pub struct Rv32JalrCoreRecord<F> {
52    pub imm: F,
53    pub rs1_data: [F; RV32_REGISTER_NUM_LIMBS],
54    pub rd_data: [F; RV32_REGISTER_NUM_LIMBS - 1],
55    pub to_pc_least_sig_bit: F,
56    pub to_pc_limbs: [u32; 2],
57    pub imm_sign: F,
58}
59
60#[derive(Debug, Clone)]
61pub struct Rv32JalrCoreAir {
62    pub bitwise_lookup_bus: BitwiseOperationLookupBus,
63    pub range_bus: VariableRangeCheckerBus,
64}
65
66impl<F: Field> BaseAir<F> for Rv32JalrCoreAir {
67    fn width(&self) -> usize {
68        Rv32JalrCoreCols::<F>::width()
69    }
70}
71
72impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalrCoreAir {}
73
74impl<AB, I> VmCoreAir<AB, I> for Rv32JalrCoreAir
75where
76    AB: InteractionBuilder,
77    I: VmAdapterInterface<AB::Expr>,
78    I::Reads: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
79    I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
80    I::ProcessedInstruction: From<SignedImmInstruction<AB::Expr>>,
81{
82    fn eval(
83        &self,
84        builder: &mut AB,
85        local_core: &[AB::Var],
86        from_pc: AB::Var,
87    ) -> AdapterAirContext<AB::Expr, I> {
88        let cols: &Rv32JalrCoreCols<AB::Var> = (*local_core).borrow();
89        let Rv32JalrCoreCols::<AB::Var> {
90            imm,
91            rs1_data: rs1,
92            rd_data: rd,
93            is_valid,
94            imm_sign,
95            to_pc_least_sig_bit,
96            to_pc_limbs,
97        } = *cols;
98
99        builder.assert_bool(is_valid);
100
101        // composed is the composition of 3 most significant limbs of rd
102        let composed = rd
103            .iter()
104            .enumerate()
105            .fold(AB::Expr::ZERO, |acc, (i, &val)| {
106                acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
107            });
108
109        let least_sig_limb = from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP) - composed;
110
111        // rd_data is the final decomposition of `from_pc + DEFAULT_PC_STEP` we need.
112        // The range check on `least_sig_limb` also ensures that `rd_data` correctly represents `from_pc + DEFAULT_PC_STEP`.
113        // Specifically, if `rd_data` does not match the expected limb, then `least_sig_limb` becomes
114        // the real `least_sig_limb` plus the difference between `composed` and the three most significant limbs of `from_pc + DEFAULT_PC_STEP`.
115        // In that case, `least_sig_limb` >= 2^RV32_CELL_BITS.
116        let rd_data = array::from_fn(|i| {
117            if i == 0 {
118                least_sig_limb.clone()
119            } else {
120                rd[i - 1].into().clone()
121            }
122        });
123
124        // Constrain rd_data
125        // Assumes only from_pc in [0,2^PC_BITS) is allowed by program bus
126        self.bitwise_lookup_bus
127            .send_range(rd_data[0].clone(), rd_data[1].clone())
128            .eval(builder, is_valid);
129        self.range_bus
130            .range_check(rd_data[2].clone(), RV32_CELL_BITS)
131            .eval(builder, is_valid);
132        self.range_bus
133            .range_check(rd_data[3].clone(), PC_BITS - RV32_CELL_BITS * 3)
134            .eval(builder, is_valid);
135
136        builder.assert_bool(imm_sign);
137
138        // Constrain to_pc_least_sig_bit + 2 * to_pc_limbs = rs1 + imm as a i32 addition with 2 limbs
139        // RISC-V spec explicitly sets the least significant bit of `to_pc` to 0
140        let rs1_limbs_01 = rs1[0] + rs1[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
141        let rs1_limbs_23 = rs1[2] + rs1[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
142        let inv = AB::F::from_canonical_u32(1 << 16).inverse();
143
144        builder.assert_bool(to_pc_least_sig_bit);
145        let carry = (rs1_limbs_01 + imm - to_pc_limbs[0] * AB::F::TWO - to_pc_least_sig_bit) * inv;
146        builder.when(is_valid).assert_bool(carry.clone());
147
148        let imm_extend_limb = imm_sign * AB::F::from_canonical_u32((1 << 16) - 1);
149        let carry = (rs1_limbs_23 + imm_extend_limb + carry - to_pc_limbs[1]) * inv;
150        builder.when(is_valid).assert_bool(carry);
151
152        // preventing to_pc overflow
153        self.range_bus
154            .range_check(to_pc_limbs[1], PC_BITS - 16)
155            .eval(builder, is_valid);
156        self.range_bus
157            .range_check(to_pc_limbs[0], 15)
158            .eval(builder, is_valid);
159        let to_pc =
160            to_pc_limbs[0] * AB::F::TWO + to_pc_limbs[1] * AB::F::from_canonical_u32(1 << 16);
161
162        let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, JALR);
163
164        AdapterAirContext {
165            to_pc: Some(to_pc),
166            reads: [rs1.map(|x| x.into())].into(),
167            writes: [rd_data].into(),
168            instruction: SignedImmInstruction {
169                is_valid: is_valid.into(),
170                opcode: expected_opcode,
171                immediate: imm.into(),
172                imm_sign: imm_sign.into(),
173            }
174            .into(),
175        }
176    }
177
178    fn start_offset(&self) -> usize {
179        Rv32JalrOpcode::CLASS_OFFSET
180    }
181}
182
183pub struct Rv32JalrCoreChip {
184    pub air: Rv32JalrCoreAir,
185    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
186    pub range_checker_chip: SharedVariableRangeCheckerChip,
187}
188
189impl Rv32JalrCoreChip {
190    pub fn new(
191        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
192        range_checker_chip: SharedVariableRangeCheckerChip,
193    ) -> Self {
194        assert!(range_checker_chip.range_max_bits() >= 16);
195        Self {
196            air: Rv32JalrCoreAir {
197                bitwise_lookup_bus: bitwise_lookup_chip.bus(),
198                range_bus: range_checker_chip.bus(),
199            },
200            bitwise_lookup_chip,
201            range_checker_chip,
202        }
203    }
204}
205
206impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for Rv32JalrCoreChip
207where
208    I::Reads: Into<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
209    I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
210{
211    type Record = Rv32JalrCoreRecord<F>;
212    type Air = Rv32JalrCoreAir;
213
214    #[allow(clippy::type_complexity)]
215    fn execute_instruction(
216        &self,
217        instruction: &Instruction<F>,
218        from_pc: u32,
219        reads: I::Reads,
220    ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
221        let Instruction { opcode, c, g, .. } = *instruction;
222        let local_opcode =
223            Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET));
224
225        let imm = c.as_canonical_u32();
226        let imm_sign = g.as_canonical_u32();
227        let imm_extended = imm + imm_sign * 0xffff0000;
228
229        let rs1 = reads.into()[0];
230        let rs1_val = compose(rs1);
231
232        let (to_pc, rd_data) = run_jalr(local_opcode, from_pc, imm_extended, rs1_val);
233
234        self.bitwise_lookup_chip
235            .request_range(rd_data[0], rd_data[1]);
236        self.range_checker_chip
237            .add_count(rd_data[2], RV32_CELL_BITS);
238        self.range_checker_chip
239            .add_count(rd_data[3], PC_BITS - RV32_CELL_BITS * 3);
240
241        let mask = (1 << 15) - 1;
242        let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1;
243
244        let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask));
245
246        let rd_data = rd_data.map(F::from_canonical_u32);
247
248        let output = AdapterRuntimeContext {
249            to_pc: Some(to_pc),
250            writes: [rd_data].into(),
251        };
252
253        Ok((
254            output,
255            Rv32JalrCoreRecord {
256                imm: c,
257                rd_data: array::from_fn(|i| rd_data[i + 1]),
258                rs1_data: rs1,
259                to_pc_least_sig_bit: F::from_canonical_u32(to_pc_least_sig_bit),
260                to_pc_limbs,
261                imm_sign: g,
262            },
263        ))
264    }
265
266    fn get_opcode_name(&self, opcode: usize) -> String {
267        format!(
268            "{:?}",
269            Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET)
270        )
271    }
272
273    fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
274        self.range_checker_chip.add_count(record.to_pc_limbs[0], 15);
275        self.range_checker_chip.add_count(record.to_pc_limbs[1], 14);
276
277        let core_cols: &mut Rv32JalrCoreCols<F> = row_slice.borrow_mut();
278        core_cols.imm = record.imm;
279        core_cols.rd_data = record.rd_data;
280        core_cols.rs1_data = record.rs1_data;
281        core_cols.to_pc_least_sig_bit = record.to_pc_least_sig_bit;
282        core_cols.to_pc_limbs = record.to_pc_limbs.map(F::from_canonical_u32);
283        core_cols.imm_sign = record.imm_sign;
284        core_cols.is_valid = F::ONE;
285    }
286
287    fn air(&self) -> &Self::Air {
288        &self.air
289    }
290}
291
292// returns (to_pc, rd_data)
293pub(super) fn run_jalr(
294    _opcode: Rv32JalrOpcode,
295    pc: u32,
296    imm: u32,
297    rs1: u32,
298) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) {
299    let to_pc = rs1.wrapping_add(imm);
300    let to_pc = to_pc - (to_pc & 1);
301    assert!(to_pc < (1 << PC_BITS));
302    (
303        to_pc,
304        array::from_fn(|i: usize| ((pc + DEFAULT_PC_STEP) >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX),
305    )
306}