1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8 arch::{
9 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10 BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
11 },
12 system::memory::{
13 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord},
14 online::TracingMemory,
15 MemoryAddress, MemoryAuxColsFactory,
16 },
17};
18use openvm_circuit_primitives::{
19 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20 AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24 instruction::Instruction,
25 program::DEFAULT_PC_STEP,
26 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
27};
28use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
29use openvm_stark_backend::{
30 interaction::InteractionBuilder,
31 p3_air::BaseAir,
32 p3_field::{Field, FieldAlgebra, PrimeField32},
33};
34
35#[repr(C)]
40#[derive(AlignedBorrow)]
41pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
42 pub from_state: ExecutionState<T>,
43
44 pub rs_ptr: [T; NUM_READS],
45 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
46 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
47
48 pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
49}
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
53 pub(super) execution_bridge: ExecutionBridge,
54 pub(super) memory_bridge: MemoryBridge,
55 pub bus: BitwiseOperationLookupBus,
56 address_bits: usize,
57}
58
59impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
60 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
61{
62 fn width(&self) -> usize {
63 Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
64 }
65}
66
67impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
68 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70 type Interface =
71 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
72
73 fn eval(
74 &self,
75 builder: &mut AB,
76 local: &[AB::Var],
77 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
78 ) {
79 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
80 let timestamp = cols.from_state.timestamp;
81 let mut timestamp_delta: usize = 0;
82 let mut timestamp_pp = || {
83 timestamp_delta += 1;
84 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
85 };
86
87 let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
88 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
89
90 for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
91 self.memory_bridge
92 .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
93 .eval(builder, ctx.instruction.is_valid.clone());
94 }
95
96 let need_range_check: Vec<AB::Var> = cols
102 .rs_val
103 .iter()
104 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
105 .collect();
106
107 let limb_shift = AB::F::from_canonical_usize(
110 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
111 );
112
113 for pair in need_range_check.chunks(2) {
117 self.bus
118 .send_range(
119 pair[0] * limb_shift,
120 pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, )
122 .eval(builder, ctx.instruction.is_valid.clone());
123 }
124
125 let heap_ptr = cols.rs_val.map(|r| {
126 r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
127 acc * AB::F::from_canonical_u32(1 << RV32_CELL_BITS) + (*limb)
128 })
129 });
130 for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
131 self.memory_bridge
132 .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
133 .eval(builder, ctx.instruction.is_valid.clone());
134 }
135
136 self.execution_bridge
137 .execute_and_increment_or_set_pc(
138 ctx.instruction.opcode,
139 [
140 cols.rs_ptr
141 .first()
142 .map(|&x| x.into())
143 .unwrap_or(AB::Expr::ZERO),
144 cols.rs_ptr
145 .get(1)
146 .map(|&x| x.into())
147 .unwrap_or(AB::Expr::ZERO),
148 ctx.instruction.immediate,
149 d.into(),
150 e.into(),
151 ],
152 cols.from_state,
153 AB::F::from_canonical_usize(timestamp_delta),
154 (DEFAULT_PC_STEP, ctx.to_pc),
155 )
156 .eval(builder, ctx.instruction.is_valid);
157 }
158
159 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
160 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
161 cols.from_state.pc
162 }
163}
164
165#[repr(C)]
166#[derive(AlignedBytesBorrow, Debug)]
167pub struct Rv32HeapBranchAdapterRecord<const NUM_READS: usize> {
168 pub from_pc: u32,
169 pub from_timestamp: u32,
170
171 pub rs_ptr: [u32; NUM_READS],
172 pub rs_vals: [u32; NUM_READS],
173
174 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
175 pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS],
176}
177
178#[derive(Clone, Copy)]
179pub struct Rv32HeapBranchAdapterExecutor<const NUM_READS: usize, const READ_SIZE: usize> {
180 pub pointer_max_bits: usize,
181}
182
183#[derive(derive_new::new)]
184pub struct Rv32HeapBranchAdapterFiller<const NUM_READS: usize, const READ_SIZE: usize> {
185 pub pointer_max_bits: usize,
186 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
187}
188
189impl<const NUM_READS: usize, const READ_SIZE: usize>
190 Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
191{
192 pub fn new(pointer_max_bits: usize) -> Self {
193 assert!(NUM_READS <= 2);
194 assert!(
195 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
196 "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
197 );
198 Self { pointer_max_bits }
199 }
200}
201
202impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceExecutor<F>
203 for Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
204{
205 const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
206 type ReadData = [[u8; READ_SIZE]; NUM_READS];
207 type WriteData = ();
208 type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord<NUM_READS>;
209
210 fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) {
211 adapter_record.from_pc = pc;
212 adapter_record.from_timestamp = memory.timestamp;
213 }
214
215 fn read(
216 &self,
217 memory: &mut TracingMemory,
218 instruction: &Instruction<F>,
219 record: &mut Self::RecordMut<'_>,
220 ) -> Self::ReadData {
221 let Instruction { a, b, d, e, .. } = *instruction;
222
223 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
224 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
225
226 record.rs_vals = from_fn(|i| {
228 record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32();
229 u32::from_le_bytes(tracing_read(
230 memory,
231 RV32_REGISTER_AS,
232 record.rs_ptr[i],
233 &mut record.rs_read_aux[i].prev_timestamp,
234 ))
235 });
236
237 from_fn(|i| {
239 debug_assert!(
240 record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)
241 );
242 tracing_read(
243 memory,
244 RV32_MEMORY_AS,
245 record.rs_vals[i],
246 &mut record.heap_read_aux[i].prev_timestamp,
247 )
248 })
249 }
250
251 fn write(
252 &self,
253 _memory: &mut TracingMemory,
254 _instruction: &Instruction<F>,
255 _data: Self::WriteData,
256 _record: &mut Self::RecordMut<'_>,
257 ) {
258 }
260}
261
262impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceFiller<F>
263 for Rv32HeapBranchAdapterFiller<NUM_READS, READ_SIZE>
264{
265 const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
266
267 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
268 let record: &Rv32HeapBranchAdapterRecord<NUM_READS> =
272 unsafe { get_record_from_slice(&mut adapter_row, ()) };
273 let cols: &mut Rv32HeapBranchAdapterCols<F, NUM_READS, READ_SIZE> =
274 adapter_row.borrow_mut();
275
276 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
279 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
280 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
281 self.bitwise_lookup_chip.request_range(
282 (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
283 if NUM_READS > 1 {
284 (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits
285 } else {
286 0
287 },
288 );
289
290 for i in (0..NUM_READS).rev() {
292 mem_helper.fill(
293 record.heap_read_aux[i].prev_timestamp,
294 record.from_timestamp + (i + NUM_READS) as u32,
295 cols.heap_read_aux[i].as_mut(),
296 );
297 }
298
299 for i in (0..NUM_READS).rev() {
300 mem_helper.fill(
301 record.rs_read_aux[i].prev_timestamp,
302 record.from_timestamp + i as u32,
303 cols.rs_read_aux[i].as_mut(),
304 );
305 }
306
307 cols.rs_val
308 .iter_mut()
309 .rev()
310 .zip(record.rs_vals.iter().rev())
311 .for_each(|(col, record)| {
312 *col = record.to_le_bytes().map(F::from_canonical_u8);
313 });
314
315 cols.rs_ptr
316 .iter_mut()
317 .rev()
318 .zip(record.rs_ptr.iter().rev())
319 .for_each(|(col, record)| {
320 *col = F::from_canonical_u32(*record);
321 });
322
323 cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
324 cols.from_state.pc = F::from_canonical_u32(record.from_pc);
325 }
326}