openvm_rv32_adapters/
heap_branch.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8    arch::{
9        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10        BasicAdapterInterface, ExecutionBridge, ExecutionState, ImmInstruction, VmAdapterAir,
11    },
12    system::memory::{
13        offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord},
14        online::TracingMemory,
15        MemoryAddress, MemoryAuxColsFactory,
16    },
17};
18use openvm_circuit_primitives::{
19    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20    AlignedBytesBorrow,
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24    instruction::Instruction,
25    program::DEFAULT_PC_STEP,
26    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
27};
28use openvm_rv32im_circuit::adapters::{tracing_read, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS};
29use openvm_stark_backend::{
30    interaction::InteractionBuilder,
31    p3_air::BaseAir,
32    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
33};
34
35/// This adapter reads from NUM_READS <= 2 pointers.
36/// * The data is read from the heap (address space 2), and the pointers are read from registers
37///   (address space 1).
38/// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
39#[repr(C)]
40#[derive(AlignedBorrow)]
41pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
42    pub from_state: ExecutionState<T>,
43
44    pub rs_ptr: [T; NUM_READS],
45    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
46    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
47
48    pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
49}
50
51#[derive(Clone, Copy, Debug, derive_new::new)]
52pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
53    pub(super) execution_bridge: ExecutionBridge,
54    pub(super) memory_bridge: MemoryBridge,
55    pub bus: BitwiseOperationLookupBus,
56    address_bits: usize,
57}
58
59impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
60    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
61{
62    fn width(&self) -> usize {
63        Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
64    }
65}
66
67impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
68    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70    type Interface =
71        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
72
73    fn eval(
74        &self,
75        builder: &mut AB,
76        local: &[AB::Var],
77        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
78    ) {
79        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
80        let timestamp = cols.from_state.timestamp;
81        let mut timestamp_delta: usize = 0;
82        let mut timestamp_pp = || {
83            timestamp_delta += 1;
84            timestamp + AB::F::from_usize(timestamp_delta - 1)
85        };
86
87        let d = AB::F::from_u32(RV32_REGISTER_AS);
88        let e = AB::F::from_u32(RV32_MEMORY_AS);
89
90        for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
91            self.memory_bridge
92                .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
93                .eval(builder, ctx.instruction.is_valid.clone());
94        }
95
96        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits -
97        // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow
98        // occurs when computing memory pointers. Since the number of cells accessed with each
99        // address will be small enough, and combined with the memory argument, it ensures
100        // that all the cells accessed in the memory are less than 2^addr_bits.
101        let need_range_check: Vec<AB::Var> = cols
102            .rs_val
103            .iter()
104            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
105            .collect();
106
107        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain
108        // the correct amount of bits
109        let limb_shift =
110            AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
111
112        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
113        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
114        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
115        for pair in need_range_check.chunks(2) {
116            self.bus
117                .send_range(
118                    pair[0] * limb_shift,
119                    pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, // in case NUM_READS is odd
120                )
121                .eval(builder, ctx.instruction.is_valid.clone());
122        }
123
124        let heap_ptr = cols.rs_val.map(|r| {
125            r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
126                acc * AB::F::from_u32(1 << RV32_CELL_BITS) + (*limb)
127            })
128        });
129        for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
130            self.memory_bridge
131                .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
132                .eval(builder, ctx.instruction.is_valid.clone());
133        }
134
135        self.execution_bridge
136            .execute_and_increment_or_set_pc(
137                ctx.instruction.opcode,
138                [
139                    cols.rs_ptr
140                        .first()
141                        .map(|&x| x.into())
142                        .unwrap_or(AB::Expr::ZERO),
143                    cols.rs_ptr
144                        .get(1)
145                        .map(|&x| x.into())
146                        .unwrap_or(AB::Expr::ZERO),
147                    ctx.instruction.immediate,
148                    d.into(),
149                    e.into(),
150                ],
151                cols.from_state,
152                AB::F::from_usize(timestamp_delta),
153                (DEFAULT_PC_STEP, ctx.to_pc),
154            )
155            .eval(builder, ctx.instruction.is_valid);
156    }
157
158    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
159        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
160        cols.from_state.pc
161    }
162}
163
164#[repr(C)]
165#[derive(AlignedBytesBorrow, Debug)]
166pub struct Rv32HeapBranchAdapterRecord<const NUM_READS: usize> {
167    pub from_pc: u32,
168    pub from_timestamp: u32,
169
170    pub rs_ptr: [u32; NUM_READS],
171    pub rs_vals: [u32; NUM_READS],
172
173    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
174    pub heap_read_aux: [MemoryReadAuxRecord; NUM_READS],
175}
176
177#[derive(Clone, Copy)]
178pub struct Rv32HeapBranchAdapterExecutor<const NUM_READS: usize, const READ_SIZE: usize> {
179    pub pointer_max_bits: usize,
180}
181
182#[derive(derive_new::new)]
183pub struct Rv32HeapBranchAdapterFiller<const NUM_READS: usize, const READ_SIZE: usize> {
184    pub pointer_max_bits: usize,
185    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
186}
187
188impl<const NUM_READS: usize, const READ_SIZE: usize>
189    Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
190{
191    pub fn new(pointer_max_bits: usize) -> Self {
192        assert!(NUM_READS <= 2);
193        assert!(
194            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
195            "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
196        );
197        Self { pointer_max_bits }
198    }
199}
200
201impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceExecutor<F>
202    for Rv32HeapBranchAdapterExecutor<NUM_READS, READ_SIZE>
203{
204    const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
205    type ReadData = [[u8; READ_SIZE]; NUM_READS];
206    type WriteData = ();
207    type RecordMut<'a> = &'a mut Rv32HeapBranchAdapterRecord<NUM_READS>;
208
209    fn start(pc: u32, memory: &TracingMemory, adapter_record: &mut Self::RecordMut<'_>) {
210        adapter_record.from_pc = pc;
211        adapter_record.from_timestamp = memory.timestamp;
212    }
213
214    fn read(
215        &self,
216        memory: &mut TracingMemory,
217        instruction: &Instruction<F>,
218        record: &mut Self::RecordMut<'_>,
219    ) -> Self::ReadData {
220        let Instruction { a, b, d, e, .. } = *instruction;
221
222        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
223        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
224
225        // Read register values
226        record.rs_vals = from_fn(|i| {
227            record.rs_ptr[i] = if i == 0 { a } else { b }.as_canonical_u32();
228            u32::from_le_bytes(tracing_read(
229                memory,
230                RV32_REGISTER_AS,
231                record.rs_ptr[i],
232                &mut record.rs_read_aux[i].prev_timestamp,
233            ))
234        });
235
236        // Read memory values
237        from_fn(|i| {
238            debug_assert!(
239                record.rs_vals[i] as usize + READ_SIZE - 1 < (1 << self.pointer_max_bits)
240            );
241            tracing_read(
242                memory,
243                RV32_MEMORY_AS,
244                record.rs_vals[i],
245                &mut record.heap_read_aux[i].prev_timestamp,
246            )
247        })
248    }
249
250    fn write(
251        &self,
252        _memory: &mut TracingMemory,
253        _instruction: &Instruction<F>,
254        _data: Self::WriteData,
255        _record: &mut Self::RecordMut<'_>,
256    ) {
257        // This adapter doesn't write anything
258    }
259}
260
261impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> AdapterTraceFiller<F>
262    for Rv32HeapBranchAdapterFiller<NUM_READS, READ_SIZE>
263{
264    const WIDTH: usize = Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width();
265
266    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
267        // SAFETY:
268        // - caller ensures `adapter_row` contains a valid record representation that was previously
269        //   written by the executor
270        let record: &Rv32HeapBranchAdapterRecord<NUM_READS> =
271            unsafe { get_record_from_slice(&mut adapter_row, ()) };
272        let cols: &mut Rv32HeapBranchAdapterCols<F, NUM_READS, READ_SIZE> =
273            adapter_row.borrow_mut();
274
275        // Range checks:
276        // **NOTE**: Must do the range checks before overwriting the records
277        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
278        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
279        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
280        self.bitwise_lookup_chip.request_range(
281            (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
282            if NUM_READS > 1 {
283                (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits
284            } else {
285                0
286            },
287        );
288
289        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
290        for i in (0..NUM_READS).rev() {
291            mem_helper.fill(
292                record.heap_read_aux[i].prev_timestamp,
293                record.from_timestamp + (i + NUM_READS) as u32,
294                cols.heap_read_aux[i].as_mut(),
295            );
296        }
297
298        for i in (0..NUM_READS).rev() {
299            mem_helper.fill(
300                record.rs_read_aux[i].prev_timestamp,
301                record.from_timestamp + i as u32,
302                cols.rs_read_aux[i].as_mut(),
303            );
304        }
305
306        cols.rs_val
307            .iter_mut()
308            .rev()
309            .zip(record.rs_vals.iter().rev())
310            .for_each(|(col, record)| {
311                *col = record.to_le_bytes().map(F::from_u8);
312            });
313
314        cols.rs_ptr
315            .iter_mut()
316            .rev()
317            .zip(record.rs_ptr.iter().rev())
318            .for_each(|(col, record)| {
319                *col = F::from_u32(*record);
320            });
321
322        cols.from_state.timestamp = F::from_u32(record.from_timestamp);
323        cols.from_state.pc = F::from_u32(record.from_pc);
324    }
325}