1use std::{
2 borrow::{Borrow, BorrowMut},
3 marker::PhantomData,
4};
5
6use openvm_circuit::{
7 arch::{
8 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
9 ExecutionBridge, ExecutionState, VmAdapterAir, VmAdapterInterface,
10 },
11 system::{
12 memory::{
13 offline_checker::{
14 MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord,
15 MemoryWriteAuxCols,
16 },
17 online::TracingMemory,
18 MemoryAddress, MemoryAuxColsFactory,
19 },
20 native_adapter::util::{memory_read_native, timed_write_native},
21 },
22};
23use openvm_circuit_primitives::{
24 utils::{not, select},
25 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
26 AlignedBytesBorrow,
27};
28use openvm_circuit_primitives_derive::AlignedBorrow;
29use openvm_instructions::{
30 instruction::Instruction,
31 program::DEFAULT_PC_STEP,
32 riscv::{RV32_IMM_AS, RV32_MEMORY_AS, RV32_REGISTER_AS},
33 LocalOpcode, NATIVE_AS,
34};
35use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
36use openvm_stark_backend::{
37 interaction::InteractionBuilder,
38 p3_air::{AirBuilder, BaseAir},
39 p3_field::{Field, FieldAlgebra, PrimeField32},
40};
41
42use super::RV32_REGISTER_NUM_LIMBS;
43use crate::adapters::{memory_read, timed_write, tracing_read, RV32_CELL_BITS};
44
45pub struct LoadStoreInstruction<T> {
54 pub is_valid: T,
56 pub opcode: T,
58 pub is_load: T,
60
61 pub load_shift_amount: T,
64 pub store_shift_amount: T,
66}
67
68pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
69
70impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
72 type Reads = (
73 [AB::Var; RV32_REGISTER_NUM_LIMBS],
74 [AB::Expr; RV32_REGISTER_NUM_LIMBS],
75 );
76 type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
77 type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
78}
79
80#[repr(C)]
81#[derive(Debug, Clone, AlignedBorrow)]
82pub struct Rv32LoadStoreAdapterCols<T> {
83 pub from_state: ExecutionState<T>,
84 pub rs1_ptr: T,
85 pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
86 pub rs1_aux_cols: MemoryReadAuxCols<T>,
87
88 pub rd_rs2_ptr: T,
90 pub read_data_aux: MemoryReadAuxCols<T>,
91 pub imm: T,
92 pub imm_sign: T,
93 pub mem_ptr_limbs: [T; 2],
95 pub mem_as: T,
96 pub write_base_aux: MemoryBaseAuxCols<T>,
98 pub needs_write: T,
105}
106
107#[derive(Clone, Copy, Debug, derive_new::new)]
108pub struct Rv32LoadStoreAdapterAir {
109 pub(super) memory_bridge: MemoryBridge,
110 pub(super) execution_bridge: ExecutionBridge,
111 pub range_bus: VariableRangeCheckerBus,
112 pointer_max_bits: usize,
113}
114
115impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
116 fn width(&self) -> usize {
117 Rv32LoadStoreAdapterCols::<F>::width()
118 }
119}
120
121impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
122 type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
123
124 fn eval(
125 &self,
126 builder: &mut AB,
127 local: &[AB::Var],
128 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
129 ) {
130 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
131
132 let timestamp: AB::Var = local_cols.from_state.timestamp;
133 let mut timestamp_delta: usize = 0;
134 let mut timestamp_pp = || {
135 timestamp_delta += 1;
136 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
137 };
138
139 let is_load = ctx.instruction.is_load;
140 let is_valid = ctx.instruction.is_valid;
141 let load_shift_amount = ctx.instruction.load_shift_amount;
142 let store_shift_amount = ctx.instruction.store_shift_amount;
143 let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
144
145 let write_count = local_cols.needs_write;
146
147 builder.assert_bool(write_count);
149 builder.when(write_count).assert_one(is_valid.clone());
150
151 builder
154 .when(is_valid.clone() - write_count)
155 .assert_one(is_load.clone());
156 builder
157 .when(is_valid.clone() - write_count)
158 .assert_zero(local_cols.rd_rs2_ptr);
159
160 self.memory_bridge
162 .read(
163 MemoryAddress::new(
164 AB::F::from_canonical_u32(RV32_REGISTER_AS),
165 local_cols.rs1_ptr,
166 ),
167 local_cols.rs1_data,
168 timestamp_pp(),
169 &local_cols.rs1_aux_cols,
170 )
171 .eval(builder, is_valid.clone());
172
173 let limbs_01 = local_cols.rs1_data[0]
175 + local_cols.rs1_data[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
176 let limbs_23 = local_cols.rs1_data[2]
177 + local_cols.rs1_data[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
178
179 let inv = AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2)).inverse();
180 let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
181
182 builder.when(is_valid.clone()).assert_bool(carry.clone());
183
184 builder
185 .when(is_valid.clone())
186 .assert_bool(local_cols.imm_sign);
187 let imm_extend_limb =
188 local_cols.imm_sign * AB::F::from_canonical_u32((1 << (RV32_CELL_BITS * 2)) - 1);
189 let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
190 builder.when(is_valid.clone()).assert_bool(carry.clone());
191
192 self.range_bus
194 .range_check(
195 (local_cols.mem_ptr_limbs[0] - shift_amount)
197 * AB::F::from_canonical_u32(4).inverse(),
198 RV32_CELL_BITS * 2 - 2,
199 )
200 .eval(builder, is_valid.clone());
201 self.range_bus
202 .range_check(
203 local_cols.mem_ptr_limbs[1],
204 self.pointer_max_bits - RV32_CELL_BITS * 2,
205 )
206 .eval(builder, is_valid.clone());
207
208 let mem_ptr = local_cols.mem_ptr_limbs[0]
209 + local_cols.mem_ptr_limbs[1] * AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2));
210
211 let is_store = is_valid.clone() - is_load.clone();
212 builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
215 builder
216 .when(not::<AB::Expr>(is_valid.clone()))
217 .assert_zero(local_cols.mem_as);
218
219 let read_as = select::<AB::Expr>(
221 is_load.clone(),
222 local_cols.mem_as,
223 AB::F::from_canonical_u32(RV32_REGISTER_AS),
224 );
225
226 let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
232 - load_shift_amount;
233
234 self.memory_bridge
235 .read(
236 MemoryAddress::new(read_as, read_ptr),
237 ctx.reads.1,
238 timestamp_pp(),
239 &local_cols.read_data_aux,
240 )
241 .eval(builder, is_valid.clone());
242
243 let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
244
245 let write_as = select::<AB::Expr>(
247 is_load.clone(),
248 AB::F::from_canonical_u32(RV32_REGISTER_AS),
249 local_cols.mem_as,
250 );
251
252 let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
254 - store_shift_amount;
255
256 self.memory_bridge
257 .write(
258 MemoryAddress::new(write_as, write_ptr),
259 ctx.writes[0].clone(),
260 timestamp_pp(),
261 &write_aux_cols,
262 )
263 .eval(builder, write_count);
264
265 let to_pc = ctx
266 .to_pc
267 .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
268 self.execution_bridge
269 .execute(
270 ctx.instruction.opcode,
271 [
272 local_cols.rd_rs2_ptr.into(),
273 local_cols.rs1_ptr.into(),
274 local_cols.imm.into(),
275 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
276 local_cols.mem_as.into(),
277 local_cols.needs_write.into(),
278 local_cols.imm_sign.into(),
279 ],
280 local_cols.from_state,
281 ExecutionState {
282 pc: to_pc,
283 timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
284 },
285 )
286 .eval(builder, is_valid);
287 }
288
289 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
290 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
291 local_cols.from_state.pc
292 }
293}
294
295#[repr(C)]
296#[derive(AlignedBytesBorrow, Debug)]
297pub struct Rv32LoadStoreAdapterRecord {
298 pub from_pc: u32,
299 pub from_timestamp: u32,
300
301 pub rs1_ptr: u32,
302 pub rs1_val: u32,
303 pub rs1_aux_record: MemoryReadAuxRecord,
304
305 pub rd_rs2_ptr: u32,
306 pub read_data_aux: MemoryReadAuxRecord,
307 pub imm: u16,
308 pub imm_sign: bool,
309
310 pub mem_as: u8,
311
312 pub write_prev_timestamp: u32,
313}
314
315#[derive(Clone, Copy, derive_new::new)]
319pub struct Rv32LoadStoreAdapterExecutor {
320 pointer_max_bits: usize,
321}
322
323#[derive(derive_new::new)]
324pub struct Rv32LoadStoreAdapterFiller {
325 pointer_max_bits: usize,
326 pub range_checker_chip: SharedVariableRangeCheckerChip,
327}
328
329impl<F> AdapterTraceExecutor<F> for Rv32LoadStoreAdapterExecutor
330where
331 F: PrimeField32,
332{
333 const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
334 type ReadData = (
335 (
336 [u32; RV32_REGISTER_NUM_LIMBS],
337 [u8; RV32_REGISTER_NUM_LIMBS],
338 ),
339 u8,
340 );
341 type WriteData = [u32; RV32_REGISTER_NUM_LIMBS];
342 type RecordMut<'a> = &'a mut Rv32LoadStoreAdapterRecord;
343
344 #[inline(always)]
345 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
346 record.from_pc = pc;
347 record.from_timestamp = memory.timestamp;
348 }
349
350 #[inline(always)]
351 fn read(
352 &self,
353 memory: &mut TracingMemory,
354 instruction: &Instruction<F>,
355 record: &mut Self::RecordMut<'_>,
356 ) -> Self::ReadData {
357 let &Instruction {
358 opcode,
359 a,
360 b,
361 c,
362 d,
363 e,
364 g,
365 ..
366 } = instruction;
367
368 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
369
370 let local_opcode = Rv32LoadStoreOpcode::from_usize(
371 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
372 );
373
374 record.rs1_ptr = b.as_canonical_u32();
375 record.rs1_val = u32::from_le_bytes(tracing_read(
376 memory,
377 RV32_REGISTER_AS,
378 record.rs1_ptr,
379 &mut record.rs1_aux_record.prev_timestamp,
380 ));
381
382 record.imm = c.as_canonical_u32() as u16;
383 record.imm_sign = g.is_one();
384 let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
385
386 let ptr_val = record.rs1_val.wrapping_add(imm_extended);
387 let shift_amount = ptr_val & 3;
388 let ptr_val = ptr_val - shift_amount;
389
390 assert!(
391 ptr_val < (1 << self.pointer_max_bits),
392 "ptr_val: {ptr_val} = rs1_val: {} + imm_extended: {imm_extended} >= 2 ** {}",
393 record.rs1_val,
394 self.pointer_max_bits
395 );
396
397 let (read_data, prev_data) = match local_opcode {
400 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
401 debug_assert_eq!(e, F::from_canonical_u32(RV32_MEMORY_AS));
402 record.mem_as = RV32_MEMORY_AS as u8;
403 let read_data = tracing_read(
404 memory,
405 RV32_MEMORY_AS,
406 ptr_val,
407 &mut record.read_data_aux.prev_timestamp,
408 );
409 let prev_data = memory_read(memory.data(), RV32_REGISTER_AS, a.as_canonical_u32())
410 .map(u32::from);
411 (read_data, prev_data)
412 }
413 STOREW | STOREH | STOREB => {
414 let e = e.as_canonical_u32();
415 debug_assert_ne!(e, RV32_IMM_AS);
416 debug_assert_ne!(e, RV32_REGISTER_AS);
417 record.mem_as = e as u8;
418 let read_data = tracing_read(
419 memory,
420 RV32_REGISTER_AS,
421 a.as_canonical_u32(),
422 &mut record.read_data_aux.prev_timestamp,
423 );
424 let prev_data = if e == NATIVE_AS {
425 memory_read_native(memory.data(), ptr_val).map(|x: F| x.as_canonical_u32())
426 } else {
427 memory_read(memory.data(), e, ptr_val).map(u32::from)
428 };
429 (read_data, prev_data)
430 }
431 };
432
433 ((prev_data, read_data), shift_amount as u8)
434 }
435
436 #[inline(always)]
437 fn write(
438 &self,
439 memory: &mut TracingMemory,
440 instruction: &Instruction<F>,
441 data: Self::WriteData,
442 record: &mut Self::RecordMut<'_>,
443 ) {
444 let &Instruction {
445 opcode,
446 a,
447 d,
448 e,
449 f: enabled,
450 ..
451 } = instruction;
452
453 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
454 debug_assert_ne!(e.as_canonical_u32(), RV32_IMM_AS);
455 debug_assert_ne!(e.as_canonical_u32(), RV32_REGISTER_AS);
456
457 let local_opcode = Rv32LoadStoreOpcode::from_usize(
458 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
459 );
460
461 if enabled != F::ZERO {
462 record.rd_rs2_ptr = a.as_canonical_u32();
463
464 record.write_prev_timestamp = match local_opcode {
465 STOREW | STOREH | STOREB => {
466 let imm_extended = record.imm as u32 + record.imm_sign as u32 * 0xffff0000;
467 let ptr = record.rs1_val.wrapping_add(imm_extended) & !3;
468
469 if record.mem_as == 4 {
470 timed_write_native(memory, ptr, data.map(F::from_canonical_u32)).0
471 } else {
472 timed_write(memory, record.mem_as as u32, ptr, data.map(|x| x as u8)).0
473 }
474 }
475 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
476 timed_write(
477 memory,
478 RV32_REGISTER_AS,
479 record.rd_rs2_ptr,
480 data.map(|x| x as u8),
481 )
482 .0
483 }
484 };
485 } else {
486 record.rd_rs2_ptr = u32::MAX;
487 memory.increment_timestamp();
488 };
489 }
490}
491
492impl<F: PrimeField32> AdapterTraceFiller<F> for Rv32LoadStoreAdapterFiller {
493 const WIDTH: usize = size_of::<Rv32LoadStoreAdapterCols<u8>>();
494
495 #[inline(always)]
496 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
497 debug_assert!(self.range_checker_chip.range_max_bits() >= 15);
498
499 let record: &Rv32LoadStoreAdapterRecord =
504 unsafe { get_record_from_slice(&mut adapter_row, ()) };
505 let adapter_row: &mut Rv32LoadStoreAdapterCols<F> = adapter_row.borrow_mut();
506
507 let needs_write = record.rd_rs2_ptr != u32::MAX;
508 adapter_row.needs_write = F::from_bool(needs_write);
510
511 if needs_write {
512 mem_helper.fill(
513 record.write_prev_timestamp,
514 record.from_timestamp + 2,
515 &mut adapter_row.write_base_aux,
516 );
517 } else {
518 mem_helper.fill_zero(&mut adapter_row.write_base_aux);
519 }
520
521 adapter_row.mem_as = F::from_canonical_u8(record.mem_as);
522 let ptr = record
523 .rs1_val
524 .wrapping_add(record.imm as u32 + record.imm_sign as u32 * 0xffff0000);
525
526 let ptr_limbs = [ptr & 0xffff, ptr >> 16];
527 self.range_checker_chip
528 .add_count(ptr_limbs[0] >> 2, RV32_CELL_BITS * 2 - 2);
529 self.range_checker_chip
530 .add_count(ptr_limbs[1], self.pointer_max_bits - 16);
531 adapter_row.mem_ptr_limbs = ptr_limbs.map(F::from_canonical_u32);
532
533 adapter_row.imm_sign = F::from_bool(record.imm_sign);
534 adapter_row.imm = F::from_canonical_u16(record.imm);
535
536 mem_helper.fill(
537 record.read_data_aux.prev_timestamp,
538 record.from_timestamp + 1,
539 adapter_row.read_data_aux.as_mut(),
540 );
541 adapter_row.rd_rs2_ptr = if record.rd_rs2_ptr != u32::MAX {
542 F::from_canonical_u32(record.rd_rs2_ptr)
543 } else {
544 F::ZERO
545 };
546
547 mem_helper.fill(
548 record.rs1_aux_record.prev_timestamp,
549 record.from_timestamp,
550 adapter_row.rs1_aux_cols.as_mut(),
551 );
552
553 adapter_row.rs1_data = record.rs1_val.to_le_bytes().map(F::from_canonical_u8);
554 adapter_row.rs1_ptr = F::from_canonical_u32(record.rs1_ptr);
555
556 adapter_row.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
557 adapter_row.from_state.pc = F::from_canonical_u32(record.from_pc);
558 }
559}