1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
12 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13 AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17 instruction::Instruction,
18 program::{DEFAULT_PC_STEP, PC_BITS},
19 LocalOpcode,
20};
21use openvm_rv32im_transpiler::Rv32JalrOpcode::{self, *};
22use openvm_stark_backend::{
23 interaction::InteractionBuilder,
24 p3_air::{AirBuilder, BaseAir},
25 p3_field::{Field, FieldAlgebra, PrimeField32},
26 rap::BaseAirWithPublicValues,
27};
28
29use crate::adapters::{
30 Rv32JalrAdapterExecutor, Rv32JalrAdapterFiller, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
31};
32
33#[repr(C)]
34#[derive(Debug, Clone, AlignedBorrow)]
35pub struct Rv32JalrCoreCols<T> {
36 pub imm: T,
37 pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
38 pub rd_data: [T; RV32_REGISTER_NUM_LIMBS - 1],
41 pub is_valid: T,
42
43 pub to_pc_least_sig_bit: T,
44 pub to_pc_limbs: [T; 2],
46 pub imm_sign: T,
47}
48
49#[derive(Debug, Clone, derive_new::new)]
50pub struct Rv32JalrCoreAir {
51 pub bitwise_lookup_bus: BitwiseOperationLookupBus,
52 pub range_bus: VariableRangeCheckerBus,
53}
54
55impl<F: Field> BaseAir<F> for Rv32JalrCoreAir {
56 fn width(&self) -> usize {
57 Rv32JalrCoreCols::<F>::width()
58 }
59}
60
61impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalrCoreAir {}
62
63impl<AB, I> VmCoreAir<AB, I> for Rv32JalrCoreAir
64where
65 AB: InteractionBuilder,
66 I: VmAdapterInterface<AB::Expr>,
67 I::Reads: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
68 I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
69 I::ProcessedInstruction: From<SignedImmInstruction<AB::Expr>>,
70{
71 fn eval(
72 &self,
73 builder: &mut AB,
74 local_core: &[AB::Var],
75 from_pc: AB::Var,
76 ) -> AdapterAirContext<AB::Expr, I> {
77 let cols: &Rv32JalrCoreCols<AB::Var> = (*local_core).borrow();
78 let Rv32JalrCoreCols::<AB::Var> {
79 imm,
80 rs1_data: rs1,
81 rd_data: rd,
82 is_valid,
83 imm_sign,
84 to_pc_least_sig_bit,
85 to_pc_limbs,
86 } = *cols;
87
88 builder.assert_bool(is_valid);
89
90 let composed = rd
92 .iter()
93 .enumerate()
94 .fold(AB::Expr::ZERO, |acc, (i, &val)| {
95 acc + val * AB::Expr::from_canonical_u32(1 << ((i + 1) * RV32_CELL_BITS))
96 });
97
98 let least_sig_limb = from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP) - composed;
99
100 let rd_data = array::from_fn(|i| {
107 if i == 0 {
108 least_sig_limb.clone()
109 } else {
110 rd[i - 1].into().clone()
111 }
112 });
113
114 self.bitwise_lookup_bus
117 .send_range(rd_data[0].clone(), rd_data[1].clone())
118 .eval(builder, is_valid);
119 self.range_bus
120 .range_check(rd_data[2].clone(), RV32_CELL_BITS)
121 .eval(builder, is_valid);
122 self.range_bus
123 .range_check(rd_data[3].clone(), PC_BITS - RV32_CELL_BITS * 3)
124 .eval(builder, is_valid);
125
126 builder.assert_bool(imm_sign);
127
128 let rs1_limbs_01 = rs1[0] + rs1[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
131 let rs1_limbs_23 = rs1[2] + rs1[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
132 let inv = AB::F::from_canonical_u32(1 << 16).inverse();
133
134 builder.assert_bool(to_pc_least_sig_bit);
135 let carry = (rs1_limbs_01 + imm - to_pc_limbs[0] * AB::F::TWO - to_pc_least_sig_bit) * inv;
136 builder.when(is_valid).assert_bool(carry.clone());
137
138 let imm_extend_limb = imm_sign * AB::F::from_canonical_u32((1 << 16) - 1);
139 let carry = (rs1_limbs_23 + imm_extend_limb + carry - to_pc_limbs[1]) * inv;
140 builder.when(is_valid).assert_bool(carry);
141
142 self.range_bus
144 .range_check(to_pc_limbs[1], PC_BITS - 16)
145 .eval(builder, is_valid);
146 self.range_bus
147 .range_check(to_pc_limbs[0], 15)
148 .eval(builder, is_valid);
149 let to_pc =
150 to_pc_limbs[0] * AB::F::TWO + to_pc_limbs[1] * AB::F::from_canonical_u32(1 << 16);
151
152 let expected_opcode = VmCoreAir::<AB, I>::opcode_to_global_expr(self, JALR);
153
154 AdapterAirContext {
155 to_pc: Some(to_pc),
156 reads: [rs1.map(|x| x.into())].into(),
157 writes: [rd_data].into(),
158 instruction: SignedImmInstruction {
159 is_valid: is_valid.into(),
160 opcode: expected_opcode,
161 immediate: imm.into(),
162 imm_sign: imm_sign.into(),
163 }
164 .into(),
165 }
166 }
167
168 fn start_offset(&self) -> usize {
169 Rv32JalrOpcode::CLASS_OFFSET
170 }
171}
172
173#[repr(C)]
174#[derive(AlignedBytesBorrow, Debug)]
175pub struct Rv32JalrCoreRecord {
176 pub imm: u16,
177 pub from_pc: u32,
178 pub rs1_val: u32,
179 pub imm_sign: bool,
180}
181
182#[derive(Clone, Copy, derive_new::new)]
183pub struct Rv32JalrExecutor<A = Rv32JalrAdapterExecutor> {
184 adapter: A,
185}
186
187#[derive(Clone)]
188pub struct Rv32JalrFiller<A = Rv32JalrAdapterFiller> {
189 adapter: A,
190 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
191 pub range_checker_chip: SharedVariableRangeCheckerChip,
192}
193
194impl<A> Rv32JalrFiller<A> {
195 pub fn new(
196 adapter: A,
197 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
198 range_checker_chip: SharedVariableRangeCheckerChip,
199 ) -> Self {
200 assert!(range_checker_chip.range_max_bits() >= 16);
201 Self {
202 adapter,
203 bitwise_lookup_chip,
204 range_checker_chip,
205 }
206 }
207}
208
209impl<F, A, RA> PreflightExecutor<F, RA> for Rv32JalrExecutor<A>
210where
211 F: PrimeField32,
212 A: 'static
213 + AdapterTraceExecutor<
214 F,
215 ReadData = [u8; RV32_REGISTER_NUM_LIMBS],
216 WriteData = [u8; RV32_REGISTER_NUM_LIMBS],
217 >,
218 for<'buf> RA: RecordArena<
219 'buf,
220 EmptyAdapterCoreLayout<F, A>,
221 (A::RecordMut<'buf>, &'buf mut Rv32JalrCoreRecord),
222 >,
223{
224 fn get_opcode_name(&self, opcode: usize) -> String {
225 format!(
226 "{:?}",
227 Rv32JalrOpcode::from_usize(opcode - Rv32JalrOpcode::CLASS_OFFSET)
228 )
229 }
230
231 fn execute(
232 &self,
233 state: VmStateMut<F, TracingMemory, RA>,
234 instruction: &Instruction<F>,
235 ) -> Result<(), ExecutionError> {
236 let Instruction { opcode, c, g, .. } = *instruction;
237
238 debug_assert_eq!(
239 opcode.local_opcode_idx(Rv32JalrOpcode::CLASS_OFFSET),
240 JALR as usize
241 );
242
243 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
244
245 A::start(*state.pc, state.memory, &mut adapter_record);
246
247 core_record.rs1_val = u32::from_le_bytes(self.adapter.read(
248 state.memory,
249 instruction,
250 &mut adapter_record,
251 ));
252
253 core_record.imm = c.as_canonical_u32() as u16;
254 core_record.imm_sign = g.is_one();
255 core_record.from_pc = *state.pc;
256
257 let (to_pc, rd_data) = run_jalr(
258 core_record.from_pc,
259 core_record.rs1_val,
260 core_record.imm,
261 core_record.imm_sign,
262 );
263
264 self.adapter
265 .write(state.memory, instruction, rd_data, &mut adapter_record);
266
267 *state.pc = to_pc & !1;
269
270 Ok(())
271 }
272}
273impl<F, A> TraceFiller<F> for Rv32JalrFiller<A>
274where
275 F: PrimeField32,
276 A: 'static + AdapterTraceFiller<F>,
277{
278 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
279 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
282 self.adapter.fill_trace_row(mem_helper, adapter_row);
283 let record: &Rv32JalrCoreRecord = unsafe { get_record_from_slice(&mut core_row, ()) };
286
287 let core_row: &mut Rv32JalrCoreCols<F> = core_row.borrow_mut();
288
289 let (to_pc, rd_data) =
290 run_jalr(record.from_pc, record.rs1_val, record.imm, record.imm_sign);
291 let to_pc_limbs = [(to_pc & ((1 << 16) - 1)) >> 1, to_pc >> 16];
292 self.range_checker_chip.add_count(to_pc_limbs[0], 15);
293 self.range_checker_chip
294 .add_count(to_pc_limbs[1], PC_BITS - 16);
295 self.bitwise_lookup_chip
296 .request_range(rd_data[0] as u32, rd_data[1] as u32);
297
298 self.range_checker_chip
299 .add_count(rd_data[2] as u32, RV32_CELL_BITS);
300 self.range_checker_chip
301 .add_count(rd_data[3] as u32, PC_BITS - RV32_CELL_BITS * 3);
302
303 core_row.imm_sign = F::from_bool(record.imm_sign);
305 core_row.to_pc_limbs = to_pc_limbs.map(F::from_canonical_u32);
306 core_row.to_pc_least_sig_bit = F::from_bool(to_pc & 1 == 1);
307 core_row.is_valid = F::ONE;
309 core_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8);
310 core_row
311 .rd_data
312 .iter_mut()
313 .rev()
314 .zip(rd_data.iter().skip(1).rev())
315 .for_each(|(dst, src)| {
316 *dst = F::from_canonical_u8(*src);
317 });
318 core_row.imm = F::from_canonical_u16(record.imm);
319 }
320}
321
322#[inline(always)]
324pub(super) fn run_jalr(pc: u32, rs1: u32, imm: u16, imm_sign: bool) -> (u32, [u8; 4]) {
325 let to_pc = rs1.wrapping_add(imm as u32 + (imm_sign as u32 * 0xffff0000));
326 assert!(to_pc < (1 << PC_BITS));
327 (to_pc, pc.wrapping_add(DEFAULT_PC_STEP).to_le_bytes())
328}