1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 fmt::Debug,
5};
6
7use openvm_circuit::{
8 arch::*,
9 system::memory::{online::TracingMemory, MemoryAuxColsFactory},
10};
11use openvm_circuit_primitives::{AlignedBorrow, AlignedBytesBorrow};
12use openvm_instructions::{
13 instruction::Instruction, program::DEFAULT_PC_STEP, riscv::RV32_REGISTER_NUM_LIMBS, LocalOpcode,
14};
15use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
16use openvm_stark_backend::{
17 interaction::InteractionBuilder,
18 p3_air::{AirBuilder, BaseAir},
19 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
20 rap::BaseAirWithPublicValues,
21};
22
23use crate::adapters::{LoadStoreInstruction, Rv32LoadStoreAdapterFiller};
24
25#[derive(Debug, Clone, Copy)]
26enum InstructionOpcode {
27 LoadW0,
28 LoadHu0,
29 LoadHu2,
30 LoadBu0,
31 LoadBu1,
32 LoadBu2,
33 LoadBu3,
34 StoreW0,
35 StoreH0,
36 StoreH2,
37 StoreB0,
38 StoreB1,
39 StoreB2,
40 StoreB3,
41}
42
43use InstructionOpcode::*;
44
45#[repr(C)]
50#[derive(Debug, Clone, AlignedBorrow)]
51pub struct LoadStoreCoreCols<T, const NUM_CELLS: usize> {
52 pub flags: [T; 4],
53 pub is_valid: T,
55 pub is_load: T,
56
57 pub read_data: [T; NUM_CELLS],
58 pub prev_data: [T; NUM_CELLS],
59 pub write_data: [T; NUM_CELLS],
62}
63
64#[derive(Debug, Clone, derive_new::new)]
65pub struct LoadStoreCoreAir<const NUM_CELLS: usize> {
66 pub offset: usize,
67}
68
69impl<F: Field, const NUM_CELLS: usize> BaseAir<F> for LoadStoreCoreAir<NUM_CELLS> {
70 fn width(&self) -> usize {
71 LoadStoreCoreCols::<F, NUM_CELLS>::width()
72 }
73}
74
75impl<F: Field, const NUM_CELLS: usize> BaseAirWithPublicValues<F> for LoadStoreCoreAir<NUM_CELLS> {}
76
77impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for LoadStoreCoreAir<NUM_CELLS>
78where
79 AB: InteractionBuilder,
80 I: VmAdapterInterface<AB::Expr>,
81 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
82 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
83 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
84{
85 fn eval(
86 &self,
87 builder: &mut AB,
88 local_core: &[AB::Var],
89 _from_pc: AB::Var,
90 ) -> AdapterAirContext<AB::Expr, I> {
91 let cols: &LoadStoreCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
92 let LoadStoreCoreCols::<AB::Var, NUM_CELLS> {
93 read_data,
94 prev_data,
95 write_data,
96 flags,
97 is_valid,
98 is_load,
99 } = *cols;
100
101 let get_expr_12 = |x: &AB::Expr| (x.clone() - AB::Expr::ONE) * (x.clone() - AB::Expr::TWO);
102
103 builder.assert_bool(is_valid);
104 let sum = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
105 builder.assert_zero(flag * get_expr_12(&flag.into()));
106 acc + flag
107 });
108 builder.assert_zero(sum.clone() * get_expr_12(&sum));
109 builder.when(get_expr_12(&sum)).assert_zero(is_valid);
111
112 let inv_2 = AB::F::from_u32(2).inverse();
115 let mut opcode_flags = vec![];
116 for flag in flags {
117 opcode_flags.push(flag * (flag - AB::F::ONE) * inv_2);
118 }
119 for flag in flags {
120 opcode_flags.push(flag * (sum.clone() - AB::F::TWO) * AB::F::NEG_ONE);
121 }
122 (0..4).for_each(|i| {
123 ((i + 1)..4).for_each(|j| opcode_flags.push(flags[i] * flags[j]));
124 });
125
126 let opcode_when = |idxs: &[InstructionOpcode]| -> AB::Expr {
127 idxs.iter().fold(AB::Expr::ZERO, |acc, &idx| {
128 acc + opcode_flags[idx as usize].clone()
129 })
130 };
131
132 builder.assert_eq(
134 is_load,
135 opcode_when(&[LoadW0, LoadHu0, LoadHu2, LoadBu0, LoadBu1, LoadBu2, LoadBu3]),
136 );
137 builder.when(is_load).assert_one(is_valid);
138
139 for (i, cell) in write_data.iter().enumerate() {
152 let expected_load_val = if i == 0 {
154 opcode_when(&[LoadW0, LoadHu0, LoadBu0]) * read_data[0]
155 + opcode_when(&[LoadBu1]) * read_data[1]
156 + opcode_when(&[LoadHu2, LoadBu2]) * read_data[2]
157 + opcode_when(&[LoadBu3]) * read_data[3]
158 } else if i < NUM_CELLS / 2 {
159 opcode_when(&[LoadW0, LoadHu0]) * read_data[i]
160 + opcode_when(&[LoadHu2]) * read_data[i + 2]
161 } else {
162 opcode_when(&[LoadW0]) * read_data[i]
163 };
164
165 let expected_store_val = if i == 0 {
167 opcode_when(&[StoreW0, StoreH0, StoreB0]) * read_data[i]
168 + opcode_when(&[StoreH2, StoreB1, StoreB2, StoreB3]) * prev_data[i]
169 } else if i == 1 {
170 opcode_when(&[StoreB1]) * read_data[i - 1]
171 + opcode_when(&[StoreW0, StoreH0]) * read_data[i]
172 + opcode_when(&[StoreH2, StoreB0, StoreB2, StoreB3]) * prev_data[i]
173 } else if i == 2 {
174 opcode_when(&[StoreH2, StoreB2]) * read_data[i - 2]
175 + opcode_when(&[StoreW0]) * read_data[i]
176 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB3]) * prev_data[i]
177 } else if i == 3 {
178 opcode_when(&[StoreB3]) * read_data[i - 3]
179 + opcode_when(&[StoreH2]) * read_data[i - 2]
180 + opcode_when(&[StoreW0]) * read_data[i]
181 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB2]) * prev_data[i]
182 } else {
183 opcode_when(&[StoreW0]) * read_data[i]
184 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * prev_data[i]
185 + opcode_when(&[StoreH0])
186 * if i < NUM_CELLS / 2 {
187 read_data[i]
188 } else {
189 prev_data[i]
190 }
191 + opcode_when(&[StoreH2])
192 * if i - 2 < NUM_CELLS / 2 {
193 read_data[i - 2]
194 } else {
195 prev_data[i]
196 }
197 };
198 let expected_val = expected_load_val + expected_store_val;
199 builder.assert_eq(*cell, expected_val);
200 }
201
202 let expected_opcode = opcode_when(&[LoadW0]) * AB::Expr::from_u8(LOADW as u8)
203 + opcode_when(&[LoadHu0, LoadHu2]) * AB::Expr::from_u8(LOADHU as u8)
204 + opcode_when(&[LoadBu0, LoadBu1, LoadBu2, LoadBu3]) * AB::Expr::from_u8(LOADBU as u8)
205 + opcode_when(&[StoreW0]) * AB::Expr::from_u8(STOREW as u8)
206 + opcode_when(&[StoreH0, StoreH2]) * AB::Expr::from_u8(STOREH as u8)
207 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * AB::Expr::from_u8(STOREB as u8);
208 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode);
209
210 let load_shift_amount = opcode_when(&[LoadBu1]) * AB::Expr::ONE
211 + opcode_when(&[LoadHu2, LoadBu2]) * AB::Expr::TWO
212 + opcode_when(&[LoadBu3]) * AB::Expr::from_u32(3);
213
214 let store_shift_amount = opcode_when(&[StoreB1]) * AB::Expr::ONE
215 + opcode_when(&[StoreH2, StoreB2]) * AB::Expr::TWO
216 + opcode_when(&[StoreB3]) * AB::Expr::from_u32(3);
217
218 AdapterAirContext {
219 to_pc: None,
220 reads: (prev_data, read_data.map(|x| x.into())).into(),
221 writes: [write_data.map(|x| x.into())].into(),
222 instruction: LoadStoreInstruction {
223 is_valid: is_valid.into(),
224 opcode: expected_opcode,
225 is_load: is_load.into(),
226 load_shift_amount,
227 store_shift_amount,
228 }
229 .into(),
230 }
231 }
232
233 fn start_offset(&self) -> usize {
234 self.offset
235 }
236}
237
238#[repr(C)]
239#[derive(AlignedBytesBorrow, Debug)]
240pub struct LoadStoreCoreRecord<const NUM_CELLS: usize> {
241 pub local_opcode: u8,
242 pub shift_amount: u8,
243 pub read_data: [u8; NUM_CELLS],
244 pub prev_data: [u32; NUM_CELLS],
246}
247
248#[derive(Clone, Copy, derive_new::new)]
249pub struct LoadStoreExecutor<A, const NUM_CELLS: usize> {
250 adapter: A,
251 pub offset: usize,
252}
253
254#[derive(Clone, derive_new::new)]
255pub struct LoadStoreFiller<
256 A = Rv32LoadStoreAdapterFiller,
257 const NUM_CELLS: usize = RV32_REGISTER_NUM_LIMBS,
258> {
259 adapter: A,
260 pub offset: usize,
261}
262
263impl<F, A, RA, const NUM_CELLS: usize> PreflightExecutor<F, RA> for LoadStoreExecutor<A, NUM_CELLS>
264where
265 F: PrimeField32,
266 A: 'static
267 + AdapterTraceExecutor<
268 F,
269 ReadData = (([u32; NUM_CELLS], [u8; NUM_CELLS]), u8),
270 WriteData = [u32; NUM_CELLS],
271 >,
272 for<'buf> RA: RecordArena<
273 'buf,
274 EmptyAdapterCoreLayout<F, A>,
275 (A::RecordMut<'buf>, &'buf mut LoadStoreCoreRecord<NUM_CELLS>),
276 >,
277{
278 fn get_opcode_name(&self, opcode: usize) -> String {
279 format!(
280 "{:?}",
281 Rv32LoadStoreOpcode::from_usize(opcode - self.offset)
282 )
283 }
284
285 fn execute(
286 &self,
287 state: VmStateMut<F, TracingMemory, RA>,
288 instruction: &Instruction<F>,
289 ) -> Result<(), ExecutionError> {
290 let Instruction { opcode, .. } = instruction;
291
292 let (mut adapter_record, core_record) = state.ctx.alloc(EmptyAdapterCoreLayout::new());
293
294 A::start(*state.pc, state.memory, &mut adapter_record);
295
296 (
297 (core_record.prev_data, core_record.read_data),
298 core_record.shift_amount,
299 ) = self
300 .adapter
301 .read(state.memory, instruction, &mut adapter_record);
302
303 let local_opcode = Rv32LoadStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
304 core_record.local_opcode = local_opcode as u8;
305
306 let write_data = run_write_data(
307 local_opcode,
308 core_record.read_data,
309 core_record.prev_data,
310 core_record.shift_amount as usize,
311 );
312 self.adapter
313 .write(state.memory, instruction, write_data, &mut adapter_record);
314
315 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
316
317 Ok(())
318 }
319}
320
321impl<F, A, const NUM_CELLS: usize> TraceFiller<F> for LoadStoreFiller<A, NUM_CELLS>
322where
323 F: PrimeField32,
324 A: 'static + AdapterTraceFiller<F>,
325{
326 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, row_slice: &mut [F]) {
327 let (adapter_row, mut core_row) = unsafe { row_slice.split_at_mut_unchecked(A::WIDTH) };
330 self.adapter.fill_trace_row(mem_helper, adapter_row);
331 let record: &LoadStoreCoreRecord<NUM_CELLS> =
334 unsafe { get_record_from_slice(&mut core_row, ()) };
335 let core_row: &mut LoadStoreCoreCols<F, NUM_CELLS> = core_row.borrow_mut();
336
337 let opcode = Rv32LoadStoreOpcode::from_usize(record.local_opcode as usize);
338 let shift = record.shift_amount;
339
340 let write_data = run_write_data(opcode, record.read_data, record.prev_data, shift as usize);
341 core_row.write_data = write_data.map(F::from_u32);
343 core_row.prev_data = record.prev_data.map(F::from_u32);
344 core_row.read_data = record.read_data.map(F::from_u8);
345 core_row.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode));
346 core_row.is_valid = F::ONE;
347 let flags = &mut core_row.flags;
348 *flags = [F::ZERO; 4];
349 match (opcode, shift) {
350 (LOADW, 0) => flags[0] = F::TWO,
351 (LOADHU, 0) => flags[1] = F::TWO,
352 (LOADHU, 2) => flags[2] = F::TWO,
353 (LOADBU, 0) => flags[3] = F::TWO,
354
355 (LOADBU, 1) => flags[0] = F::ONE,
356 (LOADBU, 2) => flags[1] = F::ONE,
357 (LOADBU, 3) => flags[2] = F::ONE,
358 (STOREW, 0) => flags[3] = F::ONE,
359
360 (STOREH, 0) => (flags[0], flags[1]) = (F::ONE, F::ONE),
361 (STOREH, 2) => (flags[0], flags[2]) = (F::ONE, F::ONE),
362 (STOREB, 0) => (flags[0], flags[3]) = (F::ONE, F::ONE),
363 (STOREB, 1) => (flags[1], flags[2]) = (F::ONE, F::ONE),
364 (STOREB, 2) => (flags[1], flags[3]) = (F::ONE, F::ONE),
365 (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE),
366 _ => unreachable!(),
367 };
368 }
369}
370
371#[inline(always)]
373pub(super) fn run_write_data<const NUM_CELLS: usize>(
374 opcode: Rv32LoadStoreOpcode,
375 read_data: [u8; NUM_CELLS],
376 prev_data: [u32; NUM_CELLS],
377 shift: usize,
378) -> [u32; NUM_CELLS] {
379 match (opcode, shift) {
380 (LOADW, 0) => {
381 read_data.map(|x| x as u32)
382 },
383 (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => {
384 let mut wrie_data = [0; NUM_CELLS];
385 wrie_data[0] = read_data[shift] as u32;
386 wrie_data
387 }
388 (LOADHU, 0) | (LOADHU, 2) => {
389 let mut write_data = [0; NUM_CELLS];
390 for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() {
391 *cell = read_data[i + shift] as u32;
392 }
393 write_data
394 }
395 (STOREW, 0) => {
396 read_data.map(|x| x as u32)
397 },
398 (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => {
399 let mut write_data = prev_data;
400 write_data[shift] = read_data[0] as u32;
401 write_data
402 }
403 (STOREH, 0) | (STOREH, 2) => {
404 array::from_fn(|i| {
405 if i >= shift && i < (NUM_CELLS / 2 + shift){
406 read_data[i - shift] as u32
407 } else {
408 prev_data[i]
409 }
410 })
411 }
412 _ => unreachable!(
416 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
417 ),
418 }
419}