openvm_rv32im_circuit/jalr/
core.rs
1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, Result, SignedImmInstruction, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13};
14use openvm_circuit_primitives_derive::AlignedBorrow;
15use openvm_instructions::{
16 instruction::Instruction,
17 program::{DEFAULT_PC_STEP, PC_BITS},
18 LocalOpcode,
19};
20use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
21use openvm_stark_backend::{
22 interaction::InteractionBuilder,
23 p3_air::{AirBuilder, BaseAir},
24 p3_field::{Field, FieldAlgebra, PrimeField32},
25 rap::BaseAirWithPublicValues,
26};
27use serde::{Deserialize, Serialize};
28
29use crate::adapters::{compose, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
30
31const RV32_LIMB_MAX: u32 = (1 << RV32_CELL_BITS) - 1;
32
33#[repr(C)]
34#[derive(Debug, Clone, AlignedBorrow)]
35pub struct Rv32JalrCoreCols<T> {
36 pub imm: T,
37 pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
38 pub rd_data: [T; RV32_REGISTER_NUM_LIMBS - 1],
41 pub is_valid: T,
42
43 pub to_pc_least_sig_bit: T,
44 pub to_pc_limbs: [T; 2],
46 pub imm_sign: T,
47}
48
49#[repr(C)]
50#[derive(Serialize, Deserialize)]
51pub struct Rv32JalrCoreRecord<F> {
52 pub imm: F,
53 pub rs1_data: [F; RV32_REGISTER_NUM_LIMBS],
54 pub rd_data: [F; RV32_REGISTER_NUM_LIMBS - 1],
55 pub to_pc_least_sig_bit: F,
56 pub to_pc_limbs: [u32; 2],
57 pub imm_sign: F,
58}
59
60#[derive(Debug, Clone)]
61pub struct Rv32JalrCoreAir {
62 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
63 pub range_bus: VariableRangeCheckerBus,
64}
65
66impl<F: Field> BaseAir<F> for Rv32JalrCoreAir {
67 fn width(&self) -> usize {
68 Rv32JalrCoreCols::<F>::width()
69 }
70}
71
72impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalrCoreAir {}
73
74impl<AB, I> VmCoreAir<AB, I> for Rv32JalrCoreAir
75where
76 AB: InteractionBuilder,
77 I: VmAdapterInterface<AB::Expr>,
78 I::Reads: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
79 I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
80 I::ProcessedInstruction: From<SignedImmInstruction<AB::Expr>>,
81{
82 fn eval(
83 &self,
84 builder: &mut AB,
85 local_core: &[AB::Var],
86 from_pc: AB::Var,
87 ) -> AdapterAirContext<AB::Expr, I> {
88 let cols: &Rv32JalrCoreCols<AB::Var> = (*local_core).borrow();
89 let Rv32JalrCoreCols::<AB::Var> {
90 imm,
91 rs1_data: rs1,
92 rd_data: rd,
93 is_valid,
94 imm_sign,
95 to_pc_least_sig_bit,
96 to_pc_limbs,
97 } = *cols;
98
99 builder.assert_bool(is_valid);
100
101 let composed = rd
103 .iter()
104 .enumerate()
105 .fold(AB::Expr::ZERO, |acc, (i, &val)| {
106 acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
107 });
108
109 let least_sig_limb = from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP) - composed;
110
111 let rd_data = array::from_fn(|i| {
117 if i == 0 {
118 least_sig_limb.clone()
119 } else {
120 rd[i - 1].into().clone()
121 }
122 });
123
124 self.bitwise_lookup_bus
127 .send_range(rd_data[0].clone(), rd_data[1].clone())
128 .eval(builder, is_valid);
129 self.range_bus
130 .range_check(rd_data[2].clone(), RV32_CELL_BITS)
131 .eval(builder, is_valid);
132 self.range_bus
133 .range_check(rd_data[3].clone(), PC_BITS - RV32_CELL_BITS * 3)
134 .eval(builder, is_valid);
135
136 builder.assert_bool(imm_sign);
137
138 let rs1_limbs_01 = rs1[0] + rs1[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
141 let rs1_limbs_23 = rs1[2] + rs1[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
142 let inv = AB::F::from_canonical_u32(1 << 16).inverse();
143
144 builder.assert_bool(to_pc_least_sig_bit);
145 let carry = (rs1_limbs_01 + imm - to_pc_limbs[0] * AB::F::TWO - to_pc_least_sig_bit) * inv;
146 builder.when(is_valid).assert_bool(carry.clone());
147
148 let imm_extend_limb = imm_sign * AB::F::from_canonical_u32((1 << 16) - 1);
149 let carry = (rs1_limbs_23 + imm_extend_limb + carry - to_pc_limbs[1]) * inv;
150 builder.when(is_valid).assert_bool(carry);
151
152 self.range_bus
154 .range_check(to_pc_limbs[1], PC_BITS - 16)
155 .eval(builder, is_valid);
156 self.range_bus
157 .range_check(to_pc_limbs[0], 15)
158 .eval(builder, is_valid);
159 let to_pc =
160 to_pc_limbs[0] * AB::F::TWO + to_pc_limbs[1] * AB::F::from_canonical_u32(1 << 16);
161
162 let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, JALR);
163
164 AdapterAirContext {
165 to_pc: Some(to_pc),
166 reads: [rs1.map(|x| x.into())].into(),
167 writes: [rd_data].into(),
168 instruction: SignedImmInstruction {
169 is_valid: is_valid.into(),
170 opcode: expected_opcode,
171 immediate: imm.into(),
172 imm_sign: imm_sign.into(),
173 }
174 .into(),
175 }
176 }
177
178 fn start_offset(&self) -> usize {
179 Rv32JalrOpcode::CLASS_OFFSET
180 }
181}
182
183pub struct Rv32JalrCoreChip {
184 pub air: Rv32JalrCoreAir,
185 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
186 pub range_checker_chip: SharedVariableRangeCheckerChip,
187}
188
189impl Rv32JalrCoreChip {
190 pub fn new(
191 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
192 range_checker_chip: SharedVariableRangeCheckerChip,
193 ) -> Self {
194 assert!(range_checker_chip.range_max_bits() >= 16);
195 Self {
196 air: Rv32JalrCoreAir {
197 bitwise_lookup_bus: bitwise_lookup_chip.bus(),
198 range_bus: range_checker_chip.bus(),
199 },
200 bitwise_lookup_chip,
201 range_checker_chip,
202 }
203 }
204}
205
206impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for Rv32JalrCoreChip
207where
208 I::Reads: Into<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
209 I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
210{
211 type Record = Rv32JalrCoreRecord<F>;
212 type Air = Rv32JalrCoreAir;
213
214 #[allow(clippy::type_complexity)]
215 fn execute_instruction(
216 &self,
217 instruction: &Instruction<F>,
218 from_pc: u32,
219 reads: I::Reads,
220 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
221 let Instruction { opcode, c, g, .. } = *instruction;
222 let local_opcode =
223 Rv32JalrOpcode::from_usize(opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET));
224
225 let imm = c.as_canonical_u32();
226 let imm_sign = g.as_canonical_u32();
227 let imm_extended = imm + imm_sign * 0xffff0000;
228
229 let rs1 = reads.into()[0];
230 let rs1_val = compose(rs1);
231
232 let (to_pc, rd_data) = run_jalr(local_opcode, from_pc, imm_extended, rs1_val);
233
234 self.bitwise_lookup_chip
235 .request_range(rd_data[0], rd_data[1]);
236 self.range_checker_chip
237 .add_count(rd_data[2], RV32_CELL_BITS);
238 self.range_checker_chip
239 .add_count(rd_data[3], PC_BITS - RV32_CELL_BITS * 3);
240
241 let mask = (1 << 15) - 1;
242 let to_pc_least_sig_bit = rs1_val.wrapping_add(imm_extended) & 1;
243
244 let to_pc_limbs = array::from_fn(|i| ((to_pc >> (1 + i * 15)) & mask));
245
246 let rd_data = rd_data.map(F::from_canonical_u32);
247
248 let output = AdapterRuntimeContext {
249 to_pc: Some(to_pc),
250 writes: [rd_data].into(),
251 };
252
253 Ok((
254 output,
255 Rv32JalrCoreRecord {
256 imm: c,
257 rd_data: array::from_fn(|i| rd_data[i + 1]),
258 rs1_data: rs1,
259 to_pc_least_sig_bit: F::from_canonical_u32(to_pc_least_sig_bit),
260 to_pc_limbs,
261 imm_sign: g,
262 },
263 ))
264 }
265
266 fn get_opcode_name(&self, opcode: usize) -> String {
267 format!(
268 "{:?}",
269 Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET)
270 )
271 }
272
273 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
274 self.range_checker_chip.add_count(record.to_pc_limbs[0], 15);
275 self.range_checker_chip.add_count(record.to_pc_limbs[1], 14);
276
277 let core_cols: &mut Rv32JalrCoreCols<F> = row_slice.borrow_mut();
278 core_cols.imm = record.imm;
279 core_cols.rd_data = record.rd_data;
280 core_cols.rs1_data = record.rs1_data;
281 core_cols.to_pc_least_sig_bit = record.to_pc_least_sig_bit;
282 core_cols.to_pc_limbs = record.to_pc_limbs.map(F::from_canonical_u32);
283 core_cols.imm_sign = record.imm_sign;
284 core_cols.is_valid = F::ONE;
285 }
286
287 fn air(&self) -> &Self::Air {
288 &self.air
289 }
290}
291
292pub(super) fn run_jalr(
294 _opcode: Rv32JalrOpcode,
295 pc: u32,
296 imm: u32,
297 rs1: u32,
298) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) {
299 let to_pc = rs1.wrapping_add(imm);
300 let to_pc = to_pc - (to_pc & 1);
301 assert!(to_pc < (1 << PC_BITS));
302 (
303 to_pc,
304 array::from_fn(|i: usize| ((pc + DEFAULT_PC_STEP) >> (RV32_CELL_BITS * i)) & RV32_LIMB_MAX),
305 )
306}