1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 iter::{once, zip},
5};
6
7use itertools::izip;
8use openvm_circuit::{
9 arch::{
10 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
11 ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir,
12 },
13 system::memory::{
14 offline_checker::{
15 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
16 MemoryWriteBytesAuxRecord,
17 },
18 online::TracingMemory,
19 MemoryAddress, MemoryAuxColsFactory,
20 },
21};
22use openvm_circuit_primitives::{
23 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
24 AlignedBytesBorrow,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28 instruction::Instruction,
29 program::DEFAULT_PC_STEP,
30 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33 abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36 interaction::InteractionBuilder,
37 p3_air::BaseAir,
38 p3_field::{Field, FieldAlgebra, PrimeField32},
39};
40
41#[repr(C)]
49#[derive(AlignedBorrow, Debug)]
50pub struct Rv32VecHeapAdapterCols<
51 T,
52 const NUM_READS: usize,
53 const BLOCKS_PER_READ: usize,
54 const BLOCKS_PER_WRITE: usize,
55 const READ_SIZE: usize,
56 const WRITE_SIZE: usize,
57> {
58 pub from_state: ExecutionState<T>,
59
60 pub rs_ptr: [T; NUM_READS],
61 pub rd_ptr: T,
62
63 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
64 pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
65
66 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
67 pub rd_read_aux: MemoryReadAuxCols<T>,
68
69 pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
70 pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
71}
72
73#[allow(dead_code)]
74#[derive(Clone, Copy, Debug, derive_new::new)]
75pub struct Rv32VecHeapAdapterAir<
76 const NUM_READS: usize,
77 const BLOCKS_PER_READ: usize,
78 const BLOCKS_PER_WRITE: usize,
79 const READ_SIZE: usize,
80 const WRITE_SIZE: usize,
81> {
82 pub(super) execution_bridge: ExecutionBridge,
83 pub(super) memory_bridge: MemoryBridge,
84 pub bus: BitwiseOperationLookupBus,
85 address_bits: usize,
87}
88
89impl<
90 F: Field,
91 const NUM_READS: usize,
92 const BLOCKS_PER_READ: usize,
93 const BLOCKS_PER_WRITE: usize,
94 const READ_SIZE: usize,
95 const WRITE_SIZE: usize,
96 > BaseAir<F>
97 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
98{
99 fn width(&self) -> usize {
100 Rv32VecHeapAdapterCols::<
101 F,
102 NUM_READS,
103 BLOCKS_PER_READ,
104 BLOCKS_PER_WRITE,
105 READ_SIZE,
106 WRITE_SIZE,
107 >::width()
108 }
109}
110
111impl<
112 AB: InteractionBuilder,
113 const NUM_READS: usize,
114 const BLOCKS_PER_READ: usize,
115 const BLOCKS_PER_WRITE: usize,
116 const READ_SIZE: usize,
117 const WRITE_SIZE: usize,
118 > VmAdapterAir<AB>
119 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
120{
121 type Interface = VecHeapAdapterInterface<
122 AB::Expr,
123 NUM_READS,
124 BLOCKS_PER_READ,
125 BLOCKS_PER_WRITE,
126 READ_SIZE,
127 WRITE_SIZE,
128 >;
129
130 fn eval(
131 &self,
132 builder: &mut AB,
133 local: &[AB::Var],
134 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
135 ) {
136 let cols: &Rv32VecHeapAdapterCols<
137 _,
138 NUM_READS,
139 BLOCKS_PER_READ,
140 BLOCKS_PER_WRITE,
141 READ_SIZE,
142 WRITE_SIZE,
143 > = local.borrow();
144 let timestamp = cols.from_state.timestamp;
145 let mut timestamp_delta: usize = 0;
146 let mut timestamp_pp = || {
147 timestamp_delta += 1;
148 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
149 };
150
151 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
153 cols.rd_ptr,
154 cols.rd_val,
155 &cols.rd_read_aux,
156 ))) {
157 self.memory_bridge
158 .read(
159 MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
160 val,
161 timestamp_pp(),
162 aux,
163 )
164 .eval(builder, ctx.instruction.is_valid.clone());
165 }
166
167 let need_range_check: Vec<AB::Var> = cols
173 .rs_val
174 .iter()
175 .chain(std::iter::repeat_n(&cols.rd_val, 2))
176 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
177 .collect();
178
179 let limb_shift = AB::F::from_canonical_usize(
182 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
183 );
184
185 for pair in need_range_check.chunks_exact(2) {
189 self.bus
190 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
191 .eval(builder, ctx.instruction.is_valid.clone());
192 }
193
194 let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
196 let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
197
198 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
199 for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
201 for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
202 self.memory_bridge
203 .read(
204 MemoryAddress::new(
205 e,
206 address.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
207 ),
208 read,
209 timestamp_pp(),
210 aux,
211 )
212 .eval(builder, ctx.instruction.is_valid.clone());
213 }
214 }
215
216 for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
218 self.memory_bridge
219 .write(
220 MemoryAddress::new(
221 e,
222 rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
223 ),
224 write,
225 timestamp_pp(),
226 aux,
227 )
228 .eval(builder, ctx.instruction.is_valid.clone());
229 }
230
231 self.execution_bridge
232 .execute_and_increment_or_set_pc(
233 ctx.instruction.opcode,
234 [
235 cols.rd_ptr.into(),
236 cols.rs_ptr
237 .first()
238 .map(|&x| x.into())
239 .unwrap_or(AB::Expr::ZERO),
240 cols.rs_ptr
241 .get(1)
242 .map(|&x| x.into())
243 .unwrap_or(AB::Expr::ZERO),
244 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
245 e.into(),
246 ],
247 cols.from_state,
248 AB::F::from_canonical_usize(timestamp_delta),
249 (DEFAULT_PC_STEP, ctx.to_pc),
250 )
251 .eval(builder, ctx.instruction.is_valid.clone());
252 }
253
254 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
255 let cols: &Rv32VecHeapAdapterCols<
256 _,
257 NUM_READS,
258 BLOCKS_PER_READ,
259 BLOCKS_PER_WRITE,
260 READ_SIZE,
261 WRITE_SIZE,
262 > = local.borrow();
263 cols.from_state.pc
264 }
265}
266
267#[repr(C)]
269#[derive(AlignedBytesBorrow, Debug)]
270pub struct Rv32VecHeapAdapterRecord<
271 const NUM_READS: usize,
272 const BLOCKS_PER_READ: usize,
273 const BLOCKS_PER_WRITE: usize,
274 const READ_SIZE: usize,
275 const WRITE_SIZE: usize,
276> {
277 pub from_pc: u32,
278 pub from_timestamp: u32,
279
280 pub rs_ptrs: [u32; NUM_READS],
281 pub rd_ptr: u32,
282
283 pub rs_vals: [u32; NUM_READS],
284 pub rd_val: u32,
285
286 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
287 pub rd_read_aux: MemoryReadAuxRecord,
288
289 pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
290 pub writes_aux: [MemoryWriteBytesAuxRecord<WRITE_SIZE>; BLOCKS_PER_WRITE],
291}
292
293#[derive(derive_new::new, Clone, Copy)]
294pub struct Rv32VecHeapAdapterExecutor<
295 const NUM_READS: usize,
296 const BLOCKS_PER_READ: usize,
297 const BLOCKS_PER_WRITE: usize,
298 const READ_SIZE: usize,
299 const WRITE_SIZE: usize,
300> {
301 pointer_max_bits: usize,
302}
303
304#[derive(derive_new::new)]
305pub struct Rv32VecHeapAdapterFiller<
306 const NUM_READS: usize,
307 const BLOCKS_PER_READ: usize,
308 const BLOCKS_PER_WRITE: usize,
309 const READ_SIZE: usize,
310 const WRITE_SIZE: usize,
311> {
312 pointer_max_bits: usize,
313 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
314}
315
316impl<
317 F: PrimeField32,
318 const NUM_READS: usize,
319 const BLOCKS_PER_READ: usize,
320 const BLOCKS_PER_WRITE: usize,
321 const READ_SIZE: usize,
322 const WRITE_SIZE: usize,
323 > AdapterTraceExecutor<F>
324 for Rv32VecHeapAdapterExecutor<
325 NUM_READS,
326 BLOCKS_PER_READ,
327 BLOCKS_PER_WRITE,
328 READ_SIZE,
329 WRITE_SIZE,
330 >
331{
332 const WIDTH: usize = Rv32VecHeapAdapterCols::<
333 F,
334 NUM_READS,
335 BLOCKS_PER_READ,
336 BLOCKS_PER_WRITE,
337 READ_SIZE,
338 WRITE_SIZE,
339 >::width();
340 type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS];
341 type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE];
342 type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord<
343 NUM_READS,
344 BLOCKS_PER_READ,
345 BLOCKS_PER_WRITE,
346 READ_SIZE,
347 WRITE_SIZE,
348 >;
349
350 #[inline(always)]
351 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
352 record.from_pc = pc;
353 record.from_timestamp = memory.timestamp;
354 }
355
356 fn read(
357 &self,
358 memory: &mut TracingMemory,
359 instruction: &Instruction<F>,
360 record: &mut &mut Rv32VecHeapAdapterRecord<
361 NUM_READS,
362 BLOCKS_PER_READ,
363 BLOCKS_PER_WRITE,
364 READ_SIZE,
365 WRITE_SIZE,
366 >,
367 ) -> Self::ReadData {
368 let &Instruction { a, b, c, d, e, .. } = instruction;
369
370 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
371 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
372
373 record.rs_vals = from_fn(|i| {
375 record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32();
376 u32::from_le_bytes(tracing_read(
377 memory,
378 RV32_REGISTER_AS,
379 record.rs_ptrs[i],
380 &mut record.rs_read_aux[i].prev_timestamp,
381 ))
382 });
383
384 record.rd_ptr = a.as_canonical_u32();
385 record.rd_val = u32::from_le_bytes(tracing_read(
386 memory,
387 RV32_REGISTER_AS,
388 a.as_canonical_u32(),
389 &mut record.rd_read_aux.prev_timestamp,
390 ));
391
392 from_fn(|i| {
394 debug_assert!(
395 (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32)
396 < (1 << self.pointer_max_bits) as u32
397 );
398 from_fn(|j| {
399 tracing_read(
400 memory,
401 RV32_MEMORY_AS,
402 record.rs_vals[i] + (j * READ_SIZE) as u32,
403 &mut record.reads_aux[i][j].prev_timestamp,
404 )
405 })
406 })
407 }
408
409 fn write(
410 &self,
411 memory: &mut TracingMemory,
412 instruction: &Instruction<F>,
413 data: Self::WriteData,
414 record: &mut &mut Rv32VecHeapAdapterRecord<
415 NUM_READS,
416 BLOCKS_PER_READ,
417 BLOCKS_PER_WRITE,
418 READ_SIZE,
419 WRITE_SIZE,
420 >,
421 ) {
422 debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS);
423
424 debug_assert!(
425 record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1
426 < (1 << self.pointer_max_bits)
427 );
428
429 #[allow(clippy::needless_range_loop)]
430 for i in 0..BLOCKS_PER_WRITE {
431 tracing_write(
432 memory,
433 RV32_MEMORY_AS,
434 record.rd_val + (i * WRITE_SIZE) as u32,
435 data[i],
436 &mut record.writes_aux[i].prev_timestamp,
437 &mut record.writes_aux[i].prev_data,
438 );
439 }
440 }
441}
442
443impl<
444 F: PrimeField32,
445 const NUM_READS: usize,
446 const BLOCKS_PER_READ: usize,
447 const BLOCKS_PER_WRITE: usize,
448 const READ_SIZE: usize,
449 const WRITE_SIZE: usize,
450 > AdapterTraceFiller<F>
451 for Rv32VecHeapAdapterFiller<
452 NUM_READS,
453 BLOCKS_PER_READ,
454 BLOCKS_PER_WRITE,
455 READ_SIZE,
456 WRITE_SIZE,
457 >
458{
459 const WIDTH: usize = Rv32VecHeapAdapterCols::<
460 F,
461 NUM_READS,
462 BLOCKS_PER_READ,
463 BLOCKS_PER_WRITE,
464 READ_SIZE,
465 WRITE_SIZE,
466 >::width();
467
468 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
469 let record: &Rv32VecHeapAdapterRecord<
473 NUM_READS,
474 BLOCKS_PER_READ,
475 BLOCKS_PER_WRITE,
476 READ_SIZE,
477 WRITE_SIZE,
478 > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
479
480 let cols: &mut Rv32VecHeapAdapterCols<
481 F,
482 NUM_READS,
483 BLOCKS_PER_READ,
484 BLOCKS_PER_WRITE,
485 READ_SIZE,
486 WRITE_SIZE,
487 > = adapter_row.borrow_mut();
488
489 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
492 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
493 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
494 if NUM_READS > 1 {
495 self.bitwise_lookup_chip.request_range(
496 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
497 (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits,
498 );
499 self.bitwise_lookup_chip.request_range(
500 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
501 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
502 );
503 } else {
504 self.bitwise_lookup_chip.request_range(
505 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
506 (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
507 );
508 }
509
510 let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE;
511 let mut timestamp = record.from_timestamp + timestamp_delta as u32;
512 let mut timestamp_mm = || {
513 timestamp -= 1;
514 timestamp
515 };
516
517 record
519 .writes_aux
520 .iter()
521 .rev()
522 .zip(cols.writes_aux.iter_mut().rev())
523 .for_each(|(write, cols_write)| {
524 cols_write.set_prev_data(write.prev_data.map(F::from_canonical_u8));
525 mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut());
526 });
527
528 record
529 .reads_aux
530 .iter()
531 .zip(cols.reads_aux.iter_mut())
532 .rev()
533 .for_each(|(reads, cols_reads)| {
534 reads
535 .iter()
536 .zip(cols_reads.iter_mut())
537 .rev()
538 .for_each(|(read, cols_read)| {
539 mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut());
540 });
541 });
542
543 mem_helper.fill(
544 record.rd_read_aux.prev_timestamp,
545 timestamp_mm(),
546 cols.rd_read_aux.as_mut(),
547 );
548
549 record
550 .rs_read_aux
551 .iter()
552 .zip(cols.rs_read_aux.iter_mut())
553 .rev()
554 .for_each(|(aux, cols_aux)| {
555 mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut());
556 });
557
558 cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8);
559 cols.rs_val
560 .iter_mut()
561 .rev()
562 .zip(record.rs_vals.iter().rev())
563 .for_each(|(cols_val, val)| {
564 *cols_val = val.to_le_bytes().map(F::from_canonical_u8);
565 });
566 cols.rd_ptr = F::from_canonical_u32(record.rd_ptr);
567 cols.rs_ptr
568 .iter_mut()
569 .rev()
570 .zip(record.rs_ptrs.iter().rev())
571 .for_each(|(cols_ptr, ptr)| {
572 *cols_ptr = F::from_canonical_u32(*ptr);
573 });
574 cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
575 cols.from_state.pc = F::from_canonical_u32(record.from_pc);
576 }
577}