openvm_rv32_adapters/
heap.rs

1use std::{
2    array::{self, from_fn},
3    borrow::Borrow,
4    marker::PhantomData,
5};
6
7use openvm_circuit::{
8    arch::{
9        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
10        ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
11        VmAdapterInterface,
12    },
13    system::{
14        memory::{offline_checker::MemoryBridge, MemoryController, OfflineMemory},
15        program::ProgramBus,
16    },
17};
18use openvm_circuit_primitives::bitwise_op_lookup::{
19    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
20};
21use openvm_instructions::{
22    instruction::Instruction,
23    program::DEFAULT_PC_STEP,
24    riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
25};
26use openvm_rv32im_circuit::adapters::read_rv32_register;
27use openvm_stark_backend::{
28    interaction::InteractionBuilder,
29    p3_air::BaseAir,
30    p3_field::{Field, PrimeField32},
31};
32
33use super::{
34    vec_heap_generate_trace_row_impl, Rv32VecHeapAdapterAir, Rv32VecHeapAdapterCols,
35    Rv32VecHeapReadRecord, Rv32VecHeapWriteRecord,
36};
37
38/// This adapter reads from NUM_READS <= 2 pointers and writes to 1 pointer.
39/// * The data is read from the heap (address space 2), and the pointers
40///   are read from registers (address space 1).
41/// * Reads are from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
42/// * Writes are to the address in `rd`.
43
44#[derive(Clone, Copy, Debug, derive_new::new)]
45pub struct Rv32HeapAdapterAir<
46    const NUM_READS: usize,
47    const READ_SIZE: usize,
48    const WRITE_SIZE: usize,
49> {
50    pub(super) execution_bridge: ExecutionBridge,
51    pub(super) memory_bridge: MemoryBridge,
52    pub bus: BitwiseOperationLookupBus,
53    /// The max number of bits for an address in memory
54    address_bits: usize,
55}
56
57impl<F: Field, const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize> BaseAir<F>
58    for Rv32HeapAdapterAir<NUM_READS, READ_SIZE, WRITE_SIZE>
59{
60    fn width(&self) -> usize {
61        Rv32VecHeapAdapterCols::<F, NUM_READS, 1, 1, READ_SIZE, WRITE_SIZE>::width()
62    }
63}
64
65impl<
66        AB: InteractionBuilder,
67        const NUM_READS: usize,
68        const READ_SIZE: usize,
69        const WRITE_SIZE: usize,
70    > VmAdapterAir<AB> for Rv32HeapAdapterAir<NUM_READS, READ_SIZE, WRITE_SIZE>
71{
72    type Interface = BasicAdapterInterface<
73        AB::Expr,
74        MinimalInstruction<AB::Expr>,
75        NUM_READS,
76        1,
77        READ_SIZE,
78        WRITE_SIZE,
79    >;
80
81    fn eval(
82        &self,
83        builder: &mut AB,
84        local: &[AB::Var],
85        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
86    ) {
87        let vec_heap_air: Rv32VecHeapAdapterAir<NUM_READS, 1, 1, READ_SIZE, WRITE_SIZE> =
88            Rv32VecHeapAdapterAir::new(
89                self.execution_bridge,
90                self.memory_bridge,
91                self.bus,
92                self.address_bits,
93            );
94        vec_heap_air.eval(builder, local, ctx.into());
95    }
96
97    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
98        let cols: &Rv32VecHeapAdapterCols<_, NUM_READS, 1, 1, READ_SIZE, WRITE_SIZE> =
99            local.borrow();
100        cols.from_state.pc
101    }
102}
103
104pub struct Rv32HeapAdapterChip<
105    F: Field,
106    const NUM_READS: usize,
107    const READ_SIZE: usize,
108    const WRITE_SIZE: usize,
109> {
110    pub air: Rv32HeapAdapterAir<NUM_READS, READ_SIZE, WRITE_SIZE>,
111    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
112    _marker: PhantomData<F>,
113}
114
115impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize>
116    Rv32HeapAdapterChip<F, NUM_READS, READ_SIZE, WRITE_SIZE>
117{
118    pub fn new(
119        execution_bus: ExecutionBus,
120        program_bus: ProgramBus,
121        memory_bridge: MemoryBridge,
122        address_bits: usize,
123        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
124    ) -> Self {
125        assert!(NUM_READS <= 2);
126        assert!(
127            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
128            "address_bits={address_bits} needs to be large enough for high limb range check"
129        );
130        Self {
131            air: Rv32HeapAdapterAir {
132                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
133                memory_bridge,
134                bus: bitwise_lookup_chip.bus(),
135                address_bits,
136            },
137            bitwise_lookup_chip,
138            _marker: PhantomData,
139        }
140    }
141}
142
143impl<F: PrimeField32, const NUM_READS: usize, const READ_SIZE: usize, const WRITE_SIZE: usize>
144    VmAdapterChip<F> for Rv32HeapAdapterChip<F, NUM_READS, READ_SIZE, WRITE_SIZE>
145{
146    type ReadRecord = Rv32VecHeapReadRecord<F, NUM_READS, 1, READ_SIZE>;
147    type WriteRecord = Rv32VecHeapWriteRecord<1, WRITE_SIZE>;
148    type Air = Rv32HeapAdapterAir<NUM_READS, READ_SIZE, WRITE_SIZE>;
149    type Interface =
150        BasicAdapterInterface<F, MinimalInstruction<F>, NUM_READS, 1, READ_SIZE, WRITE_SIZE>;
151
152    fn preprocess(
153        &mut self,
154        memory: &mut MemoryController<F>,
155        instruction: &Instruction<F>,
156    ) -> Result<(
157        <Self::Interface as VmAdapterInterface<F>>::Reads,
158        Self::ReadRecord,
159    )> {
160        let Instruction { a, b, c, d, e, .. } = *instruction;
161
162        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
163        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
164
165        let mut rs_vals = [0; NUM_READS];
166        let rs_records: [_; NUM_READS] = from_fn(|i| {
167            let addr = if i == 0 { b } else { c };
168            let (record, val) = read_rv32_register(memory, d, addr);
169            rs_vals[i] = val;
170            record
171        });
172        let (rd_record, rd_val) = read_rv32_register(memory, d, a);
173
174        let read_records = rs_vals.map(|address| {
175            debug_assert!(address as usize + READ_SIZE - 1 < (1 << self.air.address_bits));
176            [memory.read::<READ_SIZE>(e, F::from_canonical_u32(address))]
177        });
178        let read_data = read_records.map(|r| r[0].1);
179
180        let record = Rv32VecHeapReadRecord {
181            rs: rs_records,
182            rd: rd_record,
183            rd_val: F::from_canonical_u32(rd_val),
184            reads: read_records.map(|r| array::from_fn(|i| r[i].0)),
185        };
186
187        Ok((read_data, record))
188    }
189
190    fn postprocess(
191        &mut self,
192        memory: &mut MemoryController<F>,
193        instruction: &Instruction<F>,
194        from_state: ExecutionState<u32>,
195        output: AdapterRuntimeContext<F, Self::Interface>,
196        read_record: &Self::ReadRecord,
197    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
198        let e = instruction.e;
199        let writes = [memory.write(e, read_record.rd_val, output.writes[0]).0];
200
201        let timestamp_delta = memory.timestamp() - from_state.timestamp;
202        debug_assert!(
203            timestamp_delta == 6,
204            "timestamp delta is {}, expected 6",
205            timestamp_delta
206        );
207
208        Ok((
209            ExecutionState {
210                pc: from_state.pc + DEFAULT_PC_STEP,
211                timestamp: memory.timestamp(),
212            },
213            Self::WriteRecord { from_state, writes },
214        ))
215    }
216
217    fn generate_trace_row(
218        &self,
219        row_slice: &mut [F],
220        read_record: Self::ReadRecord,
221        write_record: Self::WriteRecord,
222        memory: &OfflineMemory<F>,
223    ) {
224        vec_heap_generate_trace_row_impl(
225            row_slice,
226            &read_record,
227            &write_record,
228            self.bitwise_lookup_chip.clone(),
229            self.air.address_bits,
230            memory,
231        );
232    }
233
234    fn air(&self) -> &Self::Air {
235        &self.air
236    }
237}