1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 marker::PhantomData,
5};
6
7use openvm_circuit::{
8 arch::{
9 AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
10 Result, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
11 },
12 system::{
13 memory::{
14 offline_checker::{
15 MemoryBaseAuxCols, MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols,
16 },
17 MemoryAddress, MemoryController, OfflineMemory, RecordId,
18 },
19 program::ProgramBus,
20 },
21};
22use openvm_circuit_primitives::{
23 utils::{not, select},
24 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28 instruction::Instruction,
29 program::DEFAULT_PC_STEP,
30 riscv::{RV32_IMM_AS, RV32_REGISTER_AS},
31 LocalOpcode,
32};
33use openvm_rv32im_transpiler::Rv32LoadStoreOpcode::{self, *};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::{AirBuilder, BaseAir},
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40
41use super::{compose, RV32_REGISTER_NUM_LIMBS};
42use crate::adapters::RV32_CELL_BITS;
43
44pub struct LoadStoreInstruction<T> {
52 pub is_valid: T,
54 pub opcode: T,
56 pub is_load: T,
58
59 pub load_shift_amount: T,
62 pub store_shift_amount: T,
64}
65
66pub struct Rv32LoadStoreAdapterRuntimeInterface<T>(PhantomData<T>);
76impl<T> VmAdapterInterface<T> for Rv32LoadStoreAdapterRuntimeInterface<T> {
77 type Reads = ([[T; RV32_REGISTER_NUM_LIMBS]; 2], T);
78 type Writes = [[T; RV32_REGISTER_NUM_LIMBS]; 1];
79 type ProcessedInstruction = ();
80}
81pub struct Rv32LoadStoreAdapterAirInterface<AB: InteractionBuilder>(PhantomData<AB>);
82
83impl<AB: InteractionBuilder> VmAdapterInterface<AB::Expr> for Rv32LoadStoreAdapterAirInterface<AB> {
85 type Reads = (
86 [AB::Var; RV32_REGISTER_NUM_LIMBS],
87 [AB::Expr; RV32_REGISTER_NUM_LIMBS],
88 );
89 type Writes = [[AB::Expr; RV32_REGISTER_NUM_LIMBS]; 1];
90 type ProcessedInstruction = LoadStoreInstruction<AB::Expr>;
91}
92
93pub struct Rv32LoadStoreAdapterChip<F: Field> {
97 pub air: Rv32LoadStoreAdapterAir,
98 pub range_checker_chip: SharedVariableRangeCheckerChip,
99 _marker: PhantomData<F>,
100}
101
102impl<F: PrimeField32> Rv32LoadStoreAdapterChip<F> {
103 pub fn new(
104 execution_bus: ExecutionBus,
105 program_bus: ProgramBus,
106 memory_bridge: MemoryBridge,
107 pointer_max_bits: usize,
108 range_checker_chip: SharedVariableRangeCheckerChip,
109 ) -> Self {
110 assert!(range_checker_chip.range_max_bits() >= 15);
111 Self {
112 air: Rv32LoadStoreAdapterAir {
113 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
114 memory_bridge,
115 range_bus: range_checker_chip.bus(),
116 pointer_max_bits,
117 },
118 range_checker_chip,
119 _marker: PhantomData,
120 }
121 }
122}
123
124#[repr(C)]
125#[derive(Debug, Clone, Serialize, Deserialize)]
126#[serde(bound = "F: Field")]
127pub struct Rv32LoadStoreReadRecord<F: Field> {
128 pub rs1_record: RecordId,
129 pub read: RecordId,
131 pub rs1_ptr: F,
132 pub imm: F,
133 pub imm_sign: F,
134 pub mem_as: F,
135 pub mem_ptr_limbs: [u32; 2],
136 pub shift_amount: u32,
137}
138
139#[repr(C)]
140#[derive(Debug, Clone, Serialize, Deserialize)]
141#[serde(bound = "F: Field")]
142pub struct Rv32LoadStoreWriteRecord<F: Field> {
143 pub write_id: RecordId,
146 pub from_state: ExecutionState<u32>,
147 pub rd_rs2_ptr: F,
148}
149
150#[repr(C)]
151#[derive(Debug, Clone, AlignedBorrow)]
152pub struct Rv32LoadStoreAdapterCols<T> {
153 pub from_state: ExecutionState<T>,
154 pub rs1_ptr: T,
155 pub rs1_data: [T; RV32_REGISTER_NUM_LIMBS],
156 pub rs1_aux_cols: MemoryReadAuxCols<T>,
157
158 pub rd_rs2_ptr: T,
160 pub read_data_aux: MemoryReadAuxCols<T>,
161 pub imm: T,
162 pub imm_sign: T,
163 pub mem_ptr_limbs: [T; 2],
165 pub mem_as: T,
166 pub write_base_aux: MemoryBaseAuxCols<T>,
168 pub needs_write: T,
175}
176
177#[derive(Clone, Copy, Debug, derive_new::new)]
178pub struct Rv32LoadStoreAdapterAir {
179 pub(super) memory_bridge: MemoryBridge,
180 pub(super) execution_bridge: ExecutionBridge,
181 pub range_bus: VariableRangeCheckerBus,
182 pointer_max_bits: usize,
183}
184
185impl<F: Field> BaseAir<F> for Rv32LoadStoreAdapterAir {
186 fn width(&self) -> usize {
187 Rv32LoadStoreAdapterCols::<F>::width()
188 }
189}
190
191impl<AB: InteractionBuilder> VmAdapterAir<AB> for Rv32LoadStoreAdapterAir {
192 type Interface = Rv32LoadStoreAdapterAirInterface<AB>;
193
194 fn eval(
195 &self,
196 builder: &mut AB,
197 local: &[AB::Var],
198 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
199 ) {
200 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
201
202 let timestamp: AB::Var = local_cols.from_state.timestamp;
203 let mut timestamp_delta: usize = 0;
204 let mut timestamp_pp = || {
205 timestamp_delta += 1;
206 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
207 };
208
209 let is_load = ctx.instruction.is_load;
210 let is_valid = ctx.instruction.is_valid;
211 let load_shift_amount = ctx.instruction.load_shift_amount;
212 let store_shift_amount = ctx.instruction.store_shift_amount;
213 let shift_amount = load_shift_amount.clone() + store_shift_amount.clone();
214
215 let write_count = local_cols.needs_write;
216
217 builder.assert_bool(write_count);
219 builder.when(write_count).assert_one(is_valid.clone());
220
221 builder
223 .when(is_valid.clone() - write_count)
224 .assert_one(is_load.clone());
225 builder
226 .when(is_valid.clone() - write_count)
227 .assert_zero(local_cols.rd_rs2_ptr);
228
229 self.memory_bridge
231 .read(
232 MemoryAddress::new(
233 AB::F::from_canonical_u32(RV32_REGISTER_AS),
234 local_cols.rs1_ptr,
235 ),
236 local_cols.rs1_data,
237 timestamp_pp(),
238 &local_cols.rs1_aux_cols,
239 )
240 .eval(builder, is_valid.clone());
241
242 let limbs_01 = local_cols.rs1_data[0]
244 + local_cols.rs1_data[1] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
245 let limbs_23 = local_cols.rs1_data[2]
246 + local_cols.rs1_data[3] * AB::F::from_canonical_u32(1 << RV32_CELL_BITS);
247
248 let inv = AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2)).inverse();
249 let carry = (limbs_01 + local_cols.imm - local_cols.mem_ptr_limbs[0]) * inv;
250
251 builder.when(is_valid.clone()).assert_bool(carry.clone());
252
253 builder
254 .when(is_valid.clone())
255 .assert_bool(local_cols.imm_sign);
256 let imm_extend_limb =
257 local_cols.imm_sign * AB::F::from_canonical_u32((1 << (RV32_CELL_BITS * 2)) - 1);
258 let carry = (limbs_23 + imm_extend_limb + carry - local_cols.mem_ptr_limbs[1]) * inv;
259 builder.when(is_valid.clone()).assert_bool(carry.clone());
260
261 self.range_bus
263 .range_check(
264 (local_cols.mem_ptr_limbs[0] - shift_amount)
266 * AB::F::from_canonical_u32(4).inverse(),
267 RV32_CELL_BITS * 2 - 2,
268 )
269 .eval(builder, is_valid.clone());
270 self.range_bus
271 .range_check(
272 local_cols.mem_ptr_limbs[1],
273 self.pointer_max_bits - RV32_CELL_BITS * 2,
274 )
275 .eval(builder, is_valid.clone());
276
277 let mem_ptr = local_cols.mem_ptr_limbs[0]
278 + local_cols.mem_ptr_limbs[1] * AB::F::from_canonical_u32(1 << (RV32_CELL_BITS * 2));
279
280 let is_store = is_valid.clone() - is_load.clone();
281 builder.assert_tern(local_cols.mem_as - is_store * AB::Expr::TWO);
284 builder
285 .when(not::<AB::Expr>(is_valid.clone()))
286 .assert_zero(local_cols.mem_as);
287
288 let read_as = select::<AB::Expr>(
290 is_load.clone(),
291 local_cols.mem_as,
292 AB::F::from_canonical_u32(RV32_REGISTER_AS),
293 );
294
295 let read_ptr = select::<AB::Expr>(is_load.clone(), mem_ptr.clone(), local_cols.rd_rs2_ptr)
300 - load_shift_amount;
301
302 self.memory_bridge
303 .read(
304 MemoryAddress::new(read_as, read_ptr),
305 ctx.reads.1,
306 timestamp_pp(),
307 &local_cols.read_data_aux,
308 )
309 .eval(builder, is_valid.clone());
310
311 let write_aux_cols = MemoryWriteAuxCols::from_base(local_cols.write_base_aux, ctx.reads.0);
312
313 let write_as = select::<AB::Expr>(
315 is_load.clone(),
316 AB::F::from_canonical_u32(RV32_REGISTER_AS),
317 local_cols.mem_as,
318 );
319
320 let write_ptr = select::<AB::Expr>(is_load.clone(), local_cols.rd_rs2_ptr, mem_ptr.clone())
322 - store_shift_amount;
323
324 self.memory_bridge
325 .write(
326 MemoryAddress::new(write_as, write_ptr),
327 ctx.writes[0].clone(),
328 timestamp_pp(),
329 &write_aux_cols,
330 )
331 .eval(builder, write_count);
332
333 let to_pc = ctx
334 .to_pc
335 .unwrap_or(local_cols.from_state.pc + AB::F::from_canonical_u32(DEFAULT_PC_STEP));
336 self.execution_bridge
337 .execute(
338 ctx.instruction.opcode,
339 [
340 local_cols.rd_rs2_ptr.into(),
341 local_cols.rs1_ptr.into(),
342 local_cols.imm.into(),
343 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
344 local_cols.mem_as.into(),
345 local_cols.needs_write.into(),
346 local_cols.imm_sign.into(),
347 ],
348 local_cols.from_state,
349 ExecutionState {
350 pc: to_pc,
351 timestamp: timestamp + AB::F::from_canonical_usize(timestamp_delta),
352 },
353 )
354 .eval(builder, is_valid);
355 }
356
357 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
358 let local_cols: &Rv32LoadStoreAdapterCols<AB::Var> = local.borrow();
359 local_cols.from_state.pc
360 }
361}
362
363impl<F: PrimeField32> VmAdapterChip<F> for Rv32LoadStoreAdapterChip<F> {
364 type ReadRecord = Rv32LoadStoreReadRecord<F>;
365 type WriteRecord = Rv32LoadStoreWriteRecord<F>;
366 type Air = Rv32LoadStoreAdapterAir;
367 type Interface = Rv32LoadStoreAdapterRuntimeInterface<F>;
368
369 #[allow(clippy::type_complexity)]
370 fn preprocess(
371 &mut self,
372 memory: &mut MemoryController<F>,
373 instruction: &Instruction<F>,
374 ) -> Result<(
375 <Self::Interface as VmAdapterInterface<F>>::Reads,
376 Self::ReadRecord,
377 )> {
378 let Instruction {
379 opcode,
380 a,
381 b,
382 c,
383 d,
384 e,
385 g,
386 ..
387 } = *instruction;
388 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
389 debug_assert!(e.as_canonical_u32() != RV32_IMM_AS);
390
391 let local_opcode = Rv32LoadStoreOpcode::from_usize(
392 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
393 );
394 let rs1_record = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, b);
395
396 let rs1_val = compose(rs1_record.1);
397 let imm = c.as_canonical_u32();
398 let imm_sign = g.as_canonical_u32();
399 let imm_extended = imm + imm_sign * 0xffff0000;
400
401 let ptr_val = rs1_val.wrapping_add(imm_extended);
402 let shift_amount = ptr_val % 4;
403 assert!(
404 ptr_val < (1 << self.air.pointer_max_bits),
405 "ptr_val: {ptr_val} = rs1_val: {rs1_val} + imm_extended: {imm_extended} >= 2 ** {}",
406 self.air.pointer_max_bits
407 );
408
409 let mem_ptr_limbs = array::from_fn(|i| ((ptr_val >> (i * (RV32_CELL_BITS * 2))) & 0xffff));
410
411 let ptr_val = ptr_val - shift_amount;
412 let read_record = match local_opcode {
413 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
414 memory.read::<RV32_REGISTER_NUM_LIMBS>(e, F::from_canonical_u32(ptr_val))
415 }
416 STOREW | STOREH | STOREB => memory.read::<RV32_REGISTER_NUM_LIMBS>(d, a),
417 };
418
419 let prev_data = match local_opcode {
421 STOREW | STOREH | STOREB => array::from_fn(|i| {
422 memory.unsafe_read_cell(e, F::from_canonical_usize(ptr_val as usize + i))
423 }),
424 LOADW | LOADB | LOADH | LOADBU | LOADHU => {
425 array::from_fn(|i| memory.unsafe_read_cell(d, a + F::from_canonical_usize(i)))
426 }
427 };
428
429 Ok((
430 (
431 [prev_data, read_record.1],
432 F::from_canonical_u32(shift_amount),
433 ),
434 Self::ReadRecord {
435 rs1_record: rs1_record.0,
436 rs1_ptr: b,
437 read: read_record.0,
438 imm: c,
439 imm_sign: g,
440 shift_amount,
441 mem_ptr_limbs,
442 mem_as: e,
443 },
444 ))
445 }
446
447 fn postprocess(
448 &mut self,
449 memory: &mut MemoryController<F>,
450 instruction: &Instruction<F>,
451 from_state: ExecutionState<u32>,
452 output: AdapterRuntimeContext<F, Self::Interface>,
453 read_record: &Self::ReadRecord,
454 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
455 let Instruction {
456 opcode,
457 a,
458 d,
459 e,
460 f: enabled,
461 ..
462 } = *instruction;
463
464 let local_opcode = Rv32LoadStoreOpcode::from_usize(
465 opcode.local_opcode_idx(Rv32LoadStoreOpcode::CLASS_OFFSET),
466 );
467
468 let write_id = if enabled != F::ZERO {
469 let (record_id, _) = match local_opcode {
470 STOREW | STOREH | STOREB => {
471 let ptr = read_record.mem_ptr_limbs[0]
472 + read_record.mem_ptr_limbs[1] * (1 << (RV32_CELL_BITS * 2));
473 memory.write(e, F::from_canonical_u32(ptr & 0xfffffffc), output.writes[0])
474 }
475 LOADW | LOADB | LOADH | LOADBU | LOADHU => memory.write(d, a, output.writes[0]),
476 };
477 record_id
478 } else {
479 memory.increment_timestamp();
480 RecordId(usize::MAX)
482 };
483
484 Ok((
485 ExecutionState {
486 pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
487 timestamp: memory.timestamp(),
488 },
489 Self::WriteRecord {
490 from_state,
491 write_id,
492 rd_rs2_ptr: a,
493 },
494 ))
495 }
496
497 fn generate_trace_row(
498 &self,
499 row_slice: &mut [F],
500 read_record: Self::ReadRecord,
501 write_record: Self::WriteRecord,
502 memory: &OfflineMemory<F>,
503 ) {
504 self.range_checker_chip.add_count(
505 (read_record.mem_ptr_limbs[0] - read_record.shift_amount) / 4,
506 RV32_CELL_BITS * 2 - 2,
507 );
508 self.range_checker_chip.add_count(
509 read_record.mem_ptr_limbs[1],
510 self.air.pointer_max_bits - RV32_CELL_BITS * 2,
511 );
512
513 let aux_cols_factory = memory.aux_cols_factory();
514 let adapter_cols: &mut Rv32LoadStoreAdapterCols<_> = row_slice.borrow_mut();
515 adapter_cols.from_state = write_record.from_state.map(F::from_canonical_u32);
516 let rs1 = memory.record_by_id(read_record.rs1_record);
517 adapter_cols.rs1_data.copy_from_slice(rs1.data_slice());
518 aux_cols_factory.generate_read_aux(rs1, &mut adapter_cols.rs1_aux_cols);
519 adapter_cols.rs1_ptr = read_record.rs1_ptr;
520 adapter_cols.rd_rs2_ptr = write_record.rd_rs2_ptr;
521 let read = memory.record_by_id(read_record.read);
522 aux_cols_factory.generate_read_aux(read, &mut adapter_cols.read_data_aux);
523 adapter_cols.imm = read_record.imm;
524 adapter_cols.imm_sign = read_record.imm_sign;
525 adapter_cols.mem_ptr_limbs = read_record.mem_ptr_limbs.map(F::from_canonical_u32);
526 adapter_cols.mem_as = read_record.mem_as;
527 if write_record.write_id.0 != usize::MAX {
528 let write = memory.record_by_id(write_record.write_id);
529 aux_cols_factory.generate_base_aux(write, &mut adapter_cols.write_base_aux);
530 adapter_cols.needs_write = F::ONE;
531 }
532 }
533
534 fn air(&self) -> &Self::Air {
535 &self.air
536 }
537}