1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 iter::{once, zip},
5};
6
7use itertools::izip;
8use openvm_circuit::{
9 arch::{
10 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
11 ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir,
12 },
13 system::memory::{
14 offline_checker::{
15 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
16 MemoryWriteBytesAuxRecord,
17 },
18 online::TracingMemory,
19 MemoryAddress, MemoryAuxColsFactory,
20 },
21};
22use openvm_circuit_primitives::{
23 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
24 AlignedBytesBorrow,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28 instruction::Instruction,
29 program::DEFAULT_PC_STEP,
30 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33 abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36 interaction::InteractionBuilder,
37 p3_air::BaseAir,
38 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
39};
40
41#[repr(C)]
49#[derive(AlignedBorrow, Debug)]
50pub struct Rv32VecHeapAdapterCols<
51 T,
52 const NUM_READS: usize,
53 const BLOCKS_PER_READ: usize,
54 const BLOCKS_PER_WRITE: usize,
55 const READ_SIZE: usize,
56 const WRITE_SIZE: usize,
57> {
58 pub from_state: ExecutionState<T>,
59
60 pub rs_ptr: [T; NUM_READS],
61 pub rd_ptr: T,
62
63 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
64 pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
65
66 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
67 pub rd_read_aux: MemoryReadAuxCols<T>,
68
69 pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
70 pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
71}
72
73#[allow(dead_code)]
74#[derive(Clone, Copy, Debug, derive_new::new)]
75pub struct Rv32VecHeapAdapterAir<
76 const NUM_READS: usize,
77 const BLOCKS_PER_READ: usize,
78 const BLOCKS_PER_WRITE: usize,
79 const READ_SIZE: usize,
80 const WRITE_SIZE: usize,
81> {
82 pub(super) execution_bridge: ExecutionBridge,
83 pub(super) memory_bridge: MemoryBridge,
84 pub bus: BitwiseOperationLookupBus,
85 address_bits: usize,
87}
88
89impl<
90 F: Field,
91 const NUM_READS: usize,
92 const BLOCKS_PER_READ: usize,
93 const BLOCKS_PER_WRITE: usize,
94 const READ_SIZE: usize,
95 const WRITE_SIZE: usize,
96 > BaseAir<F>
97 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
98{
99 fn width(&self) -> usize {
100 Rv32VecHeapAdapterCols::<
101 F,
102 NUM_READS,
103 BLOCKS_PER_READ,
104 BLOCKS_PER_WRITE,
105 READ_SIZE,
106 WRITE_SIZE,
107 >::width()
108 }
109}
110
111impl<
112 AB: InteractionBuilder,
113 const NUM_READS: usize,
114 const BLOCKS_PER_READ: usize,
115 const BLOCKS_PER_WRITE: usize,
116 const READ_SIZE: usize,
117 const WRITE_SIZE: usize,
118 > VmAdapterAir<AB>
119 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
120{
121 type Interface = VecHeapAdapterInterface<
122 AB::Expr,
123 NUM_READS,
124 BLOCKS_PER_READ,
125 BLOCKS_PER_WRITE,
126 READ_SIZE,
127 WRITE_SIZE,
128 >;
129
130 fn eval(
131 &self,
132 builder: &mut AB,
133 local: &[AB::Var],
134 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
135 ) {
136 let cols: &Rv32VecHeapAdapterCols<
137 _,
138 NUM_READS,
139 BLOCKS_PER_READ,
140 BLOCKS_PER_WRITE,
141 READ_SIZE,
142 WRITE_SIZE,
143 > = local.borrow();
144 let timestamp = cols.from_state.timestamp;
145 let mut timestamp_delta: usize = 0;
146 let mut timestamp_pp = || {
147 timestamp_delta += 1;
148 timestamp + AB::F::from_usize(timestamp_delta - 1)
149 };
150
151 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
153 cols.rd_ptr,
154 cols.rd_val,
155 &cols.rd_read_aux,
156 ))) {
157 self.memory_bridge
158 .read(
159 MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), ptr),
160 val,
161 timestamp_pp(),
162 aux,
163 )
164 .eval(builder, ctx.instruction.is_valid.clone());
165 }
166
167 let need_range_check: Vec<AB::Var> = cols
173 .rs_val
174 .iter()
175 .chain(std::iter::repeat_n(&cols.rd_val, 2))
176 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
177 .collect();
178
179 let limb_shift =
182 AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
183
184 for pair in need_range_check.chunks_exact(2) {
188 self.bus
189 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
190 .eval(builder, ctx.instruction.is_valid.clone());
191 }
192
193 let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
195 let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
196
197 let e = AB::F::from_u32(RV32_MEMORY_AS);
198 for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
200 for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
201 self.memory_bridge
202 .read(
203 MemoryAddress::new(
204 e,
205 address.clone() + AB::Expr::from_usize(i * READ_SIZE),
206 ),
207 read,
208 timestamp_pp(),
209 aux,
210 )
211 .eval(builder, ctx.instruction.is_valid.clone());
212 }
213 }
214
215 for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
217 self.memory_bridge
218 .write(
219 MemoryAddress::new(e, rd_val_f.clone() + AB::Expr::from_usize(i * WRITE_SIZE)),
220 write,
221 timestamp_pp(),
222 aux,
223 )
224 .eval(builder, ctx.instruction.is_valid.clone());
225 }
226
227 self.execution_bridge
228 .execute_and_increment_or_set_pc(
229 ctx.instruction.opcode,
230 [
231 cols.rd_ptr.into(),
232 cols.rs_ptr
233 .first()
234 .map(|&x| x.into())
235 .unwrap_or(AB::Expr::ZERO),
236 cols.rs_ptr
237 .get(1)
238 .map(|&x| x.into())
239 .unwrap_or(AB::Expr::ZERO),
240 AB::Expr::from_u32(RV32_REGISTER_AS),
241 e.into(),
242 ],
243 cols.from_state,
244 AB::F::from_usize(timestamp_delta),
245 (DEFAULT_PC_STEP, ctx.to_pc),
246 )
247 .eval(builder, ctx.instruction.is_valid.clone());
248 }
249
250 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
251 let cols: &Rv32VecHeapAdapterCols<
252 _,
253 NUM_READS,
254 BLOCKS_PER_READ,
255 BLOCKS_PER_WRITE,
256 READ_SIZE,
257 WRITE_SIZE,
258 > = local.borrow();
259 cols.from_state.pc
260 }
261}
262
263#[repr(C)]
265#[derive(AlignedBytesBorrow, Debug)]
266pub struct Rv32VecHeapAdapterRecord<
267 const NUM_READS: usize,
268 const BLOCKS_PER_READ: usize,
269 const BLOCKS_PER_WRITE: usize,
270 const READ_SIZE: usize,
271 const WRITE_SIZE: usize,
272> {
273 pub from_pc: u32,
274 pub from_timestamp: u32,
275
276 pub rs_ptrs: [u32; NUM_READS],
277 pub rd_ptr: u32,
278
279 pub rs_vals: [u32; NUM_READS],
280 pub rd_val: u32,
281
282 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
283 pub rd_read_aux: MemoryReadAuxRecord,
284
285 pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
286 pub writes_aux: [MemoryWriteBytesAuxRecord<WRITE_SIZE>; BLOCKS_PER_WRITE],
287}
288
289#[derive(derive_new::new, Clone, Copy)]
290pub struct Rv32VecHeapAdapterExecutor<
291 const NUM_READS: usize,
292 const BLOCKS_PER_READ: usize,
293 const BLOCKS_PER_WRITE: usize,
294 const READ_SIZE: usize,
295 const WRITE_SIZE: usize,
296> {
297 pointer_max_bits: usize,
298}
299
300#[derive(derive_new::new)]
301pub struct Rv32VecHeapAdapterFiller<
302 const NUM_READS: usize,
303 const BLOCKS_PER_READ: usize,
304 const BLOCKS_PER_WRITE: usize,
305 const READ_SIZE: usize,
306 const WRITE_SIZE: usize,
307> {
308 pointer_max_bits: usize,
309 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
310}
311
312impl<
313 F: PrimeField32,
314 const NUM_READS: usize,
315 const BLOCKS_PER_READ: usize,
316 const BLOCKS_PER_WRITE: usize,
317 const READ_SIZE: usize,
318 const WRITE_SIZE: usize,
319 > AdapterTraceExecutor<F>
320 for Rv32VecHeapAdapterExecutor<
321 NUM_READS,
322 BLOCKS_PER_READ,
323 BLOCKS_PER_WRITE,
324 READ_SIZE,
325 WRITE_SIZE,
326 >
327{
328 const WIDTH: usize = Rv32VecHeapAdapterCols::<
329 F,
330 NUM_READS,
331 BLOCKS_PER_READ,
332 BLOCKS_PER_WRITE,
333 READ_SIZE,
334 WRITE_SIZE,
335 >::width();
336 type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS];
337 type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE];
338 type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord<
339 NUM_READS,
340 BLOCKS_PER_READ,
341 BLOCKS_PER_WRITE,
342 READ_SIZE,
343 WRITE_SIZE,
344 >;
345
346 #[inline(always)]
347 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
348 record.from_pc = pc;
349 record.from_timestamp = memory.timestamp;
350 }
351
352 fn read(
353 &self,
354 memory: &mut TracingMemory,
355 instruction: &Instruction<F>,
356 record: &mut &mut Rv32VecHeapAdapterRecord<
357 NUM_READS,
358 BLOCKS_PER_READ,
359 BLOCKS_PER_WRITE,
360 READ_SIZE,
361 WRITE_SIZE,
362 >,
363 ) -> Self::ReadData {
364 let &Instruction { a, b, c, d, e, .. } = instruction;
365
366 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
367 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
368
369 record.rs_vals = from_fn(|i| {
371 record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32();
372 u32::from_le_bytes(tracing_read(
373 memory,
374 RV32_REGISTER_AS,
375 record.rs_ptrs[i],
376 &mut record.rs_read_aux[i].prev_timestamp,
377 ))
378 });
379
380 record.rd_ptr = a.as_canonical_u32();
381 record.rd_val = u32::from_le_bytes(tracing_read(
382 memory,
383 RV32_REGISTER_AS,
384 a.as_canonical_u32(),
385 &mut record.rd_read_aux.prev_timestamp,
386 ));
387
388 from_fn(|i| {
390 debug_assert!(
391 (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32)
392 < (1 << self.pointer_max_bits) as u32
393 );
394 from_fn(|j| {
395 tracing_read(
396 memory,
397 RV32_MEMORY_AS,
398 record.rs_vals[i] + (j * READ_SIZE) as u32,
399 &mut record.reads_aux[i][j].prev_timestamp,
400 )
401 })
402 })
403 }
404
405 fn write(
406 &self,
407 memory: &mut TracingMemory,
408 instruction: &Instruction<F>,
409 data: Self::WriteData,
410 record: &mut &mut Rv32VecHeapAdapterRecord<
411 NUM_READS,
412 BLOCKS_PER_READ,
413 BLOCKS_PER_WRITE,
414 READ_SIZE,
415 WRITE_SIZE,
416 >,
417 ) {
418 debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS);
419
420 debug_assert!(
421 record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1
422 < (1 << self.pointer_max_bits)
423 );
424
425 #[allow(clippy::needless_range_loop)]
426 for i in 0..BLOCKS_PER_WRITE {
427 tracing_write(
428 memory,
429 RV32_MEMORY_AS,
430 record.rd_val + (i * WRITE_SIZE) as u32,
431 data[i],
432 &mut record.writes_aux[i].prev_timestamp,
433 &mut record.writes_aux[i].prev_data,
434 );
435 }
436 }
437}
438
439impl<
440 F: PrimeField32,
441 const NUM_READS: usize,
442 const BLOCKS_PER_READ: usize,
443 const BLOCKS_PER_WRITE: usize,
444 const READ_SIZE: usize,
445 const WRITE_SIZE: usize,
446 > AdapterTraceFiller<F>
447 for Rv32VecHeapAdapterFiller<
448 NUM_READS,
449 BLOCKS_PER_READ,
450 BLOCKS_PER_WRITE,
451 READ_SIZE,
452 WRITE_SIZE,
453 >
454{
455 const WIDTH: usize = Rv32VecHeapAdapterCols::<
456 F,
457 NUM_READS,
458 BLOCKS_PER_READ,
459 BLOCKS_PER_WRITE,
460 READ_SIZE,
461 WRITE_SIZE,
462 >::width();
463
464 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
465 let record: &Rv32VecHeapAdapterRecord<
469 NUM_READS,
470 BLOCKS_PER_READ,
471 BLOCKS_PER_WRITE,
472 READ_SIZE,
473 WRITE_SIZE,
474 > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
475
476 let cols: &mut Rv32VecHeapAdapterCols<
477 F,
478 NUM_READS,
479 BLOCKS_PER_READ,
480 BLOCKS_PER_WRITE,
481 READ_SIZE,
482 WRITE_SIZE,
483 > = adapter_row.borrow_mut();
484
485 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
488 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
489 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
490 if NUM_READS > 1 {
491 self.bitwise_lookup_chip.request_range(
492 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
493 (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits,
494 );
495 self.bitwise_lookup_chip.request_range(
496 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
497 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
498 );
499 } else {
500 self.bitwise_lookup_chip.request_range(
501 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
502 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
503 );
504 }
505
506 let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE;
507 let mut timestamp = record.from_timestamp + timestamp_delta as u32;
508 let mut timestamp_mm = || {
509 timestamp -= 1;
510 timestamp
511 };
512
513 record
515 .writes_aux
516 .iter()
517 .rev()
518 .zip(cols.writes_aux.iter_mut().rev())
519 .for_each(|(write, cols_write)| {
520 cols_write.set_prev_data(write.prev_data.map(F::from_u8));
521 mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut());
522 });
523
524 record
525 .reads_aux
526 .iter()
527 .zip(cols.reads_aux.iter_mut())
528 .rev()
529 .for_each(|(reads, cols_reads)| {
530 reads
531 .iter()
532 .zip(cols_reads.iter_mut())
533 .rev()
534 .for_each(|(read, cols_read)| {
535 mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut());
536 });
537 });
538
539 mem_helper.fill(
540 record.rd_read_aux.prev_timestamp,
541 timestamp_mm(),
542 cols.rd_read_aux.as_mut(),
543 );
544
545 record
546 .rs_read_aux
547 .iter()
548 .zip(cols.rs_read_aux.iter_mut())
549 .rev()
550 .for_each(|(aux, cols_aux)| {
551 mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut());
552 });
553
554 cols.rd_val = record.rd_val.to_le_bytes().map(F::from_u8);
555 cols.rs_val
556 .iter_mut()
557 .rev()
558 .zip(record.rs_vals.iter().rev())
559 .for_each(|(cols_val, val)| {
560 *cols_val = val.to_le_bytes().map(F::from_u8);
561 });
562 cols.rd_ptr = F::from_u32(record.rd_ptr);
563 cols.rs_ptr
564 .iter_mut()
565 .rev()
566 .zip(record.rs_ptrs.iter().rev())
567 .for_each(|(cols_ptr, ptr)| {
568 *cols_ptr = F::from_u32(*ptr);
569 });
570 cols.from_state.timestamp = F::from_u32(record.from_timestamp);
571 cols.from_state.pc = F::from_u32(record.from_pc);
572 }
573}