1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 utils::select,
12 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13 AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17 instruction::Instruction,
18 program::DEFAULT_PC_STEP,
19 riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
20 LocalOpcode,
21};
22use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
23use openvm_stark_backend::{
24 interaction::InteractionBuilder,
25 p3_air::BaseAir,
26 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
27 rap::BaseAirWithPublicValues,
28};
29
30use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
31
32#[repr(C)]
39#[derive(Debug, Clone, AlignedBorrow)]
40pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
41 pub opcode_loadb_flag0: T,
43 pub opcode_loadb_flag1: T,
44 pub opcode_loadh_flag: T,
45
46 pub shift_most_sig_bit: T,
47 pub data_most_sig_bit: T,
49
50 pub shifted_read_data: [T; NUM_CELLS],
51 pub prev_data: [T; NUM_CELLS],
52}
53
54#[derive(Debug, Clone, derive_new::new)]
55pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
56 pub range_bus: VariableRangeCheckerBus,
57}
58
59impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
60 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
61{
62 fn width(&self) -> usize {
63 LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
64 }
65}
66
67impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
68 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
69{
70}
71
72impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
73 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
74where
75 AB: InteractionBuilder,
76 I: VmAdapterInterface<AB::Expr>,
77 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
78 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
79 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
80{
81 fn eval(
82 &self,
83 builder: &mut AB,
84 local_core: &[AB::Var],
85 _from_pc: AB::Var,
86 ) -> AdapterAirContext<AB::Expr, I> {
87 let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
88 let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
89 shifted_read_data,
90 prev_data,
91 opcode_loadb_flag0: is_loadb0,
92 opcode_loadb_flag1: is_loadb1,
93 opcode_loadh_flag: is_loadh,
94 data_most_sig_bit,
95 shift_most_sig_bit,
96 } = *cols;
97
98 let flags = [is_loadb0, is_loadb1, is_loadh];
99
100 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
101 builder.assert_bool(flag);
102 acc + flag
103 });
104
105 builder.assert_bool(is_valid.clone());
106 builder.assert_bool(data_most_sig_bit);
107 builder.assert_bool(shift_most_sig_bit);
108
109 let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_u8(LOADB as u8)
110 + is_loadh * AB::F::from_u8(LOADH as u8)
111 + AB::Expr::from_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
112
113 let limb_mask = data_most_sig_bit * AB::Expr::from_u32((1 << LIMB_BITS) - 1);
114
115 let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
120 if i == 0 {
121 (is_loadh + is_loadb0) * shifted_read_data[i].into()
122 + is_loadb1 * shifted_read_data[i + 1].into()
123 } else if i < NUM_CELLS / 2 {
124 shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
125 } else {
126 limb_mask.clone()
127 }
128 });
129
130 let most_sig_limb = shifted_read_data[0] * is_loadb0
132 + shifted_read_data[1] * is_loadb1
133 + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
134
135 self.range_bus
136 .range_check(
137 most_sig_limb - data_most_sig_bit * AB::Expr::from_u32(1 << (LIMB_BITS - 1)),
138 LIMB_BITS - 1,
139 )
140 .eval(builder, is_valid.clone());
141
142 let read_data = array::from_fn(|i| {
144 select(
145 shift_most_sig_bit,
146 shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
147 shifted_read_data[i],
148 )
149 });
150 let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
151
152 AdapterAirContext {
153 to_pc: None,
154 reads: (prev_data, read_data).into(),
155 writes: [write_data].into(),
156 instruction: LoadStoreInstruction {
157 is_valid: is_valid.clone(),
158 opcode: expected_opcode,
159 is_load: is_valid,
160 load_shift_amount,
161 store_shift_amount: AB::Expr::ZERO,
162 }
163 .into(),
164 }
165 }
166
167 fn start_offset(&self) -> usize {
168 Rv32LoadStoreOpcode::CLASS_OFFSET
169 }
170}
171
172#[repr(C)]
173#[derive(AlignedBytesBorrow, Debug)]
174pub struct LoadSignExtendCoreRecord<const NUM_CELLS: usize> {
175 pub is_byte: bool,
176 pub shift_amount: u8,
177 pub read_data: [u8; NUM_CELLS],
178 pub prev_data: [u8; NUM_CELLS],
179}
180
181#[derive(Clone, Copy, derive_new::new)]
182pub struct LoadSignExtendExecutor<A, const NUM_CELLS: usize, const LIMB_BITS: usize> {
183 adapter: A,
184}
185
186#[derive(Clone, derive_new::new)]
187pub struct LoadSignExtendFiller<
188 A = Rv32LoadStoreAdapterFiller,
189 const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
190 const LIMB_BITS: usize = RV32_CELL_BITS,
191> {
192 adapter: A,
193 pub range_checker_chip: SharedVariableRangeCheckerChip,
194}
195
196impl<F, A, RA, const NUM_CELLS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
197 for LoadSignExtendExecutor<A, NUM_CELLS, LIMB_BITS>
198where
199 F: PrimeField32,
200 A: 'static
201 + AdapterTraceExecutor<
202 F,
203 ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
204 WriteData = [u32; NUM_CELLS],
205 >,
206 for<'buf> RA: RecordArena<
207 'buf,
208 EmptyAdapterCoreLayout<F, A>,
209 (
210 A::RecordMut<'buf>,
211 &'buf mut LoadSignExtendCoreRecord<NUM_CELLS>,
212 ),
213 >,
214{
215 fn get_opcode_name(&self, opcode: usize) -> String {
216 format!(
217 "{:?}",
218 Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
219 )
220 }
221
222 fn execute(
223 &self,
224 state: VmStateMut<F, TracingMemory, RA>,
225 instruction: &Instruction<F>,
226 ) -> Result<(), ExecutionError> {
227 let Instruction { opcode, .. } = instruction;
228
229 let local_opcode = Rv32LoadStoreOpcode::from_usize(
230 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
231 );
232
233 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
234
235 A::start(*state.pc, state.memory, &mut adapter_record);
236
237 let tmp = self
238 .adapter
239 .read(state.memory, instruction, &mut adapter_record);
240
241 core_record.is_byte = local_opcode == LOADB;
242 core_record.prev_data = tmp.0 .0.map(|x| x as u8);
243 core_record.read_data = tmp.0 .1;
244 core_record.shift_amount = tmp.1;
245
246 let write_data = run_write_data_sign_extend(
247 local_opcode,
248 core_record.read_data,
249 core_record.shift_amount as usize,
250 );
251
252 self.adapter.write(
253 state.memory,
254 instruction,
255 write_data.map(u32::from),
256 &mut adapter_record,
257 );
258
259 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
260
261 Ok(())
262 }
263}
264
265impl<F, A, const NUM_CELLS: usize, const LIMB_BITS: usize> TraceFiller<F>
266 for LoadSignExtendFiller<A, NUM_CELLS, LIMB_BITS>
267where
268 F: PrimeField32,
269 A: 'static + AdapterTraceFiller<F>,
270{
271 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
272 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
275 self.adapter.fill_trace_row(mem_helper, adapter_row);
276 let record: &LoadSignExtendCoreRecord<NUM_CELLS> =
279 unsafe { get_record_from_slice(&mut core_row, ()) };
280
281 let core_row: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
282
283 let shift = record.shift_amount;
284 let most_sig_limb = if record.is_byte {
285 record.read_data[shift as usize]
286 } else {
287 record.read_data[NUM_CELLS / 2 - 1 + shift as usize]
288 };
289
290 let most_sig_bit = most_sig_limb & (1 << 7);
291 self.range_checker_chip
292 .add_count((most_sig_limb - most_sig_bit) as u32, 7);
293
294 core_row.prev_data = record.prev_data.map(F::from_u8);
295 core_row.shifted_read_data = record.read_data.map(F::from_u8);
296 core_row.shifted_read_data.rotate_left((shift & 2) as usize);
297
298 core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0);
299 core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2);
300 core_row.opcode_loadh_flag = F::from_bool(!record.is_byte);
301 core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1));
302 core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0));
303 }
304}
305
306#[inline(always)]
308pub(super) fn run_write_data_sign_extend<const NUM_CELLS: usize>(
309 opcode: Rv32LoadStoreOpcode,
310 read_data: [u8; NUM_CELLS],
311 shift: usize,
312) -> [u8; NUM_CELLS] {
313 match (opcode, shift) {
314 (LOADH, 0) | (LOADH, 2) => {
315 let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX;
316 array::from_fn(|i| {
317 if i < NUM_CELLS / 2 {
318 read_data[i + shift]
319 } else {
320 ext
321 }
322 })
323 }
324 (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
325 let ext = (read_data[shift] >> 7) * u8::MAX;
326 array::from_fn(|i| {
327 if i == 0 {
328 read_data[i + shift]
329 } else {
330 ext
331 }
332 })
333 }
334 _ => unreachable!(
338 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
339 ),
340 }
341}