openvm_rv32_adapters/
heap_branch.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8    arch::{
9        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
11    },
12    system::memory::{
13        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord},
14        online::TracingMemory,
15        MemoryAddress, MemoryAuxColsFactory,
16    },
17};
18use openvm_circuit_primitives::{
19    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20    AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24    instruction::Instruction,
25    program::DEFAULT_PC_STEP,
26    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
27};
28use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
29use openvm_stark_backend::{
30    interaction::InteractionBuilder,
31    p3_air::BaseAir,
32    p3_field::{Field, FieldAlgebra, PrimeField32},
33};
34
35/// This adapter reads from NUM_READS <= 2 pointers.
36/// * The data is read from the heap (address space 2), and the pointers are read from registers
37///   (address space 1).
38/// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
39#[repr(C)]
40#[derive(AlignedBorrow)]
41pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
42    pub from_state: ExecutionState<T>,
43
44    pub rs_ptr: [T; NUM_READS],
45    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
46    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
47
48    pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
49}
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
53    pub(super) execution_bridge: ExecutionBridge,
54    pub(super) memory_bridge: MemoryBridge,
55    pub bus: BitwiseOperationLookupBus,
56    address_bits: usize,
57}
58
59impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
60    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
61{
62    fn width(&self) -> usize {
63        Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
64    }
65}
66
67impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
68    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70    type Interface =
71        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
72
73    fn eval(
74        &self,
75        builder: &mut AB,
76        local: &[AB::Var],
77        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
78    ) {
79        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
80        let timestamp = cols.from_state.timestamp;
81        let mut timestamp_delta: usize = 0;
82        let mut timestamp_pp = || {
83            timestamp_delta += 1;
84            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
85        };
86
87        let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
88        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
89
90        for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
91            self.memory_bridge
92                .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
93                .eval(builder, ctx.instruction.is_valid.clone());
94        }
95
96        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits -
97        // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow
98        // occurs when computing memory pointers. Since the number of cells accessed with each
99        // address will be small enough, and combined with the memory argument, it ensures
100        // that all the cells accessed in the memory are less than 2^addr_bits.
101        let need_range_check: Vec<AB::Var> = cols
102            .rs_val
103            .iter()
104            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
105            .collect();
106
107        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain
108        // the correct amount of bits
109        let limb_shift = AB::F::from_canonical_usize(
110            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
111        );
112
113        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
114        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
115        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
116        for pair in need_range_check.chunks(2) {
117            self.bus
118                .send_range(
119                    pair[0] * limb_shift,
120                    pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, // in case NUM_READS is odd
121                )
122                .eval(builder, ctx.instruction.is_valid.clone());
123        }
124
125        let heap_ptr = cols.rs_val.map(|r| {
126            r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
127                acc * AB::F::from_canonical_u32(1 << RV32_CELL_BITS) + (*limb)
128            })
129        });
130        for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
131            self.memory_bridge
132                .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
133                .eval(builder, ctx.instruction.is_valid.clone());
134        }
135
136        self.execution_bridge
137            .execute_and_increment_or_set_pc(
138                ctx.instruction.opcode,
139                [
140                    cols.rs_ptr
141                        .first()
142                        .map(|&x| x.into())
143                        .unwrap_or(AB::Expr::ZERO),
144                    cols.rs_ptr
145                        .get(1)
146                        .map(|&x| x.into())
147                        .unwrap_or(AB::Expr::ZERO),
148                    ctx.instruction.immediate,
149                    d.into(),
150                    e.into(),
151                ],
152                cols.from_state,
153                AB::F::from_canonical_usize(timestamp_delta),
154                (DEFAULT_PC_STEP, ctx.to_pc),
155            )
156            .eval(builder, ctx.instruction.is_valid);
157    }
158
159    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
160        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
161        cols.from_state.pc
162    }
163}
164
165#[repr(C)]
166#[derive(AlignedBytesBorrow, Debug)]
167pub struct Rv32HeapBranchAdapterRecord<const NUM_READS: usize> {
168    pub from_pc: u32,
169    pub from_timestamp: u32,
170
171    pub rs_ptr: [u32; NUM_READS],
172    pub rs_vals: [u32; NUM_READS],
173
174    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
175    pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS],
176}
177
178#[derive(Clone, Copy)]
179pub struct Rv32HeapBranchAdapterExecutor<const NUM_READS: usize, const READ_SIZE: usize> {
180    pub pointer_max_bits: usize,
181}
182
183#[derive(derive_new::new)]
184pub struct Rv32HeapBranchAdapterFiller<const NUM_READS: usize, const READ_SIZE: usize> {
185    pub pointer_max_bits: usize,
186    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
187}
188
189impl<const NUM_READS: usize, const READ_SIZE: usize>
190    Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
191{
192    pub fn new(pointer_max_bits: usize) -> Self {
193        assert!(NUM_READS <= 2);
194        assert!(
195            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
196            "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
197        );
198        Self { pointer_max_bits }
199    }
200}
201
202impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceExecutor<F>
203    for Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
204{
205    const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
206    type ReadData = [[u8; READ_SIZE]; NUM_READS];
207    type WriteData = ();
208    type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord<NUM_READS>;
209
210    fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) {
211        adapter_record.from_pc = pc;
212        adapter_record.from_timestamp = memory.timestamp;
213    }
214
215    fn read(
216        &self,
217        memory: &mut TracingMemory,
218        instruction: &Instruction<F>,
219        record: &mut Self::RecordMut<'_>,
220    ) -> Self::ReadData {
221        let Instruction { a, b, d, e, .. } = *instruction;
222
223        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
224        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
225
226        // Read register values
227        record.rs_vals = from_fn(|i| {
228            record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32();
229            u32::from_le_bytes(tracing_read(
230                memory,
231                RV32_REGISTER_AS,
232                record.rs_ptr[i],
233                &mut record.rs_read_aux[i].prev_timestamp,
234            ))
235        });
236
237        // Read memory values
238        from_fn(|i| {
239            debug_assert!(
240                record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)
241            );
242            tracing_read(
243                memory,
244                RV32_MEMORY_AS,
245                record.rs_vals[i],
246                &mut record.heap_read_aux[i].prev_timestamp,
247            )
248        })
249    }
250
251    fn write(
252        &self,
253        _memory: &mut TracingMemory,
254        _instruction: &Instruction<F>,
255        _data: Self::WriteData,
256        _record: &mut Self::RecordMut<'_>,
257    ) {
258        // This adapter doesn't write anything
259    }
260}
261
262impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceFiller<F>
263    for Rv32HeapBranchAdapterFiller<NUM_READS, READ_SIZE>
264{
265    const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
266
267    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
268        // SAFETY:
269        // - caller ensures `adapter_row` contains a valid record representation that was previously
270        //   written by the executor
271        let record: &Rv32HeapBranchAdapterRecord<NUM_READS> =
272            unsafe { get_record_from_slice(&mut adapter_row, ()) };
273        let cols: &mut Rv32HeapBranchAdapterCols<F, NUM_READS, READ_SIZE> =
274            adapter_row.borrow_mut();
275
276        // Range checks:
277        // **NOTE**: Must do the range checks before overwriting the records
278        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
279        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
280        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
281        self.bitwise_lookup_chip.request_range(
282            (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
283            if NUM_READS > 1 {
284                (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits
285            } else {
286                0
287            },
288        );
289
290        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
291        for i in (0..NUM_READS).rev() {
292            mem_helper.fill(
293                record.heap_read_aux[i].prev_timestamp,
294                record.from_timestamp + (i + NUM_READS) as u32,
295                cols.heap_read_aux[i].as_mut(),
296            );
297        }
298
299        for i in (0..NUM_READS).rev() {
300            mem_helper.fill(
301                record.rs_read_aux[i].prev_timestamp,
302                record.from_timestamp + i as u32,
303                cols.rs_read_aux[i].as_mut(),
304            );
305        }
306
307        cols.rs_val
308            .iter_mut()
309            .rev()
310            .zip(record.rs_vals.iter().rev())
311            .for_each(|(col, record)| {
312                *col = record.to_le_bytes().map(F::from_canonical_u8);
313            });
314
315        cols.rs_ptr
316            .iter_mut()
317            .rev()
318            .zip(record.rs_ptr.iter().rev())
319            .for_each(|(col, record)| {
320                *col = F::from_canonical_u32(*record);
321            });
322
323        cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
324        cols.from_state.pc = F::from_canonical_u32(record.from_pc);
325    }
326}