1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 iter::once,
5 marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10 arch::{
11 AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
12 ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip,
13 VmAdapterInterface,
14 },
15 system::{
16 memory::{
17 offline_checker::{MemoryBridge, MemoryReadAuxCols},
18 MemoryAddress, MemoryController, OfflineMemory, RecordId,
19 },
20 program::ProgramBus,
21 },
22};
23use openvm_circuit_primitives::bitwise_op_lookup::{
24 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28 instruction::Instruction,
29 program::DEFAULT_PC_STEP,
30 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33 read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36 interaction::InteractionBuilder,
37 p3_air::BaseAir,
38 p3_field::{Field, FieldAlgebra, PrimeField32},
39};
40use serde::{Deserialize, Serialize};
41use serde_big_array::BigArray;
42
43#[repr(C)]
48#[derive(AlignedBorrow)]
49pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
50 pub from_state: ExecutionState<T>,
51
52 pub rs_ptr: [T; NUM_READS],
53 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
54 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
55
56 pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
57}
58
59#[derive(Clone, Copy, Debug, derive_new::new)]
60pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
61 pub(super) execution_bridge: ExecutionBridge,
62 pub(super) memory_bridge: MemoryBridge,
63 pub bus: BitwiseOperationLookupBus,
64 address_bits: usize,
65}
66
67impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
68 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70 fn width(&self) -> usize {
71 Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
72 }
73}
74
75impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
76 for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
77{
78 type Interface =
79 BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
80
81 fn eval(
82 &self,
83 builder: &mut AB,
84 local: &[AB::Var],
85 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
86 ) {
87 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
88 let timestamp = cols.from_state.timestamp;
89 let mut timestamp_delta: usize = 0;
90 let mut timestamp_pp = || {
91 timestamp_delta += 1;
92 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
93 };
94
95 let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
96 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
97
98 for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
99 self.memory_bridge
100 .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
101 .eval(builder, ctx.instruction.is_valid.clone());
102 }
103
104 let need_range_check: Vec<AB::Var> = cols
108 .rs_val
109 .iter()
110 .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
111 .collect();
112
113 let limb_shift = AB::F::from_canonical_usize(
115 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
116 );
117
118 for pair in need_range_check.chunks(2) {
122 self.bus
123 .send_range(
124 pair[0] * limb_shift,
125 pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, )
127 .eval(builder, ctx.instruction.is_valid.clone());
128 }
129
130 let heap_ptr = cols.rs_val.map(|r| {
131 r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
132 acc * AB::F::from_canonical_u32(1 << RV32_CELL_BITS) + (*limb)
133 })
134 });
135 for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
136 self.memory_bridge
137 .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
138 .eval(builder, ctx.instruction.is_valid.clone());
139 }
140
141 self.execution_bridge
142 .execute_and_increment_or_set_pc(
143 ctx.instruction.opcode,
144 [
145 cols.rs_ptr
146 .first()
147 .map(|&x| x.into())
148 .unwrap_or(AB::Expr::ZERO),
149 cols.rs_ptr
150 .get(1)
151 .map(|&x| x.into())
152 .unwrap_or(AB::Expr::ZERO),
153 ctx.instruction.immediate,
154 d.into(),
155 e.into(),
156 ],
157 cols.from_state,
158 AB::F::from_canonical_usize(timestamp_delta),
159 (DEFAULT_PC_STEP, ctx.to_pc),
160 )
161 .eval(builder, ctx.instruction.is_valid);
162 }
163
164 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
165 let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
166 cols.from_state.pc
167 }
168}
169
170pub struct Rv32HeapBranchAdapterChip<F: Field, const NUM_READS: usize, const READ_SIZE: usize> {
171 pub air: Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>,
172 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
173 _marker: PhantomData<F>,
174}
175
176impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize>
177 Rv32HeapBranchAdapterChip<F, NUM_READS, READ_SIZE>
178{
179 pub fn new(
180 execution_bus: ExecutionBus,
181 program_bus: ProgramBus,
182 memory_bridge: MemoryBridge,
183 address_bits: usize,
184 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
185 ) -> Self {
186 assert!(NUM_READS <= 2);
187 assert!(
188 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
189 "address_bits={address_bits} needs to be large enough for high limb range check"
190 );
191 Self {
192 air: Rv32HeapBranchAdapterAir {
193 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
194 memory_bridge,
195 bus: bitwise_lookup_chip.bus(),
196 address_bits,
197 },
198 bitwise_lookup_chip,
199 _marker: PhantomData,
200 }
201 }
202}
203
204#[repr(C)]
205#[derive(Clone, Debug, Serialize, Deserialize)]
206pub struct Rv32HeapBranchReadRecord<const NUM_READS: usize, const READ_SIZE: usize> {
207 #[serde(with = "BigArray")]
208 pub rs_reads: [RecordId; NUM_READS],
209 #[serde(with = "BigArray")]
210 pub heap_reads: [RecordId; NUM_READS],
211}
212
213impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterChip<F>
214 for Rv32HeapBranchAdapterChip<F, NUM_READS, READ_SIZE>
215{
216 type ReadRecord = Rv32HeapBranchReadRecord<NUM_READS, READ_SIZE>;
217 type WriteRecord = ExecutionState<u32>;
218 type Air = Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>;
219 type Interface = BasicAdapterInterface<F, ImmInstruction<F>, NUM_READS, 0, READ_SIZE, 0>;
220
221 fn preprocess(
222 &mut self,
223 memory: &mut MemoryController<F>,
224 instruction: &Instruction<F>,
225 ) -> Result<(
226 <Self::Interface as VmAdapterInterface<F>>::Reads,
227 Self::ReadRecord,
228 )> {
229 let Instruction { a, b, d, e, .. } = *instruction;
230
231 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
232 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
233
234 let mut rs_vals = [0; NUM_READS];
235 let rs_records: [_; NUM_READS] = from_fn(|i| {
236 let addr = if i == 0 { a } else { b };
237 let (record, val) = read_rv32_register(memory, d, addr);
238 rs_vals[i] = val;
239 record
240 });
241
242 let heap_records = rs_vals.map(|address| {
243 assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits));
244 memory.read::<READ_SIZE>(e, F::from_canonical_u32(address))
245 });
246
247 let record = Rv32HeapBranchReadRecord {
248 rs_reads: rs_records,
249 heap_reads: heap_records.map(|r| r.0),
250 };
251 Ok((heap_records.map(|r| r.1), record))
252 }
253
254 fn postprocess(
255 &mut self,
256 memory: &mut MemoryController<F>,
257 _instruction: &Instruction<F>,
258 from_state: ExecutionState<u32>,
259 output: AdapterRuntimeContext<F, Self::Interface>,
260 _read_record: &Self::ReadRecord,
261 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
262 let timestamp_delta = memory.timestamp() - from_state.timestamp;
263 debug_assert!(
264 timestamp_delta == 4,
265 "timestamp delta is {}, expected 4",
266 timestamp_delta
267 );
268
269 Ok((
270 ExecutionState {
271 pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
272 timestamp: memory.timestamp(),
273 },
274 from_state,
275 ))
276 }
277
278 fn generate_trace_row(
279 &self,
280 row_slice: &mut [F],
281 read_record: Self::ReadRecord,
282 write_record: Self::WriteRecord,
283 memory: &OfflineMemory<F>,
284 ) {
285 let aux_cols_factory = memory.aux_cols_factory();
286 let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> =
287 row_slice.borrow_mut();
288 row_slice.from_state = write_record.map(F::from_canonical_u32);
289
290 let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r));
291
292 for (i, rs_read) in rs_reads.iter().enumerate() {
293 row_slice.rs_ptr[i] = rs_read.pointer;
294 row_slice.rs_val[i].copy_from_slice(rs_read.data_slice());
295 aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]);
296 }
297
298 for (i, heap_read) in read_record.heap_reads.iter().enumerate() {
299 let record = memory.record_by_id(*heap_read);
300 aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]);
301 }
302
303 let need_range_check: Vec<u32> = rs_reads
305 .iter()
306 .map(|record| {
307 record
308 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
309 .as_canonical_u32()
310 })
311 .chain(once(0)) .collect();
313 debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
314 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits;
315 for pair in need_range_check.chunks_exact(2) {
316 self.bitwise_lookup_chip
317 .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
318 }
319 }
320
321 fn air(&self) -> &Self::Air {
322 &self.air
323 }
324}