1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4};
5
6use openvm_circuit::{
7 arch::{
8 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9 ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface,
10 },
11 system::{
12 memory::{
13 offline_checker::{
14 MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord,
15 MemoryWriteAuxCols,
16 },
17 online::TracingMemory,
18 MemoryAddress, MemoryAuxColsFactory,
19 },
20 native_adapter::util::{memory_read_native, timed_write_native},
21 },
22};
23use openvm_circuit_primitives::{
24 utils::{not, select},
25 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
26 AlignedBytesBorrow,
27};
28use openvm_circuit_primitives_derive::AlignedBorrow;
29use openvm_instructions::{
30 instruction::Instruction,
31 program::DEFAULT_PC_STEP,
32 riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS},
33 LocalOpcode, NATIVE_AS,
34};
35use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
36use openvm_stark_backend::{
37 interaction::InteractionBuilder,
38 p3_air::{AirBuilder, BaseAir},
39 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
40};
41
42use super::RV32_REGISTER_NUM_LIMBS;
43use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS};
44
45pub struct LoadStoreInstruction<T> {
54 pub is_valid: T,
56 pub opcode: T,
58 pub is_load: T,
60
61 pub load_shift_amount: T,
64 pub store_shift_amount: T,
66}
67
68pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
69
70impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
72 type Reads = (
73 [AB::Var; RV32_REGISTER_NUM_LIMBS],
74 [AB::Expr; RV32_REGISTER_NUM_LIMBS],
75 );
76 type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
77 type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
78}
79
80#[repr(C)]
81#[derive(Debug, Clone, AlignedBorrow)]
82pub struct Rv32LoadStoreAdapterCols<T> {
83 pub from_state: ExecutionState<T>,
84 pub rs1_ptr: T,
85 pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
86 pub rs1_aux_cols: MemoryReadAuxCols<T>,
87
88 pub rd_rs2_ptr: T,
90 pub read_data_aux: MemoryReadAuxCols<T>,
91 pub imm: T,
92 pub imm_sign: T,
93 pub mem_ptr_limbs: [T; 2],
95 pub mem_as: T,
96 pub write_base_aux: MemoryBaseAuxCols<T>,
98 pub needs_write: T,
105}
106
107#[derive(Clone, Copy, Debug, derive_new::new)]
108pub struct Rv32LoadStoreAdapterAir {
109 pub(super) memory_bridge: MemoryBridge,
110 pub(super) execution_bridge: ExecutionBridge,
111 pub range_bus: VariableRangeCheckerBus,
112 pointer_max_bits: usize,
113}
114
115impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
116 fn width(&self) -> usize {
117 Rv32LoadStoreAdapterCols::<F>::width()
118 }
119}
120
121impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
122 type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
123
124 fn eval(
125 &self,
126 builder: &mut AB,
127 local: &[AB::Var],
128 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
129 ) {
130 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
131
132 let timestamp: AB::Var = local_cols.from_state.timestamp;
133 let mut timestamp_delta: usize = 0;
134 let mut timestamp_pp = || {
135 timestamp_delta += 1;
136 timestamp + AB::Expr::from_usize(timestamp_delta - 1)
137 };
138
139 let is_load = ctx.instruction.is_load;
140 let is_valid = ctx.instruction.is_valid;
141 let load_shift_amount = ctx.instruction.load_shift_amount;
142 let store_shift_amount = ctx.instruction.store_shift_amount;
143 let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
144
145 let write_count = local_cols.needs_write;
146
147 builder.assert_bool(write_count);
149 builder.when(write_count).assert_one(is_valid.clone());
150
151 builder
154 .when(is_valid.clone() - write_count)
155 .assert_one(is_load.clone());
156 builder
157 .when(is_valid.clone() - write_count)
158 .assert_zero(local_cols.rd_rs2_ptr);
159
160 self.memory_bridge
162 .read(
163 MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.rs1_ptr),
164 local_cols.rs1_data,
165 timestamp_pp(),
166 &local_cols.rs1_aux_cols,
167 )
168 .eval(builder, is_valid.clone());
169
170 let limbs_01 =
172 local_cols.rs1_data[0] + local_cols.rs1_data[1] * AB::F::from_u32(1 << RV32_CELL_BITS);
173 let limbs_23 =
174 local_cols.rs1_data[2] + local_cols.rs1_data[3] * AB::F::from_u32(1 << RV32_CELL_BITS);
175
176 let inv = AB::F::from_u32(1 << (RV32_CELL_BITS * 2)).inverse();
177 let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
178
179 builder.when(is_valid.clone()).assert_bool(carry.clone());
180
181 builder
182 .when(is_valid.clone())
183 .assert_bool(local_cols.imm_sign);
184 let imm_extend_limb =
185 local_cols.imm_sign * AB::F::from_u32((1 << (RV32_CELL_BITS * 2)) - 1);
186 let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
187 builder.when(is_valid.clone()).assert_bool(carry.clone());
188
189 self.range_bus
191 .range_check(
192 (local_cols.mem_ptr_limbs[0] - shift_amount) * AB::F::from_u32(4).inverse(),
194 RV32_CELL_BITS * 2 - 2,
195 )
196 .eval(builder, is_valid.clone());
197 self.range_bus
198 .range_check(
199 local_cols.mem_ptr_limbs[1],
200 self.pointer_max_bits - RV32_CELL_BITS * 2,
201 )
202 .eval(builder, is_valid.clone());
203
204 let mem_ptr = local_cols.mem_ptr_limbs[0]
205 + local_cols.mem_ptr_limbs[1] * AB::F::from_u32(1 << (RV32_CELL_BITS * 2));
206
207 let is_store = is_valid.clone() - is_load.clone();
208 builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
211 builder
212 .when(not::<AB::Expr>(is_valid.clone()))
213 .assert_zero(local_cols.mem_as);
214
215 let read_as = select::<AB::Expr>(
217 is_load.clone(),
218 local_cols.mem_as,
219 AB::F::from_u32(RV32_REGISTER_AS),
220 );
221
222 let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
228 - load_shift_amount;
229
230 self.memory_bridge
231 .read(
232 MemoryAddress::new(read_as, read_ptr),
233 ctx.reads.1,
234 timestamp_pp(),
235 &local_cols.read_data_aux,
236 )
237 .eval(builder, is_valid.clone());
238
239 let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
240
241 let write_as = select::<AB::Expr>(
243 is_load.clone(),
244 AB::F::from_u32(RV32_REGISTER_AS),
245 local_cols.mem_as,
246 );
247
248 let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
250 - store_shift_amount;
251
252 self.memory_bridge
253 .write(
254 MemoryAddress::new(write_as, write_ptr),
255 ctx.writes[0].clone(),
256 timestamp_pp(),
257 &write_aux_cols,
258 )
259 .eval(builder, write_count);
260
261 let to_pc = ctx
262 .to_pc
263 .unwrap_or(local_cols.from_state.pc + AB::F::from_u32(DEFAULT_PC_STEP));
264 self.execution_bridge
265 .execute(
266 ctx.instruction.opcode,
267 [
268 local_cols.rd_rs2_ptr.into(),
269 local_cols.rs1_ptr.into(),
270 local_cols.imm.into(),
271 AB::Expr::from_u32(RV32_REGISTER_AS),
272 local_cols.mem_as.into(),
273 local_cols.needs_write.into(),
274 local_cols.imm_sign.into(),
275 ],
276 local_cols.from_state,
277 ExecutionState {
278 pc: to_pc,
279 timestamp: timestamp + AB::F::from_usize(timestamp_delta),
280 },
281 )
282 .eval(builder, is_valid);
283 }
284
285 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
286 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
287 local_cols.from_state.pc
288 }
289}
290
291#[repr(C)]
292#[derive(AlignedBytesBorrow, Debug)]
293pub struct Rv32LoadStoreAdapterRecord {
294 pub from_pc: u32,
295 pub from_timestamp: u32,
296
297 pub rs1_ptr: u32,
298 pub rs1_val: u32,
299 pub rs1_aux_record: MemoryReadAuxRecord,
300
301 pub rd_rs2_ptr: u32,
302 pub read_data_aux: MemoryReadAuxRecord,
303 pub imm: u16,
304 pub imm_sign: bool,
305
306 pub mem_as: u8,
307
308 pub write_prev_timestamp: u32,
309}
310
311#[derive(Clone, Copy, derive_new::new)]
315pub struct Rv32LoadStoreAdapterExecutor {
316 pointer_max_bits: usize,
317}
318
319#[derive(derive_new::new)]
320pub struct Rv32LoadStoreAdapterFiller {
321 pointer_max_bits: usize,
322 pub range_checker_chip: SharedVariableRangeCheckerChip,
323}
324
325impl<F> AdapterTraceExecutor<F> for Rv32LoadStoreAdapterExecutor
326where
327 F: PrimeField32,
328{
329 const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
330 type ReadData = (
331 (
332 [u32; RV32_REGISTER_NUM_LIMBS],
333 [u8; RV32_REGISTER_NUM_LIMBS],
334 ),
335 u8,
336 );
337 type WriteData = [u32; RV32_REGISTER_NUM_LIMBS];
338 type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord;
339
340 #[inline(always)]
341 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
342 record.from_pc = pc;
343 record.from_timestamp = memory.timestamp;
344 }
345
346 #[inline(always)]
347 fn read(
348 &self,
349 memory: &mut TracingMemory,
350 instruction: &Instruction<F>,
351 record: &mut Self::RecordMut<'_>,
352 ) -> Self::ReadData {
353 let &Instruction {
354 opcode,
355 a,
356 b,
357 c,
358 d,
359 e,
360 g,
361 ..
362 } = instruction;
363
364 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
365
366 let local_opcode = Rv32LoadStoreOpcode::from_usize(
367 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
368 );
369
370 record.rs1_ptr = b.as_canonical_u32();
371 record.rs1_val = u32::from_le_bytes(tracing_read(
372 memory,
373 RV32_REGISTER_AS,
374 record.rs1_ptr,
375 &mut record.rs1_aux_record.prev_timestamp,
376 ));
377
378 record.imm = c.as_canonical_u32() as u16;
379 record.imm_sign = g.is_one();
380 let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
381
382 let ptr_val = record.rs1_val.wrapping_add(imm_extended);
383 let shift_amount = ptr_val & 3;
384 let ptr_val = ptr_val - shift_amount;
385
386 assert!(
387 ptr_val < (1 << self.pointer_max_bits),
388 "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}",
389 record.rs1_val,
390 self.pointer_max_bits
391 );
392
393 let (read_data, prev_data) = match local_opcode {
396 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
397 debug_assert_eq!(e, F::from_u32(RV32_MEMORY_AS));
398 record.mem_as = RV32_MEMORY_AS as u8;
399 let read_data = tracing_read(
400 memory,
401 RV32_MEMORY_AS,
402 ptr_val,
403 &mut record.read_data_aux.prev_timestamp,
404 );
405 let prev_data = memory_read(memory.data(), RV32_REGISTER_AS, a.as_canonical_u32())
406 .map(u32::from);
407 (read_data, prev_data)
408 }
409 STOREW | STOREH | STOREB => {
410 let e = e.as_canonical_u32();
411 debug_assert_ne!(e, RV32_IMM_AS);
412 debug_assert_ne!(e, RV32_REGISTER_AS);
413 record.mem_as = e as u8;
414 let read_data = tracing_read(
415 memory,
416 RV32_REGISTER_AS,
417 a.as_canonical_u32(),
418 &mut record.read_data_aux.prev_timestamp,
419 );
420 let prev_data = if e == NATIVE_AS {
421 memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32())
422 } else {
423 memory_read(memory.data(), e, ptr_val).map(u32::from)
424 };
425 (read_data, prev_data)
426 }
427 };
428
429 ((prev_data, read_data), shift_amount as u8)
430 }
431
432 #[inline(always)]
433 fn write(
434 &self,
435 memory: &mut TracingMemory,
436 instruction: &Instruction<F>,
437 data: Self::WriteData,
438 record: &mut Self::RecordMut<'_>,
439 ) {
440 let &Instruction {
441 opcode,
442 a,
443 d,
444 e,
445 f: enabled,
446 ..
447 } = instruction;
448
449 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
450 debug_assert_ne!(e.as_canonical_u32(), RV32_IMM_AS);
451 debug_assert_ne!(e.as_canonical_u32(), RV32_REGISTER_AS);
452
453 let local_opcode = Rv32LoadStoreOpcode::from_usize(
454 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
455 );
456
457 if enabled != F::ZERO {
458 record.rd_rs2_ptr = a.as_canonical_u32();
459
460 record.write_prev_timestamp = match local_opcode {
461 STOREW | STOREH | STOREB => {
462 let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
463 let ptr = record.rs1_val.wrapping_add(imm_extended) & !3;
464
465 if record.mem_as == 4 {
466 timed_write_native(memory, ptr, data.map(F::from_u32)).0
467 } else {
468 timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0
469 }
470 }
471 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
472 timed_write(
473 memory,
474 RV32_REGISTER_AS,
475 record.rd_rs2_ptr,
476 data.map(|x| x as u8),
477 )
478 .0
479 }
480 };
481 } else {
482 record.rd_rs2_ptr = u32::MAX;
483 memory.increment_timestamp();
484 };
485 }
486}
487
488impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32LoadStoreAdapterFiller {
489 const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
490
491 #[inline(always)]
492 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
493 debug_assert!(self.range_checker_chip.range_max_bits() >= 15);
494
495 let record: &Rv32LoadStoreAdapterRecord =
500 unsafe { get_record_from_slice(&mut adapter_row, ()) };
501 let adapter_row: &mut Rv32LoadStoreAdapterCols<F> = adapter_row.borrow_mut();
502
503 let needs_write = record.rd_rs2_ptr != u32::MAX;
504 adapter_row.needs_write = F::from_bool(needs_write);
506
507 if needs_write {
508 mem_helper.fill(
509 record.write_prev_timestamp,
510 record.from_timestamp + 2,
511 &mut adapter_row.write_base_aux,
512 );
513 } else {
514 mem_helper.fill_zero(&mut adapter_row.write_base_aux);
515 }
516
517 adapter_row.mem_as = F::from_u8(record.mem_as);
518 let ptr = record
519 .rs1_val
520 .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000);
521
522 let ptr_limbs = [ptr & 0xffff, ptr >> 16];
523 self.range_checker_chip
524 .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2);
525 self.range_checker_chip
526 .add_count(ptr_limbs[1], self.pointer_max_bits - 16);
527 adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_u32);
528
529 adapter_row.imm_sign = F::from_bool(record.imm_sign);
530 adapter_row.imm = F::from_u16(record.imm);
531
532 mem_helper.fill(
533 record.read_data_aux.prev_timestamp,
534 record.from_timestamp + 1,
535 adapter_row.read_data_aux.as_mut(),
536 );
537 adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX {
538 F::from_u32(record.rd_rs2_ptr)
539 } else {
540 F::ZERO
541 };
542
543 mem_helper.fill(
544 record.rs1_aux_record.prev_timestamp,
545 record.from_timestamp,
546 adapter_row.rs1_aux_cols.as_mut(),
547 );
548
549 adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_u8);
550 adapter_row.rs1_ptr = F::from_u32(record.rs1_ptr);
551
552 adapter_row.from_state.timestamp = F::from_u32(record.from_timestamp);
553 adapter_row.from_state.pc = F::from_u32(record.from_pc);
554 }
555}