openvm_rv32_adapters/
heap_branch.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    iter::once,
5    marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10    arch::{
11        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
12        ExecutionBus, ExecutionState, ImmInstruction, Result, VmAdapterAir, VmAdapterChip,
13        VmAdapterInterface,
14    },
15    system::{
16        memory::{
17            offline_checker::{MemoryBridge, MemoryReadAuxCols},
18            MemoryAddress, MemoryController, OfflineMemory, RecordId,
19        },
20        program::ProgramBus,
21    },
22};
23use openvm_circuit_primitives::bitwise_op_lookup::{
24    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28    instruction::Instruction,
29    program::DEFAULT_PC_STEP,
30    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33    read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36    interaction::InteractionBuilder,
37    p3_air::BaseAir,
38    p3_field::{Field, FieldAlgebra, PrimeField32},
39};
40use serde::{Deserialize, Serialize};
41use serde_big_array::BigArray;
42
43/// This adapter reads from NUM_READS <= 2 pointers.
44/// * The data is read from the heap (address space 2), and the pointers
45///   are read from registers (address space 1).
46/// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
47#[repr(C)]
48#[derive(AlignedBorrow)]
49pub struct Rv32HeapBranchAdapterCols<T, const NUM_READS: usize, const READ_SIZE: usize> {
50    pub from_state: ExecutionState<T>,
51
52    pub rs_ptr: [T; NUM_READS],
53    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
54    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
55
56    pub heap_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
57}
58
59#[derive(Clone, Copy, Debug, derive_new::new)]
60pub struct Rv32HeapBranchAdapterAir<const NUM_READS: usize, const READ_SIZE: usize> {
61    pub(super) execution_bridge: ExecutionBridge,
62    pub(super) memory_bridge: MemoryBridge,
63    pub bus: BitwiseOperationLookupBus,
64    address_bits: usize,
65}
66
67impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize> BaseAir<F>
68    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
69{
70    fn width(&self) -> usize {
71        Rv32HeapBranchAdapterCols::<F, NUM_READS, READ_SIZE>::width()
72    }
73}
74
75impl<AB: InteractionBuilder, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterAir<AB>
76    for Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>
77{
78    type Interface =
79        BasicAdapterInterface<AB::Expr, ImmInstruction<AB::Expr>, NUM_READS, 0, READ_SIZE, 0>;
80
81    fn eval(
82        &self,
83        builder: &mut AB,
84        local: &[AB::Var],
85        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
86    ) {
87        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
88        let timestamp = cols.from_state.timestamp;
89        let mut timestamp_delta: usize = 0;
90        let mut timestamp_pp = || {
91            timestamp_delta += 1;
92            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
93        };
94
95        let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
96        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
97
98        for (ptr, data, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
99            self.memory_bridge
100                .read(MemoryAddress::new(d, ptr), data, timestamp_pp(), aux)
101                .eval(builder, ctx.instruction.is_valid.clone());
102        }
103
104        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))).
105        // This ensures that no overflow occurs when computing memory pointers. Since the number of cells accessed with each address
106        // will be small enough, and combined with the memory argument, it ensures that all the cells accessed in the memory are less than 2^addr_bits.
107        let need_range_check: Vec<AB::Var> = cols
108            .rs_val
109            .iter()
110            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
111            .collect();
112
113        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain the correct amount of bits
114        let limb_shift = AB::F::from_canonical_usize(
115            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
116        );
117
118        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
119        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
120        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
121        for pair in need_range_check.chunks(2) {
122            self.bus
123                .send_range(
124                    pair[0] * limb_shift,
125                    pair.get(1).map(|x| (*x).into()).unwrap_or(AB::Expr::ZERO) * limb_shift, // in case NUM_READS is odd
126                )
127                .eval(builder, ctx.instruction.is_valid.clone());
128        }
129
130        let heap_ptr = cols.rs_val.map(|r| {
131            r.iter().rev().fold(AB::Expr::ZERO, |acc, limb| {
132                acc * AB::F::from_canonical_u32(1 << RV32_CELL_BITS) + (*limb)
133            })
134        });
135        for (ptr, data, aux) in izip!(heap_ptr, ctx.reads, &cols.heap_read_aux) {
136            self.memory_bridge
137                .read(MemoryAddress::new(e, ptr), data, timestamp_pp(), aux)
138                .eval(builder, ctx.instruction.is_valid.clone());
139        }
140
141        self.execution_bridge
142            .execute_and_increment_or_set_pc(
143                ctx.instruction.opcode,
144                [
145                    cols.rs_ptr
146                        .first()
147                        .map(|&x| x.into())
148                        .unwrap_or(AB::Expr::ZERO),
149                    cols.rs_ptr
150                        .get(1)
151                        .map(|&x| x.into())
152                        .unwrap_or(AB::Expr::ZERO),
153                    ctx.instruction.immediate,
154                    d.into(),
155                    e.into(),
156                ],
157                cols.from_state,
158                AB::F::from_canonical_usize(timestamp_delta),
159                (DEFAULT_PC_STEP, ctx.to_pc),
160            )
161            .eval(builder, ctx.instruction.is_valid);
162    }
163
164    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
165        let cols: &Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> = local.borrow();
166        cols.from_state.pc
167    }
168}
169
170pub struct Rv32HeapBranchAdapterChip<F: Field, const NUM_READS: usize, const READ_SIZE: usize> {
171    pub air: Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>,
172    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
173    _marker: PhantomData<F>,
174}
175
176impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize>
177    Rv32HeapBranchAdapterChip<F, NUM_READS, READ_SIZE>
178{
179    pub fn new(
180        execution_bus: ExecutionBus,
181        program_bus: ProgramBus,
182        memory_bridge: MemoryBridge,
183        address_bits: usize,
184        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
185    ) -> Self {
186        assert!(NUM_READS <= 2);
187        assert!(
188            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
189            "address_bits={address_bits} needs to be large enough for high limb range check"
190        );
191        Self {
192            air: Rv32HeapBranchAdapterAir {
193                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
194                memory_bridge,
195                bus: bitwise_lookup_chip.bus(),
196                address_bits,
197            },
198            bitwise_lookup_chip,
199            _marker: PhantomData,
200        }
201    }
202}
203
204#[repr(C)]
205#[derive(Clone, Debug, Serialize, Deserialize)]
206pub struct Rv32HeapBranchReadRecord<const NUM_READS: usize, const READ_SIZE: usize> {
207    #[serde(with = "BigArray")]
208    pub rs_reads: [RecordId; NUM_READS],
209    #[serde(with = "BigArray")]
210    pub heap_reads: [RecordId; NUM_READS],
211}
212
213impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize> VmAdapterChip<F>
214    for Rv32HeapBranchAdapterChip<F, NUM_READS, READ_SIZE>
215{
216    type ReadRecord = Rv32HeapBranchReadRecord<NUM_READS, READ_SIZE>;
217    type WriteRecord = ExecutionState<u32>;
218    type Air = Rv32HeapBranchAdapterAir<NUM_READS, READ_SIZE>;
219    type Interface = BasicAdapterInterface<F, ImmInstruction<F>, NUM_READS, 0, READ_SIZE, 0>;
220
221    fn preprocess(
222        &mut self,
223        memory: &mut MemoryController<F>,
224        instruction: &Instruction<F>,
225    ) -> Result<(
226        <Self::Interface as VmAdapterInterface<F>>::Reads,
227        Self::ReadRecord,
228    )> {
229        let Instruction { a, b, d, e, .. } = *instruction;
230
231        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
232        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
233
234        let mut rs_vals = [0; NUM_READS];
235        let rs_records: [_; NUM_READS] = from_fn(|i| {
236            let addr = if i == 0 { a } else { b };
237            let (record, val) = read_rv32_register(memory, d, addr);
238            rs_vals[i] = val;
239            record
240        });
241
242        let heap_records = rs_vals.map(|address| {
243            assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits));
244            memory.read::<READ_SIZE>(e, F::from_canonical_u32(address))
245        });
246
247        let record = Rv32HeapBranchReadRecord {
248            rs_reads: rs_records,
249            heap_reads: heap_records.map(|r| r.0),
250        };
251        Ok((heap_records.map(|r| r.1), record))
252    }
253
254    fn postprocess(
255        &mut self,
256        memory: &mut MemoryController<F>,
257        _instruction: &Instruction<F>,
258        from_state: ExecutionState<u32>,
259        output: AdapterRuntimeContext<F, Self::Interface>,
260        _read_record: &Self::ReadRecord,
261    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
262        let timestamp_delta = memory.timestamp() - from_state.timestamp;
263        debug_assert!(
264            timestamp_delta == 4,
265            "timestamp delta is {}, expected 4",
266            timestamp_delta
267        );
268
269        Ok((
270            ExecutionState {
271                pc: output.to_pc.unwrap_or(from_state.pc + DEFAULT_PC_STEP),
272                timestamp: memory.timestamp(),
273            },
274            from_state,
275        ))
276    }
277
278    fn generate_trace_row(
279        &self,
280        row_slice: &mut [F],
281        read_record: Self::ReadRecord,
282        write_record: Self::WriteRecord,
283        memory: &OfflineMemory<F>,
284    ) {
285        let aux_cols_factory = memory.aux_cols_factory();
286        let row_slice: &mut Rv32HeapBranchAdapterCols<_, NUM_READS, READ_SIZE> =
287            row_slice.borrow_mut();
288        row_slice.from_state = write_record.map(F::from_canonical_u32);
289
290        let rs_reads = read_record.rs_reads.map(|r| memory.record_by_id(r));
291
292        for (i, rs_read) in rs_reads.iter().enumerate() {
293            row_slice.rs_ptr[i] = rs_read.pointer;
294            row_slice.rs_val[i].copy_from_slice(rs_read.data_slice());
295            aux_cols_factory.generate_read_aux(rs_read, &mut row_slice.rs_read_aux[i]);
296        }
297
298        for (i, heap_read) in read_record.heap_reads.iter().enumerate() {
299            let record = memory.record_by_id(*heap_read);
300            aux_cols_factory.generate_read_aux(record, &mut row_slice.heap_read_aux[i]);
301        }
302
303        // Range checks:
304        let need_range_check: Vec<u32> = rs_reads
305            .iter()
306            .map(|record| {
307                record
308                    .data_at(RV32_REGISTER_NUM_LIMBS - 1)
309                    .as_canonical_u32()
310            })
311            .chain(once(0)) // in case NUM_READS is odd
312            .collect();
313        debug_assert!(self.air.address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
314        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits;
315        for pair in need_range_check.chunks_exact(2) {
316            self.bitwise_lookup_chip
317                .request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
318        }
319    }
320
321    fn air(&self) -> &Self::Air {
322        &self.air
323    }
324}