1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 fmt::Debug,
5};
6
7use openvm_circuit::{
8 arch::*,
9 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
10};
11use openvm_circuit_primitives::{AlignedBorrow, AlignedBytesBorrow};
12use openvm_instructions::{
13 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode,
14};
15use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
16use openvm_stark_backend::{
17 interaction::InteractionBuilder,
18 p3_air::{AirBuilder, BaseAir},
19 p3_field::{Field, FieldAlgebra, PrimeField32},
20 rap::BaseAirWithPublicValues,
21};
22
23use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
24
25#[derive(Debug, Clone, Copy)]
26enum InstructionOpcode {
27 LoadW0,
28 LoadHu0,
29 LoadHu2,
30 LoadBu0,
31 LoadBu1,
32 LoadBu2,
33 LoadBu3,
34 StoreW0,
35 StoreH0,
36 StoreH2,
37 StoreB0,
38 StoreB1,
39 StoreB2,
40 StoreB3,
41}
42
43use InstructionOpcode::*;
44
45#[repr(C)]
50#[derive(Debug, Clone, AlignedBorrow)]
51pub struct LoadStoreCoreCols<T, const NUM_CELLS: usize> {
52 pub flags: [T; 4],
53 pub is_valid: T,
55 pub is_load: T,
56
57 pub read_data: [T; NUM_CELLS],
58 pub prev_data: [T; NUM_CELLS],
59 pub write_data: [T; NUM_CELLS],
62}
63
64#[derive(Debug, Clone, derive_new::new)]
65pub struct LoadStoreCoreAir<const NUM_CELLS: usize> {
66 pub offset: usize,
67}
68
69impl<F: Field, const NUM_CELLS: usize> BaseAir<F> for LoadStoreCoreAir<NUM_CELLS> {
70 fn width(&self) -> usize {
71 LoadStoreCoreCols::<F, NUM_CELLS>::width()
72 }
73}
74
75impl<F: Field, const NUM_CELLS: usize> BaseAirWithPublicValues<F> for LoadStoreCoreAir<NUM_CELLS> {}
76
77impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for LoadStoreCoreAir<NUM_CELLS>
78where
79 AB: InteractionBuilder,
80 I: VmAdapterInterface<AB::Expr>,
81 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
82 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
83 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
84{
85 fn eval(
86 &self,
87 builder: &mut AB,
88 local_core: &[AB::Var],
89 _from_pc: AB::Var,
90 ) -> AdapterAirContext<AB::Expr, I> {
91 let cols: &LoadStoreCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
92 let LoadStoreCoreCols::<AB::Var, NUM_CELLS> {
93 read_data,
94 prev_data,
95 write_data,
96 flags,
97 is_valid,
98 is_load,
99 } = *cols;
100
101 let get_expr_12 = |x: &AB::Expr| (x.clone() - AB::Expr::ONE) * (x.clone() - AB::Expr::TWO);
102
103 builder.assert_bool(is_valid);
104 let sum = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
105 builder.assert_zero(flag * get_expr_12(&flag.into()));
106 acc + flag
107 });
108 builder.assert_zero(sum.clone() * get_expr_12(&sum));
109 builder.when(get_expr_12(&sum)).assert_zero(is_valid);
111
112 let inv_2 = AB::F::from_canonical_u32(2).inverse();
115 let mut opcode_flags = vec![];
116 for flag in flags {
117 opcode_flags.push(flag * (flag - AB::F::ONE) * inv_2);
118 }
119 for flag in flags {
120 opcode_flags.push(flag * (sum.clone() - AB::F::TWO) * AB::F::NEG_ONE);
121 }
122 (0..4).for_each(|i| {
123 ((i + 1)..4).for_each(|j| opcode_flags.push(flags[i] * flags[j]));
124 });
125
126 let opcode_when = |idxs: &[InstructionOpcode]| -> AB::Expr {
127 idxs.iter().fold(AB::Expr::ZERO, |acc, &idx| {
128 acc + opcode_flags[idx as usize].clone()
129 })
130 };
131
132 builder.assert_eq(
134 is_load,
135 opcode_when(&[LoadW0, LoadHu0, LoadHu2, LoadBu0, LoadBu1, LoadBu2, LoadBu3]),
136 );
137 builder.when(is_load).assert_one(is_valid);
138
139 for (i, cell) in write_data.iter().enumerate() {
152 let expected_load_val = if i == 0 {
154 opcode_when(&[LoadW0, LoadHu0, LoadBu0]) * read_data[0]
155 + opcode_when(&[LoadBu1]) * read_data[1]
156 + opcode_when(&[LoadHu2, LoadBu2]) * read_data[2]
157 + opcode_when(&[LoadBu3]) * read_data[3]
158 } else if i < NUM_CELLS / 2 {
159 opcode_when(&[LoadW0, LoadHu0]) * read_data[i]
160 + opcode_when(&[LoadHu2]) * read_data[i + 2]
161 } else {
162 opcode_when(&[LoadW0]) * read_data[i]
163 };
164
165 let expected_store_val = if i == 0 {
167 opcode_when(&[StoreW0, StoreH0, StoreB0]) * read_data[i]
168 + opcode_when(&[StoreH2, StoreB1, StoreB2, StoreB3]) * prev_data[i]
169 } else if i == 1 {
170 opcode_when(&[StoreB1]) * read_data[i - 1]
171 + opcode_when(&[StoreW0, StoreH0]) * read_data[i]
172 + opcode_when(&[StoreH2, StoreB0, StoreB2, StoreB3]) * prev_data[i]
173 } else if i == 2 {
174 opcode_when(&[StoreH2, StoreB2]) * read_data[i - 2]
175 + opcode_when(&[StoreW0]) * read_data[i]
176 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB3]) * prev_data[i]
177 } else if i == 3 {
178 opcode_when(&[StoreB3]) * read_data[i - 3]
179 + opcode_when(&[StoreH2]) * read_data[i - 2]
180 + opcode_when(&[StoreW0]) * read_data[i]
181 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB2]) * prev_data[i]
182 } else {
183 opcode_when(&[StoreW0]) * read_data[i]
184 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * prev_data[i]
185 + opcode_when(&[StoreH0])
186 * if i < NUM_CELLS / 2 {
187 read_data[i]
188 } else {
189 prev_data[i]
190 }
191 + opcode_when(&[StoreH2])
192 * if i - 2 < NUM_CELLS / 2 {
193 read_data[i - 2]
194 } else {
195 prev_data[i]
196 }
197 };
198 let expected_val = expected_load_val + expected_store_val;
199 builder.assert_eq(*cell, expected_val);
200 }
201
202 let expected_opcode = opcode_when(&[LoadW0]) * AB::Expr::from_canonical_u8(LOADW as u8)
203 + opcode_when(&[LoadHu0, LoadHu2]) * AB::Expr::from_canonical_u8(LOADHU as u8)
204 + opcode_when(&[LoadBu0, LoadBu1, LoadBu2, LoadBu3])
205 * AB::Expr::from_canonical_u8(LOADBU as u8)
206 + opcode_when(&[StoreW0]) * AB::Expr::from_canonical_u8(STOREW as u8)
207 + opcode_when(&[StoreH0, StoreH2]) * AB::Expr::from_canonical_u8(STOREH as u8)
208 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3])
209 * AB::Expr::from_canonical_u8(STOREB as u8);
210 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode);
211
212 let load_shift_amount = opcode_when(&[LoadBu1]) * AB::Expr::ONE
213 + opcode_when(&[LoadHu2, LoadBu2]) * AB::Expr::TWO
214 + opcode_when(&[LoadBu3]) * AB::Expr::from_canonical_u32(3);
215
216 let store_shift_amount = opcode_when(&[StoreB1]) * AB::Expr::ONE
217 + opcode_when(&[StoreH2, StoreB2]) * AB::Expr::TWO
218 + opcode_when(&[StoreB3]) * AB::Expr::from_canonical_u32(3);
219
220 AdapterAirContext {
221 to_pc: None,
222 reads: (prev_data, read_data.map(|x| x.into())).into(),
223 writes: [write_data.map(|x| x.into())].into(),
224 instruction: LoadStoreInstruction {
225 is_valid: is_valid.into(),
226 opcode: expected_opcode,
227 is_load: is_load.into(),
228 load_shift_amount,
229 store_shift_amount,
230 }
231 .into(),
232 }
233 }
234
235 fn start_offset(&self) -> usize {
236 self.offset
237 }
238}
239
240#[repr(C)]
241#[derive(AlignedBytesBorrow, Debug)]
242pub struct LoadStoreCoreRecord<const NUM_CELLS: usize> {
243 pub local_opcode: u8,
244 pub shift_amount: u8,
245 pub read_data: [u8; NUM_CELLS],
246 pub prev_data: [u32; NUM_CELLS],
248}
249
250#[derive(Clone, Copy, derive_new::new)]
251pub struct LoadStoreExecutor<A, const NUM_CELLS: usize> {
252 adapter: A,
253 pub offset: usize,
254}
255
256#[derive(Clone, derive_new::new)]
257pub struct LoadStoreFiller<
258 A = Rv32LoadStoreAdapterFiller,
259 const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
260> {
261 adapter: A,
262 pub offset: usize,
263}
264
265impl<F, A, RA, const NUM_CELLS: usize> PreflightExecutor<F, RA> for LoadStoreExecutor<A, NUM_CELLS>
266where
267 F: PrimeField32,
268 A: 'static
269 + AdapterTraceExecutor<
270 F,
271 ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
272 WriteData = [u32; NUM_CELLS],
273 >,
274 for<'buf> RA: RecordArena<
275 'buf,
276 EmptyAdapterCoreLayout<F, A>,
277 (A::RecordMut<'buf>, &'buf mut LoadStoreCoreRecord<NUM_CELLS>),
278 >,
279{
280 fn get_opcode_name(&self, opcode: usize) -> String {
281 format!(
282 "{:?}",
283 Rv32LoadStoreOpcode::from_usize(opcode - self.offset)
284 )
285 }
286
287 fn execute(
288 &self,
289 state: VmStateMut<F, TracingMemory, RA>,
290 instruction: &Instruction<F>,
291 ) -> Result<(), ExecutionError> {
292 let Instruction { opcode, .. } = instruction;
293
294 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
295
296 A::start(*state.pc, state.memory, &mut adapter_record);
297
298 (
299 (core_record.prev_data, core_record.read_data),
300 core_record.shift_amount,
301 ) = self
302 .adapter
303 .read(state.memory, instruction, &mut adapter_record);
304
305 let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
306 core_record.local_opcode = local_opcode as u8;
307
308 let write_data = run_write_data(
309 local_opcode,
310 core_record.read_data,
311 core_record.prev_data,
312 core_record.shift_amount as usize,
313 );
314 self.adapter
315 .write(state.memory, instruction, write_data, &mut adapter_record);
316
317 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
318
319 Ok(())
320 }
321}
322
323impl<F, A, const NUM_CELLS: usize> TraceFiller<F> for LoadStoreFiller<A, NUM_CELLS>
324where
325 F: PrimeField32,
326 A: 'static + AdapterTraceFiller<F>,
327{
328 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
329 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
332 self.adapter.fill_trace_row(mem_helper, adapter_row);
333 let record: &LoadStoreCoreRecord<NUM_CELLS> =
336 unsafe { get_record_from_slice(&mut core_row, ()) };
337 let core_row: &mut LoadStoreCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
338
339 let opcode = Rv32LoadStoreOpcode::from_usize(record.local_opcode as usize);
340 let shift = record.shift_amount;
341
342 let write_data = run_write_data(opcode, record.read_data, record.prev_data, shift as usize);
343 core_row.write_data = write_data.map(F::from_canonical_u32);
345 core_row.prev_data = record.prev_data.map(F::from_canonical_u32);
346 core_row.read_data = record.read_data.map(F::from_canonical_u8);
347 core_row.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode));
348 core_row.is_valid = F::ONE;
349 let flags = &mut core_row.flags;
350 *flags = [F::ZERO; 4];
351 match (opcode, shift) {
352 (LOADW, 0) => flags[0] = F::TWO,
353 (LOADHU, 0) => flags[1] = F::TWO,
354 (LOADHU, 2) => flags[2] = F::TWO,
355 (LOADBU, 0) => flags[3] = F::TWO,
356
357 (LOADBU, 1) => flags[0] = F::ONE,
358 (LOADBU, 2) => flags[1] = F::ONE,
359 (LOADBU, 3) => flags[2] = F::ONE,
360 (STOREW, 0) => flags[3] = F::ONE,
361
362 (STOREH, 0) => (flags[0], flags[1]) = (F::ONE, F::ONE),
363 (STOREH, 2) => (flags[0], flags[2]) = (F::ONE, F::ONE),
364 (STOREB, 0) => (flags[0], flags[3]) = (F::ONE, F::ONE),
365 (STOREB, 1) => (flags[1], flags[2]) = (F::ONE, F::ONE),
366 (STOREB, 2) => (flags[1], flags[3]) = (F::ONE, F::ONE),
367 (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE),
368 _ => unreachable!(),
369 };
370 }
371}
372
373#[inline(always)]
375pub(super) fn run_write_data<const NUM_CELLS: usize>(
376 opcode: Rv32LoadStoreOpcode,
377 read_data: [u8; NUM_CELLS],
378 prev_data: [u32; NUM_CELLS],
379 shift: usize,
380) -> [u32; NUM_CELLS] {
381 match (opcode, shift) {
382 (LOADW, 0) => {
383 read_data.map(|x| x as u32)
384 },
385 (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => {
386 let mut wrie_data = [0; NUM_CELLS];
387 wrie_data[0] = read_data[shift] as u32;
388 wrie_data
389 }
390 (LOADHU, 0) | (LOADHU, 2) => {
391 let mut write_data = [0; NUM_CELLS];
392 for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() {
393 *cell = read_data[i + shift] as u32;
394 }
395 write_data
396 }
397 (STOREW, 0) => {
398 read_data.map(|x| x as u32)
399 },
400 (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => {
401 let mut write_data = prev_data;
402 write_data[shift] = read_data[0] as u32;
403 write_data
404 }
405 (STOREH, 0) | (STOREH, 2) => {
406 array::from_fn(|i| {
407 if i >= shift && i < (NUM_CELLS / 2 + shift){
408 read_data[i - shift] as u32
409 } else {
410 prev_data[i]
411 }
412 })
413 }
414 _ => unreachable!(
418 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
419 ),
420 }
421}