openvm_rv32im_circuit/jal_lui/
core.rs
1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::arch::{
7 AdapterAirContext, AdapterRuntimeContext, ImmInstruction, Result, VmAdapterInterface,
8 VmCoreAir, VmCoreChip,
9};
10use openvm_circuit_primitives::bitwise_op_lookup::{
11 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
12};
13use openvm_circuit_primitives_derive::AlignedBorrow;
14use openvm_instructions::{
15 instruction::Instruction,
16 program::{DEFAULT_PC_STEP, PC_BITS},
17 LocalOpcode,
18};
19use openvm_rv32im_transpiler::Rv32JalLuiOpcode::{self, *};
20use openvm_stark_backend::{
21 interaction::InteractionBuilder,
22 p3_air::{AirBuilder, BaseAir},
23 p3_field::{Field, FieldAlgebra, PrimeField32},
24 rap::BaseAirWithPublicValues,
25};
26use serde::{Deserialize, Serialize};
27
28use crate::adapters::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS, RV_J_TYPE_IMM_BITS};
29
30#[repr(C)]
31#[derive(Debug, Clone, AlignedBorrow)]
32pub struct Rv32JalLuiCoreCols<T> {
33 pub imm: T,
34 pub rd_data: [T; RV32_REGISTER_NUM_LIMBS],
35 pub is_jal: T,
36 pub is_lui: T,
37}
38
39#[derive(Debug, Clone)]
40pub struct Rv32JalLuiCoreAir {
41 pub bus: BitwiseOperationLookupBus,
42}
43
44impl<F: Field> BaseAir<F> for Rv32JalLuiCoreAir {
45 fn width(&self) -> usize {
46 Rv32JalLuiCoreCols::<F>::width()
47 }
48}
49
50impl<F: Field> BaseAirWithPublicValues<F> for Rv32JalLuiCoreAir {}
51
52impl<AB, I> VmCoreAir<AB, I> for Rv32JalLuiCoreAir
53where
54 AB: InteractionBuilder,
55 I: VmAdapterInterface<AB::Expr>,
56 I::Reads: From<[[AB::Expr; 0]; 0]>,
57 I::Writes: From<[[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1]>,
58 I::ProcessedInstruction: From<ImmInstruction<AB::Expr>>,
59{
60 fn eval(
61 &self,
62 builder: &mut AB,
63 local_core: &[AB::Var],
64 from_pc: AB::Var,
65 ) -> AdapterAirContext<AB::Expr, I> {
66 let cols: &Rv32JalLuiCoreCols<AB::Var> = (*local_core).borrow();
67 let Rv32JalLuiCoreCols::<AB::Var> {
68 imm,
69 rd_data: rd,
70 is_jal,
71 is_lui,
72 } = *cols;
73
74 builder.assert_bool(is_lui);
75 builder.assert_bool(is_jal);
76 let is_valid = is_lui + is_jal;
77 builder.assert_bool(is_valid.clone());
78 builder.when(is_lui).assert_zero(rd[0]);
79
80 for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
81 self.bus
82 .send_range(rd[i * 2], rd[i * 2 + 1])
83 .eval(builder, is_valid.clone());
84 }
85
86 let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
89 let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
90 let additional_bits = AB::F::from_canonical_u32(additional_bits);
91 self.bus
92 .send_xor(rd[3], additional_bits, rd[3] + additional_bits)
93 .eval(builder, is_jal);
94
95 let intermed_val = rd
96 .iter()
97 .skip(1)
98 .enumerate()
99 .fold(AB::Expr::ZERO, |acc, (i, &val)| {
100 acc + val * AB::Expr::from_canonical_u32(1 << (i * RV32_CELL_BITS))
101 });
102
103 builder.when(is_lui).assert_eq(
105 intermed_val.clone(),
106 imm * AB::F::from_canonical_u32(1 << (12 - RV32_CELL_BITS)),
107 );
108
109 let intermed_val = rd[0] + intermed_val * AB::Expr::from_canonical_u32(1 << RV32_CELL_BITS);
110 builder.when(is_jal).assert_eq(
112 intermed_val,
113 from_pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP),
114 );
115
116 let to_pc = from_pc + is_lui * AB::F::from_canonical_u32(DEFAULT_PC_STEP) + is_jal * imm;
117
118 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(
119 self,
120 is_lui * AB::F::from_canonical_u32(LUI as u32)
121 + is_jal * AB::F::from_canonical_u32(JAL as u32),
122 );
123
124 AdapterAirContext {
125 to_pc: Some(to_pc),
126 reads: [].into(),
127 writes: [rd.map(|x| x.into())].into(),
128 instruction: ImmInstruction {
129 is_valid,
130 opcode: expected_opcode,
131 immediate: imm.into(),
132 }
133 .into(),
134 }
135 }
136
137 fn start_offset(&self) -> usize {
138 Rv32JalLuiOpcode::CLASS_OFFSET
139 }
140}
141
142#[repr(C)]
143#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(bound = "F: Field")]
145pub struct Rv32JalLuiCoreRecord<F: Field> {
146 pub rd_data: [F; RV32_REGISTER_NUM_LIMBS],
147 pub imm: F,
148 pub is_jal: bool,
149 pub is_lui: bool,
150}
151
152pub struct Rv32JalLuiCoreChip {
153 pub air: Rv32JalLuiCoreAir,
154 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
155}
156
157impl Rv32JalLuiCoreChip {
158 pub fn new(bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>) -> Self {
159 Self {
160 air: Rv32JalLuiCoreAir {
161 bus: bitwise_lookup_chip.bus(),
162 },
163 bitwise_lookup_chip,
164 }
165 }
166}
167
168impl<F: PrimeField32, I: VmAdapterInterface<F>> VmCoreChip<F, I> for Rv32JalLuiCoreChip
169where
170 I::Writes: From<[[F; RV32_REGISTER_NUM_LIMBS]; 1]>,
171{
172 type Record = Rv32JalLuiCoreRecord<F>;
173 type Air = Rv32JalLuiCoreAir;
174
175 #[allow(clippy::type_complexity)]
176 fn execute_instruction(
177 &self,
178 instruction: &Instruction<F>,
179 from_pc: u32,
180 _reads: I::Reads,
181 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
182 let local_opcode = Rv32JalLuiOpcode::from_usize(
183 instruction
184 .opcode
185 .local_opcode_idx(Rv32JalLuiOpcode::CLASS_OFFSET),
186 );
187 let imm = instruction.c;
188
189 let signed_imm = match local_opcode {
190 JAL => {
191 (imm + F::from_canonical_u32(1 << (RV_J_TYPE_IMM_BITS - 1))).as_canonical_u32()
193 as i32
194 - (1 << (RV_J_TYPE_IMM_BITS - 1))
195 }
196 LUI => imm.as_canonical_u32() as i32,
197 };
198 let (to_pc, rd_data) = run_jal_lui(local_opcode, from_pc, signed_imm);
199
200 for i in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
201 self.bitwise_lookup_chip
202 .request_range(rd_data[i * 2], rd_data[i * 2 + 1]);
203 }
204
205 if local_opcode == JAL {
206 let last_limb_bits = PC_BITS - RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
207 let additional_bits = (last_limb_bits..RV32_CELL_BITS).fold(0, |acc, x| acc + (1 << x));
208 self.bitwise_lookup_chip
209 .request_xor(rd_data[3], additional_bits);
210 }
211
212 let rd_data = rd_data.map(F::from_canonical_u32);
213
214 let output = AdapterRuntimeContext {
215 to_pc: Some(to_pc),
216 writes: [rd_data].into(),
217 };
218
219 Ok((
220 output,
221 Rv32JalLuiCoreRecord {
222 rd_data,
223 imm,
224 is_jal: local_opcode == JAL,
225 is_lui: local_opcode == LUI,
226 },
227 ))
228 }
229
230 fn get_opcode_name(&self, opcode: usize) -> String {
231 format!(
232 "{:?}",
233 Rv32JalLuiOpcode::from_usize(opcode - Rv32JalLuiOpcode::CLASS_OFFSET)
234 )
235 }
236
237 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
238 let core_cols: &mut Rv32JalLuiCoreCols<F> = row_slice.borrow_mut();
239 core_cols.rd_data = record.rd_data;
240 core_cols.imm = record.imm;
241 core_cols.is_jal = F::from_bool(record.is_jal);
242 core_cols.is_lui = F::from_bool(record.is_lui);
243 }
244
245 fn air(&self) -> &Self::Air {
246 &self.air
247 }
248}
249
250pub(super) fn run_jal_lui(
252 opcode: Rv32JalLuiOpcode,
253 pc: u32,
254 imm: i32,
255) -> (u32, [u32; RV32_REGISTER_NUM_LIMBS]) {
256 match opcode {
257 JAL => {
258 let rd_data = array::from_fn(|i| {
259 ((pc + DEFAULT_PC_STEP) >> (8 * i)) & ((1 << RV32_CELL_BITS) - 1)
260 });
261 let next_pc = pc as i32 + imm;
262 assert!(next_pc >= 0);
263 (next_pc as u32, rd_data)
264 }
265 LUI => {
266 let imm = imm as u32;
267 let rd = imm << 12;
268 let rd_data =
269 array::from_fn(|i| (rd >> (RV32_CELL_BITS * i)) & ((1 << RV32_CELL_BITS) - 1));
270 (pc + DEFAULT_PC_STEP, rd_data)
271 }
272 }
273}