1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4};
5
6use openvm_circuit::{
7 arch::*,
8 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
9};
10use openvm_circuit_primitives::{
11 utils::select,
12 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
13 AlignedBytesBorrow,
14};
15use openvm_circuit_primitives_derive::AlignedBorrow;
16use openvm_instructions::{
17 instruction::Instruction,
18 program::DEFAULT_PC_STEP,
19 riscv::{RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS},
20 LocalOpcode,
21};
22use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
23use openvm_stark_backend::{
24 interaction::InteractionBuilder,
25 p3_air::BaseAir,
26 p3_field::{Field, FieldAlgebra, PrimeField32},
27 rap::BaseAirWithPublicValues,
28};
29
30use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
31
32#[repr(C)]
39#[derive(Debug, Clone, AlignedBorrow)]
40pub struct LoadSignExtendCoreCols<T, const NUM_CELLS: usize> {
41 pub opcode_loadb_flag0: T,
43 pub opcode_loadb_flag1: T,
44 pub opcode_loadh_flag: T,
45
46 pub shift_most_sig_bit: T,
47 pub data_most_sig_bit: T,
49
50 pub shifted_read_data: [T; NUM_CELLS],
51 pub prev_data: [T; NUM_CELLS],
52}
53
54#[derive(Debug, Clone, derive_new::new)]
55pub struct LoadSignExtendCoreAir<const NUM_CELLS: usize, const LIMB_BITS: usize> {
56 pub range_bus: VariableRangeCheckerBus,
57}
58
59impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAir<F>
60 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
61{
62 fn width(&self) -> usize {
63 LoadSignExtendCoreCols::<F, NUM_CELLS>::width()
64 }
65}
66
67impl<F: Field, const NUM_CELLS: usize, const LIMB_BITS: usize> BaseAirWithPublicValues<F>
68 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
69{
70}
71
72impl<AB, I, const NUM_CELLS: usize, const LIMB_BITS: usize> VmCoreAir<AB, I>
73 for LoadSignExtendCoreAir<NUM_CELLS, LIMB_BITS>
74where
75 AB: InteractionBuilder,
76 I: VmAdapterInterface<AB::Expr>,
77 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
78 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
79 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
80{
81 fn eval(
82 &self,
83 builder: &mut AB,
84 local_core: &[AB::Var],
85 _from_pc: AB::Var,
86 ) -> AdapterAirContext<AB::Expr, I> {
87 let cols: &LoadSignExtendCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
88 let LoadSignExtendCoreCols::<AB::Var, NUM_CELLS> {
89 shifted_read_data,
90 prev_data,
91 opcode_loadb_flag0: is_loadb0,
92 opcode_loadb_flag1: is_loadb1,
93 opcode_loadh_flag: is_loadh,
94 data_most_sig_bit,
95 shift_most_sig_bit,
96 } = *cols;
97
98 let flags = [is_loadb0, is_loadb1, is_loadh];
99
100 let is_valid = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
101 builder.assert_bool(flag);
102 acc + flag
103 });
104
105 builder.assert_bool(is_valid.clone());
106 builder.assert_bool(data_most_sig_bit);
107 builder.assert_bool(shift_most_sig_bit);
108
109 let expected_opcode = (is_loadb0 + is_loadb1) * AB::F::from_canonical_u8(LOADB as u8)
110 + is_loadh * AB::F::from_canonical_u8(LOADH as u8)
111 + AB::Expr::from_canonical_usize(Rv32LoadStoreOpcode::CLASS_OFFSET);
112
113 let limb_mask = data_most_sig_bit * AB::Expr::from_canonical_u32((1 << LIMB_BITS) - 1);
114
115 let write_data: [AB::Expr; NUM_CELLS] = array::from_fn(|i| {
120 if i == 0 {
121 (is_loadh + is_loadb0) * shifted_read_data[i].into()
122 + is_loadb1 * shifted_read_data[i + 1].into()
123 } else if i < NUM_CELLS / 2 {
124 shifted_read_data[i] * is_loadh + (is_loadb0 + is_loadb1) * limb_mask.clone()
125 } else {
126 limb_mask.clone()
127 }
128 });
129
130 let most_sig_limb = shifted_read_data[0] * is_loadb0
132 + shifted_read_data[1] * is_loadb1
133 + shifted_read_data[NUM_CELLS / 2 - 1] * is_loadh;
134
135 self.range_bus
136 .range_check(
137 most_sig_limb
138 - data_most_sig_bit * AB::Expr::from_canonical_u32(1 << (LIMB_BITS - 1)),
139 LIMB_BITS - 1,
140 )
141 .eval(builder, is_valid.clone());
142
143 let read_data = array::from_fn(|i| {
145 select(
146 shift_most_sig_bit,
147 shifted_read_data[(i + NUM_CELLS - 2) % NUM_CELLS],
148 shifted_read_data[i],
149 )
150 });
151 let load_shift_amount = shift_most_sig_bit * AB::Expr::TWO + is_loadb1;
152
153 AdapterAirContext {
154 to_pc: None,
155 reads: (prev_data, read_data).into(),
156 writes: [write_data].into(),
157 instruction: LoadStoreInstruction {
158 is_valid: is_valid.clone(),
159 opcode: expected_opcode,
160 is_load: is_valid,
161 load_shift_amount,
162 store_shift_amount: AB::Expr::ZERO,
163 }
164 .into(),
165 }
166 }
167
168 fn start_offset(&self) -> usize {
169 Rv32LoadStoreOpcode::CLASS_OFFSET
170 }
171}
172
173#[repr(C)]
174#[derive(AlignedBytesBorrow, Debug)]
175pub struct LoadSignExtendCoreRecord<const NUM_CELLS: usize> {
176 pub is_byte: bool,
177 pub shift_amount: u8,
178 pub read_data: [u8; NUM_CELLS],
179 pub prev_data: [u8; NUM_CELLS],
180}
181
182#[derive(Clone, Copy, derive_new::new)]
183pub struct LoadSignExtendExecutor<A, const NUM_CELLS: usize, const LIMB_BITS: usize> {
184 adapter: A,
185}
186
187#[derive(Clone, derive_new::new)]
188pub struct LoadSignExtendFiller<
189 A = Rv32LoadStoreAdapterFiller,
190 const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
191 const LIMB_BITS: usize = RV32_CELL_BITS,
192> {
193 adapter: A,
194 pub range_checker_chip: SharedVariableRangeCheckerChip,
195}
196
197impl<F, A, RA, const NUM_CELLS: usize, const LIMB_BITS: usize> PreflightExecutor<F, RA>
198 for LoadSignExtendExecutor<A, NUM_CELLS, LIMB_BITS>
199where
200 F: PrimeField32,
201 A: 'static
202 + AdapterTraceExecutor<
203 F,
204 ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
205 WriteData = [u32; NUM_CELLS],
206 >,
207 for<'buf> RA: RecordArena<
208 'buf,
209 EmptyAdapterCoreLayout<F, A>,
210 (
211 A::RecordMut<'buf>,
212 &'buf mut LoadSignExtendCoreRecord<NUM_CELLS>,
213 ),
214 >,
215{
216 fn get_opcode_name(&self, opcode: usize) -> String {
217 format!(
218 "{:?}",
219 Rv32LoadStoreOpcode::from_usize(opcode - Rv32LoadStoreOpcode::CLASS_OFFSET)
220 )
221 }
222
223 fn execute(
224 &self,
225 state: VmStateMut<F, TracingMemory, RA>,
226 instruction: &Instruction<F>,
227 ) -> Result<(), ExecutionError> {
228 let Instruction { opcode, .. } = instruction;
229
230 let local_opcode = Rv32LoadStoreOpcode::from_usize(
231 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
232 );
233
234 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
235
236 A::start(*state.pc, state.memory, &mut adapter_record);
237
238 let tmp = self
239 .adapter
240 .read(state.memory, instruction, &mut adapter_record);
241
242 core_record.is_byte = local_opcode == LOADB;
243 core_record.prev_data = tmp.0 .0.map(|x| x as u8);
244 core_record.read_data = tmp.0 .1;
245 core_record.shift_amount = tmp.1;
246
247 let write_data = run_write_data_sign_extend(
248 local_opcode,
249 core_record.read_data,
250 core_record.shift_amount as usize,
251 );
252
253 self.adapter.write(
254 state.memory,
255 instruction,
256 write_data.map(u32::from),
257 &mut adapter_record,
258 );
259
260 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
261
262 Ok(())
263 }
264}
265
266impl<F, A, const NUM_CELLS: usize, const LIMB_BITS: usize> TraceFiller<F>
267 for LoadSignExtendFiller<A, NUM_CELLS, LIMB_BITS>
268where
269 F: PrimeField32,
270 A: 'static + AdapterTraceFiller<F>,
271{
272 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
273 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
276 self.adapter.fill_trace_row(mem_helper, adapter_row);
277 let record: &LoadSignExtendCoreRecord<NUM_CELLS> =
280 unsafe { get_record_from_slice(&mut core_row, ()) };
281
282 let core_row: &mut LoadSignExtendCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
283
284 let shift = record.shift_amount;
285 let most_sig_limb = if record.is_byte {
286 record.read_data[shift as usize]
287 } else {
288 record.read_data[NUM_CELLS / 2 - 1 + shift as usize]
289 };
290
291 let most_sig_bit = most_sig_limb & (1 << 7);
292 self.range_checker_chip
293 .add_count((most_sig_limb - most_sig_bit) as u32, 7);
294
295 core_row.prev_data = record.prev_data.map(F::from_canonical_u8);
296 core_row.shifted_read_data = record.read_data.map(F::from_canonical_u8);
297 core_row.shifted_read_data.rotate_left((shift & 2) as usize);
298
299 core_row.data_most_sig_bit = F::from_bool(most_sig_bit != 0);
300 core_row.shift_most_sig_bit = F::from_bool(shift & 2 == 2);
301 core_row.opcode_loadh_flag = F::from_bool(!record.is_byte);
302 core_row.opcode_loadb_flag1 = F::from_bool(record.is_byte && ((shift & 1) == 1));
303 core_row.opcode_loadb_flag0 = F::from_bool(record.is_byte && ((shift & 1) == 0));
304 }
305}
306
307#[inline(always)]
309pub(super) fn run_write_data_sign_extend<const NUM_CELLS: usize>(
310 opcode: Rv32LoadStoreOpcode,
311 read_data: [u8; NUM_CELLS],
312 shift: usize,
313) -> [u8; NUM_CELLS] {
314 match (opcode, shift) {
315 (LOADH, 0) | (LOADH, 2) => {
316 let ext = (read_data[NUM_CELLS / 2 - 1 + shift] >> 7) * u8::MAX;
317 array::from_fn(|i| {
318 if i < NUM_CELLS / 2 {
319 read_data[i + shift]
320 } else {
321 ext
322 }
323 })
324 }
325 (LOADB, 0) | (LOADB, 1) | (LOADB, 2) | (LOADB, 3) => {
326 let ext = (read_data[shift] >> 7) * u8::MAX;
327 array::from_fn(|i| {
328 if i == 0 {
329 read_data[i + shift]
330 } else {
331 ext
332 }
333 })
334 }
335 _ => unreachable!(
339 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
340 ),
341 }
342}