1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::arch::{
4 AdapterAirContext, AdapterRuntimeContext, Result, VmAdapterInterface, VmCoreAir, VmCoreChip,
5};
6use openvm_circuit_primitives_derive::AlignedBorrow;
7use openvm_instructions::{instruction::Instruction, LocalOpcode};
8use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
9use openvm_stark_backend::{
10 interaction::InteractionBuilder,
11 p3_air::{AirBuilder, BaseAir},
12 p3_field::{Field, FieldAlgebra, PrimeField32},
13 rap::BaseAirWithPublicValues,
14};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_big_array::BigArray;
17
18use crate::adapters::LoadStoreInstruction;
19
20#[derive(Debug, Clone, Copy)]
21enum InstructionOpcode {
22 LoadW0,
23 LoadHu0,
24 LoadHu2,
25 LoadBu0,
26 LoadBu1,
27 LoadBu2,
28 LoadBu3,
29 StoreW0,
30 StoreH0,
31 StoreH2,
32 StoreB0,
33 StoreB1,
34 StoreB2,
35 StoreB3,
36}
37
38use InstructionOpcode::*;
39
40#[repr(C)]
45#[derive(Debug, Clone, AlignedBorrow)]
46pub struct LoadStoreCoreCols<T, const NUM_CELLS: usize> {
47 pub flags: [T; 4],
48 pub is_valid: T,
50 pub is_load: T,
51
52 pub read_data: [T; NUM_CELLS],
53 pub prev_data: [T; NUM_CELLS],
54 pub write_data: [T; NUM_CELLS],
57}
58
59#[repr(C)]
60#[derive(Debug, Clone, Serialize, Deserialize)]
61#[serde(bound = "F: Serialize + DeserializeOwned")]
62pub struct LoadStoreCoreRecord<F, const NUM_CELLS: usize> {
63 pub opcode: Rv32LoadStoreOpcode,
64 pub shift: u32,
65 #[serde(with = "BigArray")]
66 pub read_data: [F; NUM_CELLS],
67 #[serde(with = "BigArray")]
68 pub prev_data: [F; NUM_CELLS],
69 #[serde(with = "BigArray")]
70 pub write_data: [F; NUM_CELLS],
71}
72
73#[derive(Debug, Clone)]
74pub struct LoadStoreCoreAir<const NUM_CELLS: usize> {
75 pub offset: usize,
76}
77
78impl<F: Field, const NUM_CELLS: usize> BaseAir<F> for LoadStoreCoreAir<NUM_CELLS> {
79 fn width(&self) -> usize {
80 LoadStoreCoreCols::<F, NUM_CELLS>::width()
81 }
82}
83
84impl<F: Field, const NUM_CELLS: usize> BaseAirWithPublicValues<F> for LoadStoreCoreAir<NUM_CELLS> {}
85
86impl<AB, I, const NUM_CELLS: usize> VmCoreAir<AB, I> for LoadStoreCoreAir<NUM_CELLS>
87where
88 AB: InteractionBuilder,
89 I: VmAdapterInterface<AB::Expr>,
90 I::Reads: From<([AB::Var; NUM_CELLS], [AB::Expr; NUM_CELLS])>,
91 I::Writes: From<[[AB::Expr; NUM_CELLS]; 1]>,
92 I::ProcessedInstruction: From<LoadStoreInstruction<AB::Expr>>,
93{
94 fn eval(
95 &self,
96 builder: &mut AB,
97 local_core: &[AB::Var],
98 _from_pc: AB::Var,
99 ) -> AdapterAirContext<AB::Expr, I> {
100 let cols: &LoadStoreCoreCols<AB::Var, NUM_CELLS> = (*local_core).borrow();
101 let LoadStoreCoreCols::<AB::Var, NUM_CELLS> {
102 read_data,
103 prev_data,
104 write_data,
105 flags,
106 is_valid,
107 is_load,
108 } = *cols;
109
110 let get_expr_12 = |x: &AB::Expr| (x.clone() - AB::Expr::ONE) * (x.clone() - AB::Expr::TWO);
111
112 builder.assert_bool(is_valid);
113 let sum = flags.iter().fold(AB::Expr::ZERO, |acc, &flag| {
114 builder.assert_zero(flag * get_expr_12(&flag.into()));
115 acc + flag
116 });
117 builder.assert_zero(sum.clone() * get_expr_12(&sum));
118 builder.when(get_expr_12(&sum)).assert_zero(is_valid);
120
121 let inv_2 = AB::F::from_canonical_u32(2).inverse();
124 let mut opcode_flags = vec![];
125 for flag in flags {
126 opcode_flags.push(flag * (flag - AB::F::ONE) * inv_2);
127 }
128 for flag in flags {
129 opcode_flags.push(flag * (sum.clone() - AB::F::TWO) * AB::F::NEG_ONE);
130 }
131 (0..4).for_each(|i| {
132 ((i + 1)..4).for_each(|j| opcode_flags.push(flags[i] * flags[j]));
133 });
134
135 let opcode_when = |idxs: &[InstructionOpcode]| -> AB::Expr {
136 idxs.iter().fold(AB::Expr::ZERO, |acc, &idx| {
137 acc + opcode_flags[idx as usize].clone()
138 })
139 };
140
141 builder.assert_eq(
143 is_load,
144 opcode_when(&[LoadW0, LoadHu0, LoadHu2, LoadBu0, LoadBu1, LoadBu2, LoadBu3]),
145 );
146 builder.when(is_load).assert_one(is_valid);
147
148 for (i, cell) in write_data.iter().enumerate() {
161 let expected_load_val = if i == 0 {
163 opcode_when(&[LoadW0, LoadHu0, LoadBu0]) * read_data[0]
164 + opcode_when(&[LoadBu1]) * read_data[1]
165 + opcode_when(&[LoadHu2, LoadBu2]) * read_data[2]
166 + opcode_when(&[LoadBu3]) * read_data[3]
167 } else if i < NUM_CELLS / 2 {
168 opcode_when(&[LoadW0, LoadHu0]) * read_data[i]
169 + opcode_when(&[LoadHu2]) * read_data[i + 2]
170 } else {
171 opcode_when(&[LoadW0]) * read_data[i]
172 };
173
174 let expected_store_val = if i == 0 {
176 opcode_when(&[StoreW0, StoreH0, StoreB0]) * read_data[i]
177 + opcode_when(&[StoreH2, StoreB1, StoreB2, StoreB3]) * prev_data[i]
178 } else if i == 1 {
179 opcode_when(&[StoreB1]) * read_data[i - 1]
180 + opcode_when(&[StoreW0, StoreH0]) * read_data[i]
181 + opcode_when(&[StoreH2, StoreB0, StoreB2, StoreB3]) * prev_data[i]
182 } else if i == 2 {
183 opcode_when(&[StoreH2, StoreB2]) * read_data[i - 2]
184 + opcode_when(&[StoreW0]) * read_data[i]
185 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB3]) * prev_data[i]
186 } else if i == 3 {
187 opcode_when(&[StoreB3]) * read_data[i - 3]
188 + opcode_when(&[StoreH2]) * read_data[i - 2]
189 + opcode_when(&[StoreW0]) * read_data[i]
190 + opcode_when(&[StoreH0, StoreB0, StoreB1, StoreB2]) * prev_data[i]
191 } else {
192 opcode_when(&[StoreW0]) * read_data[i]
193 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3]) * prev_data[i]
194 + opcode_when(&[StoreH0])
195 * if i < NUM_CELLS / 2 {
196 read_data[i]
197 } else {
198 prev_data[i]
199 }
200 + opcode_when(&[StoreH2])
201 * if i - 2 < NUM_CELLS / 2 {
202 read_data[i - 2]
203 } else {
204 prev_data[i]
205 }
206 };
207 let expected_val = expected_load_val + expected_store_val;
208 builder.assert_eq(*cell, expected_val);
209 }
210
211 let expected_opcode = opcode_when(&[LoadW0]) * AB::Expr::from_canonical_u8(LOADW as u8)
212 + opcode_when(&[LoadHu0, LoadHu2]) * AB::Expr::from_canonical_u8(LOADHU as u8)
213 + opcode_when(&[LoadBu0, LoadBu1, LoadBu2, LoadBu3])
214 * AB::Expr::from_canonical_u8(LOADBU as u8)
215 + opcode_when(&[StoreW0]) * AB::Expr::from_canonical_u8(STOREW as u8)
216 + opcode_when(&[StoreH0, StoreH2]) * AB::Expr::from_canonical_u8(STOREH as u8)
217 + opcode_when(&[StoreB0, StoreB1, StoreB2, StoreB3])
218 * AB::Expr::from_canonical_u8(STOREB as u8);
219 let expected_opcode = VmCoreAir::<AB, I>::expr_to_global_expr(self, expected_opcode);
220
221 let load_shift_amount = opcode_when(&[LoadBu1]) * AB::Expr::ONE
222 + opcode_when(&[LoadHu2, LoadBu2]) * AB::Expr::TWO
223 + opcode_when(&[LoadBu3]) * AB::Expr::from_canonical_u32(3);
224
225 let store_shift_amount = opcode_when(&[StoreB1]) * AB::Expr::ONE
226 + opcode_when(&[StoreH2, StoreB2]) * AB::Expr::TWO
227 + opcode_when(&[StoreB3]) * AB::Expr::from_canonical_u32(3);
228
229 AdapterAirContext {
230 to_pc: None,
231 reads: (prev_data, read_data.map(|x| x.into())).into(),
232 writes: [write_data.map(|x| x.into())].into(),
233 instruction: LoadStoreInstruction {
234 is_valid: is_valid.into(),
235 opcode: expected_opcode,
236 is_load: is_load.into(),
237 load_shift_amount,
238 store_shift_amount,
239 }
240 .into(),
241 }
242 }
243
244 fn start_offset(&self) -> usize {
245 self.offset
246 }
247}
248
249#[derive(Debug)]
250pub struct LoadStoreCoreChip<const NUM_CELLS: usize> {
251 pub air: LoadStoreCoreAir<NUM_CELLS>,
252}
253
254impl<const NUM_CELLS: usize> LoadStoreCoreChip<NUM_CELLS> {
255 pub fn new(offset: usize) -> Self {
256 Self {
257 air: LoadStoreCoreAir { offset },
258 }
259 }
260}
261
262impl<F: PrimeField32, I: VmAdapterInterface<F>, const NUM_CELLS: usize> VmCoreChip<F, I>
263 for LoadStoreCoreChip<NUM_CELLS>
264where
265 I::Reads: Into<([[F; NUM_CELLS]; 2], F)>,
266 I::Writes: From<[[F; NUM_CELLS]; 1]>,
267{
268 type Record = LoadStoreCoreRecord<F, NUM_CELLS>;
269 type Air = LoadStoreCoreAir<NUM_CELLS>;
270
271 #[allow(clippy::type_complexity)]
272 fn execute_instruction(
273 &self,
274 instruction: &Instruction<F>,
275 _from_pc: u32,
276 reads: I::Reads,
277 ) -> Result<(AdapterRuntimeContext<F, I>, Self::Record)> {
278 let local_opcode =
279 Rv32LoadStoreOpcode::from_usize(instruction.opcode.local_opcode_idx(self.air.offset));
280
281 let (reads, shift_amount) = reads.into();
282 let shift = shift_amount.as_canonical_u32();
283 let prev_data = reads[0];
284 let read_data = reads[1];
285 let write_data = run_write_data(local_opcode, read_data, prev_data, shift);
286 let output = AdapterRuntimeContext::without_pc([write_data]);
287
288 Ok((
289 output,
290 LoadStoreCoreRecord {
291 opcode: local_opcode,
292 shift,
293 prev_data,
294 read_data,
295 write_data,
296 },
297 ))
298 }
299
300 fn get_opcode_name(&self, opcode: usize) -> String {
301 format!(
302 "{:?}",
303 Rv32LoadStoreOpcode::from_usize(opcode - self.air.offset)
304 )
305 }
306
307 fn generate_trace_row(&self, row_slice: &mut [F], record: Self::Record) {
308 let core_cols: &mut LoadStoreCoreCols<F, NUM_CELLS> = row_slice.borrow_mut();
309 let opcode = record.opcode;
310 let flags = &mut core_cols.flags;
311 *flags = [F::ZERO; 4];
312 match (opcode, record.shift) {
313 (LOADW, 0) => flags[0] = F::TWO,
314 (LOADHU, 0) => flags[1] = F::TWO,
315 (LOADHU, 2) => flags[2] = F::TWO,
316 (LOADBU, 0) => flags[3] = F::TWO,
317
318 (LOADBU, 1) => flags[0] = F::ONE,
319 (LOADBU, 2) => flags[1] = F::ONE,
320 (LOADBU, 3) => flags[2] = F::ONE,
321 (STOREW, 0) => flags[3] = F::ONE,
322
323 (STOREH, 0) => (flags[0], flags[1]) = (F::ONE, F::ONE),
324 (STOREH, 2) => (flags[0], flags[2]) = (F::ONE, F::ONE),
325 (STOREB, 0) => (flags[0], flags[3]) = (F::ONE, F::ONE),
326 (STOREB, 1) => (flags[1], flags[2]) = (F::ONE, F::ONE),
327 (STOREB, 2) => (flags[1], flags[3]) = (F::ONE, F::ONE),
328 (STOREB, 3) => (flags[2], flags[3]) = (F::ONE, F::ONE),
329 _ => unreachable!(),
330 };
331 core_cols.prev_data = record.prev_data;
332 core_cols.read_data = record.read_data;
333 core_cols.is_valid = F::ONE;
334 core_cols.is_load = F::from_bool([LOADW, LOADHU, LOADBU].contains(&opcode));
335 core_cols.write_data = record.write_data;
336 }
337
338 fn air(&self) -> &Self::Air {
339 &self.air
340 }
341}
342
343pub(super) fn run_write_data<F: PrimeField32, const NUM_CELLS: usize>(
344 opcode: Rv32LoadStoreOpcode,
345 read_data: [F; NUM_CELLS],
346 prev_data: [F; NUM_CELLS],
347 shift: u32,
348) -> [F; NUM_CELLS] {
349 let shift = shift as usize;
350 let mut write_data = read_data;
351 match (opcode, shift) {
352 (LOADW, 0) => (),
353 (LOADBU, 0) | (LOADBU, 1) | (LOADBU, 2) | (LOADBU, 3) => {
354 for cell in write_data.iter_mut().take(NUM_CELLS).skip(1) {
355 *cell = F::ZERO;
356 }
357 write_data[0] = read_data[shift];
358 }
359 (LOADHU, 0) | (LOADHU, 2) => {
360 for cell in write_data.iter_mut().take(NUM_CELLS).skip(NUM_CELLS / 2) {
361 *cell = F::ZERO;
362 }
363 for (i, cell) in write_data.iter_mut().take(NUM_CELLS / 2).enumerate() {
364 *cell = read_data[i + shift];
365 }
366 }
367 (STOREW, 0) => (),
368 (STOREB, 0) | (STOREB, 1) | (STOREB, 2) | (STOREB, 3) => {
369 write_data = prev_data;
370 write_data[shift] = read_data[0];
371 }
372 (STOREH, 0) | (STOREH, 2) => {
373 write_data = prev_data;
374 write_data[shift..(NUM_CELLS / 2 + shift)]
375 .copy_from_slice(&read_data[..(NUM_CELLS / 2)]);
376 }
377 _ => unreachable!(
381 "unaligned memory access not supported by this execution environment: {opcode:?}, shift: {shift}"
382 ),
383 };
384 write_data
385}