1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 iter::zip,
5 marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10 arch::{
11 AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
12 Result, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
13 },
14 system::{
15 memory::{
16 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17 MemoryAddress, MemoryController, OfflineMemory, RecordId,
18 },
19 program::ProgramBus,
20 },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27 instruction::Instruction,
28 program::DEFAULT_PC_STEP,
29 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32 abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::BaseAir,
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_with::serde_as;
41
42pub struct Rv32VecHeapTwoReadsAdapterChip<
51 F: Field,
52 const BLOCKS_PER_READ1: usize,
53 const BLOCKS_PER_READ2: usize,
54 const BLOCKS_PER_WRITE: usize,
55 const READ_SIZE: usize,
56 const WRITE_SIZE: usize,
57> {
58 pub air: Rv32VecHeapTwoReadsAdapterAir<
59 BLOCKS_PER_READ1,
60 BLOCKS_PER_READ2,
61 BLOCKS_PER_WRITE,
62 READ_SIZE,
63 WRITE_SIZE,
64 >,
65 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
66 _marker: PhantomData<F>,
67}
68
69impl<
70 F: PrimeField32,
71 const BLOCKS_PER_READ1: usize,
72 const BLOCKS_PER_READ2: usize,
73 const BLOCKS_PER_WRITE: usize,
74 const READ_SIZE: usize,
75 const WRITE_SIZE: usize,
76 >
77 Rv32VecHeapTwoReadsAdapterChip<
78 F,
79 BLOCKS_PER_READ1,
80 BLOCKS_PER_READ2,
81 BLOCKS_PER_WRITE,
82 READ_SIZE,
83 WRITE_SIZE,
84 >
85{
86 pub fn new(
87 execution_bus: ExecutionBus,
88 program_bus: ProgramBus,
89 memory_bridge: MemoryBridge,
90 address_bits: usize,
91 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
92 ) -> Self {
93 assert!(
94 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
95 "address_bits={address_bits} needs to be large enough for high limb range check"
96 );
97 Self {
98 air: Rv32VecHeapTwoReadsAdapterAir {
99 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
100 memory_bridge,
101 bus: bitwise_lookup_chip.bus(),
102 address_bits,
103 },
104 bitwise_lookup_chip,
105 _marker: PhantomData,
106 }
107 }
108}
109
110#[repr(C)]
111#[serde_as]
112#[derive(Clone, Debug, Serialize, Deserialize)]
113#[serde(bound = "F: Field")]
114pub struct Rv32VecHeapTwoReadsReadRecord<
115 F: Field,
116 const BLOCKS_PER_READ1: usize,
117 const BLOCKS_PER_READ2: usize,
118 const READ_SIZE: usize,
119> {
120 pub rs1: RecordId,
122 pub rs2: RecordId,
123 pub rd: RecordId,
125
126 pub rd_val: F,
127
128 #[serde_as(as = "[_; BLOCKS_PER_READ1]")]
129 pub reads1: [RecordId; BLOCKS_PER_READ1],
130 #[serde_as(as = "[_; BLOCKS_PER_READ2]")]
131 pub reads2: [RecordId; BLOCKS_PER_READ2],
132}
133
134#[repr(C)]
135#[serde_as]
136#[derive(Clone, Debug, Serialize, Deserialize)]
137pub struct Rv32VecHeapTwoReadsWriteRecord<const BLOCKS_PER_WRITE: usize, const WRITE_SIZE: usize> {
138 pub from_state: ExecutionState<u32>,
139 #[serde_as(as = "[_; BLOCKS_PER_WRITE]")]
140 pub writes: [RecordId; BLOCKS_PER_WRITE],
141}
142
143#[repr(C)]
144#[derive(AlignedBorrow)]
145pub struct Rv32VecHeapTwoReadsAdapterCols<
146 T,
147 const BLOCKS_PER_READ1: usize,
148 const BLOCKS_PER_READ2: usize,
149 const BLOCKS_PER_WRITE: usize,
150 const READ_SIZE: usize,
151 const WRITE_SIZE: usize,
152> {
153 pub from_state: ExecutionState<T>,
154
155 pub rs1_ptr: T,
156 pub rs2_ptr: T,
157 pub rd_ptr: T,
158
159 pub rs1_val: [T; RV32_REGISTER_NUM_LIMBS],
160 pub rs2_val: [T; RV32_REGISTER_NUM_LIMBS],
161 pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
162
163 pub rs1_read_aux: MemoryReadAuxCols<T>,
164 pub rs2_read_aux: MemoryReadAuxCols<T>,
165 pub rd_read_aux: MemoryReadAuxCols<T>,
166
167 pub reads1_aux: [MemoryReadAuxCols<T>; BLOCKS_PER_READ1],
168 pub reads2_aux: [MemoryReadAuxCols<T>; BLOCKS_PER_READ2],
169 pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
170}
171
172#[allow(dead_code)]
173#[derive(Clone, Copy, Debug, derive_new::new)]
174pub struct Rv32VecHeapTwoReadsAdapterAir<
175 const BLOCKS_PER_READ1: usize,
176 const BLOCKS_PER_READ2: usize,
177 const BLOCKS_PER_WRITE: usize,
178 const READ_SIZE: usize,
179 const WRITE_SIZE: usize,
180> {
181 pub(super) execution_bridge: ExecutionBridge,
182 pub(super) memory_bridge: MemoryBridge,
183 pub bus: BitwiseOperationLookupBus,
184 address_bits: usize,
186}
187
188impl<
189 F: Field,
190 const BLOCKS_PER_READ1: usize,
191 const BLOCKS_PER_READ2: usize,
192 const BLOCKS_PER_WRITE: usize,
193 const READ_SIZE: usize,
194 const WRITE_SIZE: usize,
195 > BaseAir<F>
196 for Rv32VecHeapTwoReadsAdapterAir<
197 BLOCKS_PER_READ1,
198 BLOCKS_PER_READ2,
199 BLOCKS_PER_WRITE,
200 READ_SIZE,
201 WRITE_SIZE,
202 >
203{
204 fn width(&self) -> usize {
205 Rv32VecHeapTwoReadsAdapterCols::<
206 F,
207 BLOCKS_PER_READ1,
208 BLOCKS_PER_READ2,
209 BLOCKS_PER_WRITE,
210 READ_SIZE,
211 WRITE_SIZE,
212 >::width()
213 }
214}
215
216impl<
217 AB: InteractionBuilder,
218 const BLOCKS_PER_READ1: usize,
219 const BLOCKS_PER_READ2: usize,
220 const BLOCKS_PER_WRITE: usize,
221 const READ_SIZE: usize,
222 const WRITE_SIZE: usize,
223 > VmAdapterAir<AB>
224 for Rv32VecHeapTwoReadsAdapterAir<
225 BLOCKS_PER_READ1,
226 BLOCKS_PER_READ2,
227 BLOCKS_PER_WRITE,
228 READ_SIZE,
229 WRITE_SIZE,
230 >
231{
232 type Interface = VecHeapTwoReadsAdapterInterface<
233 AB::Expr,
234 BLOCKS_PER_READ1,
235 BLOCKS_PER_READ2,
236 BLOCKS_PER_WRITE,
237 READ_SIZE,
238 WRITE_SIZE,
239 >;
240
241 fn eval(
242 &self,
243 builder: &mut AB,
244 local: &[AB::Var],
245 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
246 ) {
247 let cols: &Rv32VecHeapTwoReadsAdapterCols<
248 _,
249 BLOCKS_PER_READ1,
250 BLOCKS_PER_READ2,
251 BLOCKS_PER_WRITE,
252 READ_SIZE,
253 WRITE_SIZE,
254 > = local.borrow();
255 let timestamp = cols.from_state.timestamp;
256 let mut timestamp_delta: usize = 0;
257 let mut timestamp_pp = || {
258 timestamp_delta += 1;
259 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
260 };
261
262 let ptrs = [cols.rs1_ptr, cols.rs2_ptr, cols.rd_ptr];
263 let vals = [cols.rs1_val, cols.rs2_val, cols.rd_val];
264 let auxs = [&cols.rs1_read_aux, &cols.rs2_read_aux, &cols.rd_read_aux];
265
266 for (ptr, val, aux) in izip!(ptrs, vals, auxs) {
268 self.memory_bridge
269 .read(
270 MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
271 val,
272 timestamp_pp(),
273 aux,
274 )
275 .eval(builder, ctx.instruction.is_valid.clone());
276 }
277
278 let need_range_check = [&cols.rs1_val, &cols.rs2_val, &cols.rd_val, &cols.rd_val]
280 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]);
281
282 let limb_shift = AB::F::from_canonical_usize(
284 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
285 );
286
287 for pair in need_range_check.chunks_exact(2) {
291 self.bus
292 .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
293 .eval(builder, ctx.instruction.is_valid.clone());
294 }
295
296 let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
297 let rs1_val_f: AB::Expr = abstract_compose(cols.rs1_val);
298 let rs2_val_f: AB::Expr = abstract_compose(cols.rs2_val);
299
300 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
301 for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads1_aux).enumerate() {
303 self.memory_bridge
304 .read(
305 MemoryAddress::new(
306 e,
307 rs1_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
308 ),
309 read,
310 timestamp_pp(),
311 aux,
312 )
313 .eval(builder, ctx.instruction.is_valid.clone());
314 }
315 for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads2_aux).enumerate() {
316 self.memory_bridge
317 .read(
318 MemoryAddress::new(
319 e,
320 rs2_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
321 ),
322 read,
323 timestamp_pp(),
324 aux,
325 )
326 .eval(builder, ctx.instruction.is_valid.clone());
327 }
328
329 for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
331 self.memory_bridge
332 .write(
333 MemoryAddress::new(
334 e,
335 rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
336 ),
337 write,
338 timestamp_pp(),
339 aux,
340 )
341 .eval(builder, ctx.instruction.is_valid.clone());
342 }
343
344 self.execution_bridge
345 .execute_and_increment_or_set_pc(
346 ctx.instruction.opcode,
347 [
348 cols.rd_ptr.into(),
349 cols.rs1_ptr.into(),
350 cols.rs2_ptr.into(),
351 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
352 e.into(),
353 ],
354 cols.from_state,
355 AB::F::from_canonical_usize(timestamp_delta),
356 (DEFAULT_PC_STEP, ctx.to_pc),
357 )
358 .eval(builder, ctx.instruction.is_valid.clone());
359 }
360
361 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
362 let cols: &Rv32VecHeapTwoReadsAdapterCols<
363 _,
364 BLOCKS_PER_READ1,
365 BLOCKS_PER_READ2,
366 BLOCKS_PER_WRITE,
367 READ_SIZE,
368 WRITE_SIZE,
369 > = local.borrow();
370 cols.from_state.pc
371 }
372}
373
374impl<
375 F: PrimeField32,
376 const BLOCKS_PER_READ1: usize,
377 const BLOCKS_PER_READ2: usize,
378 const BLOCKS_PER_WRITE: usize,
379 const READ_SIZE: usize,
380 const WRITE_SIZE: usize,
381 > VmAdapterChip<F>
382 for Rv32VecHeapTwoReadsAdapterChip<
383 F,
384 BLOCKS_PER_READ1,
385 BLOCKS_PER_READ2,
386 BLOCKS_PER_WRITE,
387 READ_SIZE,
388 WRITE_SIZE,
389 >
390{
391 type ReadRecord =
392 Rv32VecHeapTwoReadsReadRecord<F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, READ_SIZE>;
393 type WriteRecord = Rv32VecHeapTwoReadsWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>;
394 type Air = Rv32VecHeapTwoReadsAdapterAir<
395 BLOCKS_PER_READ1,
396 BLOCKS_PER_READ2,
397 BLOCKS_PER_WRITE,
398 READ_SIZE,
399 WRITE_SIZE,
400 >;
401 type Interface = VecHeapTwoReadsAdapterInterface<
402 F,
403 BLOCKS_PER_READ1,
404 BLOCKS_PER_READ2,
405 BLOCKS_PER_WRITE,
406 READ_SIZE,
407 WRITE_SIZE,
408 >;
409
410 fn preprocess(
411 &mut self,
412 memory: &mut MemoryController<F>,
413 instruction: &Instruction<F>,
414 ) -> Result<(
415 <Self::Interface as VmAdapterInterface<F>>::Reads,
416 Self::ReadRecord,
417 )> {
418 let Instruction { a, b, c, d, e, .. } = *instruction;
419
420 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
421 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
422
423 let (rs1_record, rs1_val) = read_rv32_register(memory, d, b);
424 let (rs2_record, rs2_val) = read_rv32_register(memory, d, c);
425 let (rd_record, rd_val) = read_rv32_register(memory, d, a);
426
427 assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.air.address_bits));
428 let read1_records = from_fn(|i| {
429 memory.read::<READ_SIZE>(e, F::from_canonical_u32(rs1_val + (i * READ_SIZE) as u32))
430 });
431 let read1_data = read1_records.map(|r| r.1);
432 assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.air.address_bits));
433 let read2_records = from_fn(|i| {
434 memory.read::<READ_SIZE>(e, F::from_canonical_u32(rs2_val + (i * READ_SIZE) as u32))
435 });
436 let read2_data = read2_records.map(|r| r.1);
437 assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits));
438
439 let record = Rv32VecHeapTwoReadsReadRecord {
440 rs1: rs1_record,
441 rs2: rs2_record,
442 rd: rd_record,
443 rd_val: F::from_canonical_u32(rd_val),
444 reads1: read1_records.map(|r| r.0),
445 reads2: read2_records.map(|r| r.0),
446 };
447
448 Ok(((read1_data, read2_data), record))
449 }
450
451 fn postprocess(
452 &mut self,
453 memory: &mut MemoryController<F>,
454 instruction: &Instruction<F>,
455 from_state: ExecutionState<u32>,
456 output: AdapterRuntimeContext<F, Self::Interface>,
457 read_record: &Self::ReadRecord,
458 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
459 let e = instruction.e;
460 let mut i = 0;
461 let writes = output.writes.map(|write| {
462 let (record_id, _) = memory.write(
463 e,
464 read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32),
465 write,
466 );
467 i += 1;
468 record_id
469 });
470
471 Ok((
472 ExecutionState {
473 pc: from_state.pc + DEFAULT_PC_STEP,
474 timestamp: memory.timestamp(),
475 },
476 Self::WriteRecord { from_state, writes },
477 ))
478 }
479
480 fn generate_trace_row(
481 &self,
482 row_slice: &mut [F],
483 read_record: Self::ReadRecord,
484 write_record: Self::WriteRecord,
485 memory: &OfflineMemory<F>,
486 ) {
487 vec_heap_two_reads_generate_trace_row_impl(
488 row_slice,
489 &read_record,
490 &write_record,
491 self.bitwise_lookup_chip.clone(),
492 self.air.address_bits,
493 memory,
494 )
495 }
496
497 fn air(&self) -> &Self::Air {
498 &self.air
499 }
500}
501
502pub(super) fn vec_heap_two_reads_generate_trace_row_impl<
503 F: PrimeField32,
504 const BLOCKS_PER_READ1: usize,
505 const BLOCKS_PER_READ2: usize,
506 const BLOCKS_PER_WRITE: usize,
507 const READ_SIZE: usize,
508 const WRITE_SIZE: usize,
509>(
510 row_slice: &mut [F],
511 read_record: &Rv32VecHeapTwoReadsReadRecord<F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, READ_SIZE>,
512 write_record: &Rv32VecHeapTwoReadsWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>,
513 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
514 address_bits: usize,
515 memory: &OfflineMemory<F>,
516) {
517 let aux_cols_factory = memory.aux_cols_factory();
518 let row_slice: &mut Rv32VecHeapTwoReadsAdapterCols<
519 F,
520 BLOCKS_PER_READ1,
521 BLOCKS_PER_READ2,
522 BLOCKS_PER_WRITE,
523 READ_SIZE,
524 WRITE_SIZE,
525 > = row_slice.borrow_mut();
526 row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
527
528 let rd = memory.record_by_id(read_record.rd);
529 let rs1 = memory.record_by_id(read_record.rs1);
530 let rs2 = memory.record_by_id(read_record.rs2);
531
532 row_slice.rd_ptr = rd.pointer;
533 row_slice.rs1_ptr = rs1.pointer;
534 row_slice.rs2_ptr = rs2.pointer;
535
536 row_slice.rd_val.copy_from_slice(rd.data_slice());
537 row_slice.rs1_val.copy_from_slice(rs1.data_slice());
538 row_slice.rs2_val.copy_from_slice(rs2.data_slice());
539
540 aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux);
541 aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux);
542 aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux);
543
544 for (i, r) in read_record.reads1.iter().enumerate() {
545 let record = memory.record_by_id(*r);
546 aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]);
547 }
548
549 for (i, r) in read_record.reads2.iter().enumerate() {
550 let record = memory.record_by_id(*r);
551 aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]);
552 }
553
554 for (i, w) in write_record.writes.iter().enumerate() {
555 let record = memory.record_by_id(*w);
556 aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]);
557 }
558 let need_range_check = [
560 &read_record.rs1,
561 &read_record.rs2,
562 &read_record.rd,
563 &read_record.rd,
564 ]
565 .map(|record| {
566 memory
567 .record_by_id(*record)
568 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
569 .as_canonical_u32()
570 });
571 debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
572 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits;
573 for pair in need_range_check.chunks_exact(2) {
574 bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
575 }
576}