1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8 arch::{
9 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10 BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
11 },
12 system::memory::{
13 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord},
14 online::TracingMemory,
15 MemoryAddress, MemoryAuxColsFactory,
16 },
17};
18use openvm_circuit_primitives::{
19 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20 AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24 instruction::Instruction,
25 program::DEFAULT_PC_STEP,
26 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
27};
28use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
29use openvm_stark_backend::{
30 interaction::InteractionBuilder,
31 p3_air::BaseAir,
32 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
33};
34
35#[repr(C)]
40#[derive(AlignedBorrow)]
41pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
42 pub from_state: ExecutionState<T>,
43
44 pub rs_ptr: [T; NUM_READS],
45 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
46 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
47
48 pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
49}
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
53 pub(super) execution_bridge: ExecutionBridge,
54 pub(super) memory_bridge: MemoryBridge,
55 pub bus: BitwiseOperationLookupBus,
56 address_bits: usize,
57}
58
59impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
60 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
61{
62 fn width(&self) -> usize {
63 Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
64 }
65}
66
67impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
68 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70 type Interface =
71 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
72
73 fn eval(
74 &self,
75 builder: &mut AB,
76 local: &[AB::Var],
77 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
78 ) {
79 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
80 let timestamp = cols.from_state.timestamp;
81 let mut timestamp_delta: usize = 0;
82 let mut timestamp_pp = || {
83 timestamp_delta += 1;
84 timestamp + AB::F::from_usize(timestamp_delta - 1)
85 };
86
87 let d = AB::F::from_u32(RV32_REGISTER_AS);
88 let e = AB::F::from_u32(RV32_MEMORY_AS);
89
90 for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
91 self.memory_bridge
92 .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
93 .eval(builder, ctx.instruction.is_valid.clone());
94 }
95
96 let need_range_check: Vec<AB::Var> = cols
102 .rs_val
103 .iter()
104 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
105 .collect();
106
107 let limb_shift =
110 AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
111
112 for pair in need_range_check.chunks(2) {
116 self.bus
117 .send_range(
118 pair[0] * limb_shift,
119 pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, )
121 .eval(builder, ctx.instruction.is_valid.clone());
122 }
123
124 let heap_ptr = cols.rs_val.map(|r| {
125 r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
126 acc * AB::F::from_u32(1 << RV32_CELL_BITS) + (*limb)
127 })
128 });
129 for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
130 self.memory_bridge
131 .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
132 .eval(builder, ctx.instruction.is_valid.clone());
133 }
134
135 self.execution_bridge
136 .execute_and_increment_or_set_pc(
137 ctx.instruction.opcode,
138 [
139 cols.rs_ptr
140 .first()
141 .map(|&x| x.into())
142 .unwrap_or(AB::Expr::ZERO),
143 cols.rs_ptr
144 .get(1)
145 .map(|&x| x.into())
146 .unwrap_or(AB::Expr::ZERO),
147 ctx.instruction.immediate,
148 d.into(),
149 e.into(),
150 ],
151 cols.from_state,
152 AB::F::from_usize(timestamp_delta),
153 (DEFAULT_PC_STEP, ctx.to_pc),
154 )
155 .eval(builder, ctx.instruction.is_valid);
156 }
157
158 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
159 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
160 cols.from_state.pc
161 }
162}
163
164#[repr(C)]
165#[derive(AlignedBytesBorrow, Debug)]
166pub struct Rv32HeapBranchAdapterRecord<const NUM_READS: usize> {
167 pub from_pc: u32,
168 pub from_timestamp: u32,
169
170 pub rs_ptr: [u32; NUM_READS],
171 pub rs_vals: [u32; NUM_READS],
172
173 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
174 pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS],
175}
176
177#[derive(Clone, Copy)]
178pub struct Rv32HeapBranchAdapterExecutor<const NUM_READS: usize, const READ_SIZE: usize> {
179 pub pointer_max_bits: usize,
180}
181
182#[derive(derive_new::new)]
183pub struct Rv32HeapBranchAdapterFiller<const NUM_READS: usize, const READ_SIZE: usize> {
184 pub pointer_max_bits: usize,
185 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
186}
187
188impl<const NUM_READS: usize, const READ_SIZE: usize>
189 Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
190{
191 pub fn new(pointer_max_bits: usize) -> Self {
192 assert!(NUM_READS <= 2);
193 assert!(
194 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
195 "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
196 );
197 Self { pointer_max_bits }
198 }
199}
200
201impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceExecutor<F>
202 for Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
203{
204 const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
205 type ReadData = [[u8; READ_SIZE]; NUM_READS];
206 type WriteData = ();
207 type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord<NUM_READS>;
208
209 fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) {
210 adapter_record.from_pc = pc;
211 adapter_record.from_timestamp = memory.timestamp;
212 }
213
214 fn read(
215 &self,
216 memory: &mut TracingMemory,
217 instruction: &Instruction<F>,
218 record: &mut Self::RecordMut<'_>,
219 ) -> Self::ReadData {
220 let Instruction { a, b, d, e, .. } = *instruction;
221
222 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
223 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
224
225 record.rs_vals = from_fn(|i| {
227 record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32();
228 u32::from_le_bytes(tracing_read(
229 memory,
230 RV32_REGISTER_AS,
231 record.rs_ptr[i],
232 &mut record.rs_read_aux[i].prev_timestamp,
233 ))
234 });
235
236 from_fn(|i| {
238 debug_assert!(
239 record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)
240 );
241 tracing_read(
242 memory,
243 RV32_MEMORY_AS,
244 record.rs_vals[i],
245 &mut record.heap_read_aux[i].prev_timestamp,
246 )
247 })
248 }
249
250 fn write(
251 &self,
252 _memory: &mut TracingMemory,
253 _instruction: &Instruction<F>,
254 _data: Self::WriteData,
255 _record: &mut Self::RecordMut<'_>,
256 ) {
257 }
259}
260
261impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceFiller<F>
262 for Rv32HeapBranchAdapterFiller<NUM_READS, READ_SIZE>
263{
264 const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
265
266 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
267 let record: &Rv32HeapBranchAdapterRecord<NUM_READS> =
271 unsafe { get_record_from_slice(&mut adapter_row, ()) };
272 let cols: &mut Rv32HeapBranchAdapterCols<F, NUM_READS, READ_SIZE> =
273 adapter_row.borrow_mut();
274
275 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
278 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
279 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
280 self.bitwise_lookup_chip.request_range(
281 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
282 if NUM_READS > 1 {
283 (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits
284 } else {
285 0
286 },
287 );
288
289 for i in (0..NUM_READS).rev() {
291 mem_helper.fill(
292 record.heap_read_aux[i].prev_timestamp,
293 record.from_timestamp + (i + NUM_READS) as u32,
294 cols.heap_read_aux[i].as_mut(),
295 );
296 }
297
298 for i in (0..NUM_READS).rev() {
299 mem_helper.fill(
300 record.rs_read_aux[i].prev_timestamp,
301 record.from_timestamp + i as u32,
302 cols.rs_read_aux[i].as_mut(),
303 );
304 }
305
306 cols.rs_val
307 .iter_mut()
308 .rev()
309 .zip(record.rs_vals.iter().rev())
310 .for_each(|(col, record)| {
311 *col = record.to_le_bytes().map(F::from_u8);
312 });
313
314 cols.rs_ptr
315 .iter_mut()
316 .rev()
317 .zip(record.rs_ptr.iter().rev())
318 .for_each(|(col, record)| {
319 *col = F::from_u32(*record);
320 });
321
322 cols.from_state.timestamp = F::from_u32(record.from_timestamp);
323 cols.from_state.pc = F::from_u32(record.from_pc);
324 }
325}