1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 iter::{once, zip},
5 marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10 arch::{
11 AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
12 Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
13 },
14 system::{
15 memory::{
16 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17 MemoryAddress, MemoryController, OfflineMemory, RecordId,
18 },
19 program::ProgramBus,
20 },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27 instruction::Instruction,
28 program::DEFAULT_PC_STEP,
29 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32 abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::BaseAir,
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_with::serde_as;
41
42#[derive(Clone)]
51pub struct Rv32VecHeapAdapterChip<
52 F: Field,
53 const NUM_READS: usize,
54 const BLOCKS_PER_READ: usize,
55 const BLOCKS_PER_WRITE: usize,
56 const READ_SIZE: usize,
57 const WRITE_SIZE: usize,
58> {
59 pub air:
60 Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>,
61 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
62 _marker: PhantomData<F>,
63}
64
65impl<
66 F: PrimeField32,
67 const NUM_READS: usize,
68 const BLOCKS_PER_READ: usize,
69 const BLOCKS_PER_WRITE: usize,
70 const READ_SIZE: usize,
71 const WRITE_SIZE: usize,
72 >
73 Rv32VecHeapAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
74{
75 pub fn new(
76 execution_bus: ExecutionBus,
77 program_bus: ProgramBus,
78 memory_bridge: MemoryBridge,
79 address_bits: usize,
80 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
81 ) -> Self {
82 assert!(NUM_READS <= 2);
83 assert!(
84 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
85 "address_bits={address_bits} needs to be large enough for high limb range check"
86 );
87 Self {
88 air: Rv32VecHeapAdapterAir {
89 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
90 memory_bridge,
91 bus: bitwise_lookup_chip.bus(),
92 address_bits,
93 },
94 bitwise_lookup_chip,
95 _marker: PhantomData,
96 }
97 }
98}
99
100#[repr(C)]
101#[serde_as]
102#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
103#[serde(bound = "F: Field")]
104pub struct Rv32VecHeapReadRecord<
105 F: Field,
106 const NUM_READS: usize,
107 const BLOCKS_PER_READ: usize,
108 const READ_SIZE: usize,
109> {
110 #[serde_as(as = "[_; NUM_READS]")]
112 pub rs: [RecordId; NUM_READS],
113 pub rd: RecordId,
115
116 pub rd_val: F,
117
118 #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")]
119 pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS],
120}
121
122#[repr(C)]
123#[serde_as]
124#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
125pub struct Rv32VecHeapWriteRecord<const BLOCKS_PER_WRITE: usize, const WRITE_SIZE: usize> {
126 pub from_state: ExecutionState<u32>,
127 #[serde_as(as = "[_; BLOCKS_PER_WRITE]")]
128 pub writes: [RecordId; BLOCKS_PER_WRITE],
129}
130
131#[repr(C)]
132#[derive(AlignedBorrow)]
133pub struct Rv32VecHeapAdapterCols<
134 T,
135 const NUM_READS: usize,
136 const BLOCKS_PER_READ: usize,
137 const BLOCKS_PER_WRITE: usize,
138 const READ_SIZE: usize,
139 const WRITE_SIZE: usize,
140> {
141 pub from_state: ExecutionState<T>,
142
143 pub rs_ptr: [T; NUM_READS],
144 pub rd_ptr: T,
145
146 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
147 pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
148
149 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
150 pub rd_read_aux: MemoryReadAuxCols<T>,
151
152 pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
153 pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
154}
155
156#[allow(dead_code)]
157#[derive(Clone, Copy, Debug, derive_new::new)]
158pub struct Rv32VecHeapAdapterAir<
159 const NUM_READS: usize,
160 const BLOCKS_PER_READ: usize,
161 const BLOCKS_PER_WRITE: usize,
162 const READ_SIZE: usize,
163 const WRITE_SIZE: usize,
164> {
165 pub(super) execution_bridge: ExecutionBridge,
166 pub(super) memory_bridge: MemoryBridge,
167 pub bus: BitwiseOperationLookupBus,
168 address_bits: usize,
170}
171
172impl<
173 F: Field,
174 const NUM_READS: usize,
175 const BLOCKS_PER_READ: usize,
176 const BLOCKS_PER_WRITE: usize,
177 const READ_SIZE: usize,
178 const WRITE_SIZE: usize,
179 > BaseAir<F>
180 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
181{
182 fn width(&self) -> usize {
183 Rv32VecHeapAdapterCols::<
184 F,
185 NUM_READS,
186 BLOCKS_PER_READ,
187 BLOCKS_PER_WRITE,
188 READ_SIZE,
189 WRITE_SIZE,
190 >::width()
191 }
192}
193
194impl<
195 AB: InteractionBuilder,
196 const NUM_READS: usize,
197 const BLOCKS_PER_READ: usize,
198 const BLOCKS_PER_WRITE: usize,
199 const READ_SIZE: usize,
200 const WRITE_SIZE: usize,
201 > VmAdapterAir<AB>
202 for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
203{
204 type Interface = VecHeapAdapterInterface<
205 AB::Expr,
206 NUM_READS,
207 BLOCKS_PER_READ,
208 BLOCKS_PER_WRITE,
209 READ_SIZE,
210 WRITE_SIZE,
211 >;
212
213 fn eval(
214 &self,
215 builder: &mut AB,
216 local: &[AB::Var],
217 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
218 ) {
219 let cols: &Rv32VecHeapAdapterCols<
220 _,
221 NUM_READS,
222 BLOCKS_PER_READ,
223 BLOCKS_PER_WRITE,
224 READ_SIZE,
225 WRITE_SIZE,
226 > = local.borrow();
227 let timestamp = cols.from_state.timestamp;
228 let mut timestamp_delta: usize = 0;
229 let mut timestamp_pp = || {
230 timestamp_delta += 1;
231 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
232 };
233
234 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
236 cols.rd_ptr,
237 cols.rd_val,
238 &cols.rd_read_aux,
239 ))) {
240 self.memory_bridge
241 .read(
242 MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
243 val,
244 timestamp_pp(),
245 aux,
246 )
247 .eval(builder, ctx.instruction.is_valid.clone());
248 }
249
250 let need_range_check: Vec<AB::Var> = cols
254 .rs_val
255 .iter()
256 .chain(std::iter::repeat(&cols.rd_val).take(2))
257 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
258 .collect();
259
260 let limb_shift = AB::F::from_canonical_usize(
262 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
263 );
264
265 for pair in need_range_check.chunks_exact(2) {
269 self.bus
270 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
271 .eval(builder, ctx.instruction.is_valid.clone());
272 }
273
274 let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
276 let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
277
278 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
279 for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
281 for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
282 self.memory_bridge
283 .read(
284 MemoryAddress::new(
285 e,
286 address.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
287 ),
288 read,
289 timestamp_pp(),
290 aux,
291 )
292 .eval(builder, ctx.instruction.is_valid.clone());
293 }
294 }
295
296 for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
298 self.memory_bridge
299 .write(
300 MemoryAddress::new(
301 e,
302 rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
303 ),
304 write,
305 timestamp_pp(),
306 aux,
307 )
308 .eval(builder, ctx.instruction.is_valid.clone());
309 }
310
311 self.execution_bridge
312 .execute_and_increment_or_set_pc(
313 ctx.instruction.opcode,
314 [
315 cols.rd_ptr.into(),
316 cols.rs_ptr
317 .first()
318 .map(|&x| x.into())
319 .unwrap_or(AB::Expr::ZERO),
320 cols.rs_ptr
321 .get(1)
322 .map(|&x| x.into())
323 .unwrap_or(AB::Expr::ZERO),
324 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
325 e.into(),
326 ],
327 cols.from_state,
328 AB::F::from_canonical_usize(timestamp_delta),
329 (DEFAULT_PC_STEP, ctx.to_pc),
330 )
331 .eval(builder, ctx.instruction.is_valid.clone());
332 }
333
334 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
335 let cols: &Rv32VecHeapAdapterCols<
336 _,
337 NUM_READS,
338 BLOCKS_PER_READ,
339 BLOCKS_PER_WRITE,
340 READ_SIZE,
341 WRITE_SIZE,
342 > = local.borrow();
343 cols.from_state.pc
344 }
345}
346
347impl<
348 F: PrimeField32,
349 const NUM_READS: usize,
350 const BLOCKS_PER_READ: usize,
351 const BLOCKS_PER_WRITE: usize,
352 const READ_SIZE: usize,
353 const WRITE_SIZE: usize,
354 > VmAdapterChip<F>
355 for Rv32VecHeapAdapterChip<
356 F,
357 NUM_READS,
358 BLOCKS_PER_READ,
359 BLOCKS_PER_WRITE,
360 READ_SIZE,
361 WRITE_SIZE,
362 >
363{
364 type ReadRecord = Rv32VecHeapReadRecord<F, NUM_READS, BLOCKS_PER_READ, READ_SIZE>;
365 type WriteRecord = Rv32VecHeapWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>;
366 type Air =
367 Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>;
368 type Interface = VecHeapAdapterInterface<
369 F,
370 NUM_READS,
371 BLOCKS_PER_READ,
372 BLOCKS_PER_WRITE,
373 READ_SIZE,
374 WRITE_SIZE,
375 >;
376
377 fn preprocess(
378 &mut self,
379 memory: &mut MemoryController<F>,
380 instruction: &Instruction<F>,
381 ) -> Result<(
382 <Self::Interface as VmAdapterInterface<F>>::Reads,
383 Self::ReadRecord,
384 )> {
385 let Instruction { a, b, c, d, e, .. } = *instruction;
386
387 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
388 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
389
390 let mut rs_vals = [0; NUM_READS];
392 let rs_records: [_; NUM_READS] = from_fn(|i| {
393 let addr = if i == 0 { b } else { c };
394 let (record, val) = read_rv32_register(memory, d, addr);
395 rs_vals[i] = val;
396 record
397 });
398 let (rd_record, rd_val) = read_rv32_register(memory, d, a);
399
400 let read_records = rs_vals.map(|address| {
402 assert!(
403 address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits)
404 );
405 from_fn(|i| {
406 memory.read::<READ_SIZE>(e, F::from_canonical_u32(address + (i * READ_SIZE) as u32))
407 })
408 });
409 let read_data = read_records.map(|r| r.map(|x| x.1));
410 assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits));
411
412 let record = Rv32VecHeapReadRecord {
413 rs: rs_records,
414 rd: rd_record,
415 rd_val: F::from_canonical_u32(rd_val),
416 reads: read_records.map(|r| r.map(|x| x.0)),
417 };
418
419 Ok((read_data, record))
420 }
421
422 fn postprocess(
423 &mut self,
424 memory: &mut MemoryController<F>,
425 instruction: &Instruction<F>,
426 from_state: ExecutionState<u32>,
427 output: AdapterRuntimeContext<F, Self::Interface>,
428 read_record: &Self::ReadRecord,
429 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
430 let e = instruction.e;
431 let mut i = 0;
432 let writes = output.writes.map(|write| {
433 let (record_id, _) = memory.write(
434 e,
435 read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32),
436 write,
437 );
438 i += 1;
439 record_id
440 });
441
442 Ok((
443 ExecutionState {
444 pc: from_state.pc + DEFAULT_PC_STEP,
445 timestamp: memory.timestamp(),
446 },
447 Self::WriteRecord { from_state, writes },
448 ))
449 }
450
451 fn generate_trace_row(
452 &self,
453 row_slice: &mut [F],
454 read_record: Self::ReadRecord,
455 write_record: Self::WriteRecord,
456 memory: &OfflineMemory<F>,
457 ) {
458 vec_heap_generate_trace_row_impl(
459 row_slice,
460 &read_record,
461 &write_record,
462 self.bitwise_lookup_chip.clone(),
463 self.air.address_bits,
464 memory,
465 )
466 }
467
468 fn air(&self) -> &Self::Air {
469 &self.air
470 }
471}
472
473pub(super) fn vec_heap_generate_trace_row_impl<
474 F: PrimeField32,
475 const NUM_READS: usize,
476 const BLOCKS_PER_READ: usize,
477 const BLOCKS_PER_WRITE: usize,
478 const READ_SIZE: usize,
479 const WRITE_SIZE: usize,
480>(
481 row_slice: &mut [F],
482 read_record: &Rv32VecHeapReadRecord<F, NUM_READS, BLOCKS_PER_READ, READ_SIZE>,
483 write_record: &Rv32VecHeapWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>,
484 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
485 address_bits: usize,
486 memory: &OfflineMemory<F>,
487) {
488 let aux_cols_factory = memory.aux_cols_factory();
489 let row_slice: &mut Rv32VecHeapAdapterCols<
490 F,
491 NUM_READS,
492 BLOCKS_PER_READ,
493 BLOCKS_PER_WRITE,
494 READ_SIZE,
495 WRITE_SIZE,
496 > = row_slice.borrow_mut();
497 row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
498
499 let rd = memory.record_by_id(read_record.rd);
500 let rs = read_record
501 .rs
502 .into_iter()
503 .map(|r| memory.record_by_id(r))
504 .collect::<Vec<_>>();
505
506 row_slice.rd_ptr = rd.pointer;
507 row_slice.rd_val.copy_from_slice(rd.data_slice());
508
509 for (i, r) in rs.iter().enumerate() {
510 row_slice.rs_ptr[i] = r.pointer;
511 row_slice.rs_val[i].copy_from_slice(r.data_slice());
512 aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
513 }
514
515 aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux);
516
517 for (i, reads) in read_record.reads.iter().enumerate() {
518 for (j, &x) in reads.iter().enumerate() {
519 let record = memory.record_by_id(x);
520 aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]);
521 }
522 }
523
524 for (i, &w) in write_record.writes.iter().enumerate() {
525 let record = memory.record_by_id(w);
526 aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]);
527 }
528
529 let need_range_check: Vec<u32> = rs
531 .iter()
532 .chain(std::iter::repeat(&rd).take(2))
533 .map(|record| {
534 record
535 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
536 .as_canonical_u32()
537 })
538 .collect();
539 debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
540 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits;
541 for pair in need_range_check.chunks_exact(2) {
542 bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
543 }
544}