openvm_rv32_adapters/
vec_heap.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    iter::{once, zip},
5    marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10    arch::{
11        AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
12        Result, VecHeapAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
13    },
14    system::{
15        memory::{
16            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17            MemoryAddress, MemoryController, OfflineMemory, RecordId,
18        },
19        program::ProgramBus,
20    },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27    instruction::Instruction,
28    program::DEFAULT_PC_STEP,
29    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32    abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::BaseAir,
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_with::serde_as;
41
42/// This adapter reads from R (R <= 2) pointers and writes to 1 pointer.
43/// * The data is read from the heap (address space 2), and the pointers
44///   are read from registers (address space 1).
45/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size
46///   `READ_SIZE` from the heap, starting from the addresses in `rs[0]`
47///   (and `rs[1]` if `R = 2`).
48/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of
49///   size `WRITE_SIZE` to the heap, starting from the address in `rd`.
50#[derive(Clone)]
51pub struct Rv32VecHeapAdapterChip<
52    F: Field,
53    const NUM_READS: usize,
54    const BLOCKS_PER_READ: usize,
55    const BLOCKS_PER_WRITE: usize,
56    const READ_SIZE: usize,
57    const WRITE_SIZE: usize,
58> {
59    pub air:
60        Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>,
61    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
62    _marker: PhantomData<F>,
63}
64
65impl<
66        F: PrimeField32,
67        const NUM_READS: usize,
68        const BLOCKS_PER_READ: usize,
69        const BLOCKS_PER_WRITE: usize,
70        const READ_SIZE: usize,
71        const WRITE_SIZE: usize,
72    >
73    Rv32VecHeapAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
74{
75    pub fn new(
76        execution_bus: ExecutionBus,
77        program_bus: ProgramBus,
78        memory_bridge: MemoryBridge,
79        address_bits: usize,
80        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
81    ) -> Self {
82        assert!(NUM_READS <= 2);
83        assert!(
84            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
85            "address_bits={address_bits} needs to be large enough for high limb range check"
86        );
87        Self {
88            air: Rv32VecHeapAdapterAir {
89                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
90                memory_bridge,
91                bus: bitwise_lookup_chip.bus(),
92                address_bits,
93            },
94            bitwise_lookup_chip,
95            _marker: PhantomData,
96        }
97    }
98}
99
100#[repr(C)]
101#[serde_as]
102#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
103#[serde(bound = "F: Field")]
104pub struct Rv32VecHeapReadRecord<
105    F: Field,
106    const NUM_READS: usize,
107    const BLOCKS_PER_READ: usize,
108    const READ_SIZE: usize,
109> {
110    /// Read register value from address space e=1
111    #[serde_as(as = "[_; NUM_READS]")]
112    pub rs: [RecordId; NUM_READS],
113    /// Read register value from address space d=1
114    pub rd: RecordId,
115
116    pub rd_val: F,
117
118    #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")]
119    pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS],
120}
121
122#[repr(C)]
123#[serde_as]
124#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
125pub struct Rv32VecHeapWriteRecord<const BLOCKS_PER_WRITE: usize, const WRITE_SIZE: usize> {
126    pub from_state: ExecutionState<u32>,
127    #[serde_as(as = "[_; BLOCKS_PER_WRITE]")]
128    pub writes: [RecordId; BLOCKS_PER_WRITE],
129}
130
131#[repr(C)]
132#[derive(AlignedBorrow)]
133pub struct Rv32VecHeapAdapterCols<
134    T,
135    const NUM_READS: usize,
136    const BLOCKS_PER_READ: usize,
137    const BLOCKS_PER_WRITE: usize,
138    const READ_SIZE: usize,
139    const WRITE_SIZE: usize,
140> {
141    pub from_state: ExecutionState<T>,
142
143    pub rs_ptr: [T; NUM_READS],
144    pub rd_ptr: T,
145
146    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
147    pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
148
149    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
150    pub rd_read_aux: MemoryReadAuxCols<T>,
151
152    pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
153    pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
154}
155
156#[allow(dead_code)]
157#[derive(Clone, Copy, Debug, derive_new::new)]
158pub struct Rv32VecHeapAdapterAir<
159    const NUM_READS: usize,
160    const BLOCKS_PER_READ: usize,
161    const BLOCKS_PER_WRITE: usize,
162    const READ_SIZE: usize,
163    const WRITE_SIZE: usize,
164> {
165    pub(super) execution_bridge: ExecutionBridge,
166    pub(super) memory_bridge: MemoryBridge,
167    pub bus: BitwiseOperationLookupBus,
168    /// The max number of bits for an address in memory
169    address_bits: usize,
170}
171
172impl<
173        F: Field,
174        const NUM_READS: usize,
175        const BLOCKS_PER_READ: usize,
176        const BLOCKS_PER_WRITE: usize,
177        const READ_SIZE: usize,
178        const WRITE_SIZE: usize,
179    > BaseAir<F>
180    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
181{
182    fn width(&self) -> usize {
183        Rv32VecHeapAdapterCols::<
184            F,
185            NUM_READS,
186            BLOCKS_PER_READ,
187            BLOCKS_PER_WRITE,
188            READ_SIZE,
189            WRITE_SIZE,
190        >::width()
191    }
192}
193
194impl<
195        AB: InteractionBuilder,
196        const NUM_READS: usize,
197        const BLOCKS_PER_READ: usize,
198        const BLOCKS_PER_WRITE: usize,
199        const READ_SIZE: usize,
200        const WRITE_SIZE: usize,
201    > VmAdapterAir<AB>
202    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
203{
204    type Interface = VecHeapAdapterInterface<
205        AB::Expr,
206        NUM_READS,
207        BLOCKS_PER_READ,
208        BLOCKS_PER_WRITE,
209        READ_SIZE,
210        WRITE_SIZE,
211    >;
212
213    fn eval(
214        &self,
215        builder: &mut AB,
216        local: &[AB::Var],
217        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
218    ) {
219        let cols: &Rv32VecHeapAdapterCols<
220            _,
221            NUM_READS,
222            BLOCKS_PER_READ,
223            BLOCKS_PER_WRITE,
224            READ_SIZE,
225            WRITE_SIZE,
226        > = local.borrow();
227        let timestamp = cols.from_state.timestamp;
228        let mut timestamp_delta: usize = 0;
229        let mut timestamp_pp = || {
230            timestamp_delta += 1;
231            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
232        };
233
234        // Read register values for rs, rd
235        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
236            cols.rd_ptr,
237            cols.rd_val,
238            &cols.rd_read_aux,
239        ))) {
240            self.memory_bridge
241                .read(
242                    MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
243                    val,
244                    timestamp_pp(),
245                    aux,
246                )
247                .eval(builder, ctx.instruction.is_valid.clone());
248        }
249
250        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))).
251        // This ensures that no overflow occurs when computing memory pointers. Since the number of cells accessed with each address
252        // will be small enough, and combined with the memory argument, it ensures that all the cells accessed in the memory are less than 2^addr_bits.
253        let need_range_check: Vec<AB::Var> = cols
254            .rs_val
255            .iter()
256            .chain(std::iter::repeat(&cols.rd_val).take(2))
257            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
258            .collect();
259
260        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain the correct amount of bits
261        let limb_shift = AB::F::from_canonical_usize(
262            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
263        );
264
265        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
266        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
267        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
268        for pair in need_range_check.chunks_exact(2) {
269            self.bus
270                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
271                .eval(builder, ctx.instruction.is_valid.clone());
272        }
273
274        // Compose the u32 register value into single field element, with `abstract_compose`
275        let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
276        let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
277
278        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
279        // Reads from heap
280        for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
281            for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
282                self.memory_bridge
283                    .read(
284                        MemoryAddress::new(
285                            e,
286                            address.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
287                        ),
288                        read,
289                        timestamp_pp(),
290                        aux,
291                    )
292                    .eval(builder, ctx.instruction.is_valid.clone());
293            }
294        }
295
296        // Writes to heap
297        for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
298            self.memory_bridge
299                .write(
300                    MemoryAddress::new(
301                        e,
302                        rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
303                    ),
304                    write,
305                    timestamp_pp(),
306                    aux,
307                )
308                .eval(builder, ctx.instruction.is_valid.clone());
309        }
310
311        self.execution_bridge
312            .execute_and_increment_or_set_pc(
313                ctx.instruction.opcode,
314                [
315                    cols.rd_ptr.into(),
316                    cols.rs_ptr
317                        .first()
318                        .map(|&x| x.into())
319                        .unwrap_or(AB::Expr::ZERO),
320                    cols.rs_ptr
321                        .get(1)
322                        .map(|&x| x.into())
323                        .unwrap_or(AB::Expr::ZERO),
324                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
325                    e.into(),
326                ],
327                cols.from_state,
328                AB::F::from_canonical_usize(timestamp_delta),
329                (DEFAULT_PC_STEP, ctx.to_pc),
330            )
331            .eval(builder, ctx.instruction.is_valid.clone());
332    }
333
334    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
335        let cols: &Rv32VecHeapAdapterCols<
336            _,
337            NUM_READS,
338            BLOCKS_PER_READ,
339            BLOCKS_PER_WRITE,
340            READ_SIZE,
341            WRITE_SIZE,
342        > = local.borrow();
343        cols.from_state.pc
344    }
345}
346
347impl<
348        F: PrimeField32,
349        const NUM_READS: usize,
350        const BLOCKS_PER_READ: usize,
351        const BLOCKS_PER_WRITE: usize,
352        const READ_SIZE: usize,
353        const WRITE_SIZE: usize,
354    > VmAdapterChip<F>
355    for Rv32VecHeapAdapterChip<
356        F,
357        NUM_READS,
358        BLOCKS_PER_READ,
359        BLOCKS_PER_WRITE,
360        READ_SIZE,
361        WRITE_SIZE,
362    >
363{
364    type ReadRecord = Rv32VecHeapReadRecord<F, NUM_READS, BLOCKS_PER_READ, READ_SIZE>;
365    type WriteRecord = Rv32VecHeapWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>;
366    type Air =
367        Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>;
368    type Interface = VecHeapAdapterInterface<
369        F,
370        NUM_READS,
371        BLOCKS_PER_READ,
372        BLOCKS_PER_WRITE,
373        READ_SIZE,
374        WRITE_SIZE,
375    >;
376
377    fn preprocess(
378        &mut self,
379        memory: &mut MemoryController<F>,
380        instruction: &Instruction<F>,
381    ) -> Result<(
382        <Self::Interface as VmAdapterInterface<F>>::Reads,
383        Self::ReadRecord,
384    )> {
385        let Instruction { a, b, c, d, e, .. } = *instruction;
386
387        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
388        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
389
390        // Read register values
391        let mut rs_vals = [0; NUM_READS];
392        let rs_records: [_; NUM_READS] = from_fn(|i| {
393            let addr = if i == 0 { b } else { c };
394            let (record, val) = read_rv32_register(memory, d, addr);
395            rs_vals[i] = val;
396            record
397        });
398        let (rd_record, rd_val) = read_rv32_register(memory, d, a);
399
400        // Read memory values
401        let read_records = rs_vals.map(|address| {
402            assert!(
403                address as usize + READ_SIZE * BLOCKS_PER_READ - 1 < (1 << self.air.address_bits)
404            );
405            from_fn(|i| {
406                memory.read::<READ_SIZE>(e, F::from_canonical_u32(address + (i * READ_SIZE) as u32))
407            })
408        });
409        let read_data = read_records.map(|r| r.map(|x| x.1));
410        assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits));
411
412        let record = Rv32VecHeapReadRecord {
413            rs: rs_records,
414            rd: rd_record,
415            rd_val: F::from_canonical_u32(rd_val),
416            reads: read_records.map(|r| r.map(|x| x.0)),
417        };
418
419        Ok((read_data, record))
420    }
421
422    fn postprocess(
423        &mut self,
424        memory: &mut MemoryController<F>,
425        instruction: &Instruction<F>,
426        from_state: ExecutionState<u32>,
427        output: AdapterRuntimeContext<F, Self::Interface>,
428        read_record: &Self::ReadRecord,
429    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
430        let e = instruction.e;
431        let mut i = 0;
432        let writes = output.writes.map(|write| {
433            let (record_id, _) = memory.write(
434                e,
435                read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32),
436                write,
437            );
438            i += 1;
439            record_id
440        });
441
442        Ok((
443            ExecutionState {
444                pc: from_state.pc + DEFAULT_PC_STEP,
445                timestamp: memory.timestamp(),
446            },
447            Self::WriteRecord { from_state, writes },
448        ))
449    }
450
451    fn generate_trace_row(
452        &self,
453        row_slice: &mut [F],
454        read_record: Self::ReadRecord,
455        write_record: Self::WriteRecord,
456        memory: &OfflineMemory<F>,
457    ) {
458        vec_heap_generate_trace_row_impl(
459            row_slice,
460            &read_record,
461            &write_record,
462            self.bitwise_lookup_chip.clone(),
463            self.air.address_bits,
464            memory,
465        )
466    }
467
468    fn air(&self) -> &Self::Air {
469        &self.air
470    }
471}
472
473pub(super) fn vec_heap_generate_trace_row_impl<
474    F: PrimeField32,
475    const NUM_READS: usize,
476    const BLOCKS_PER_READ: usize,
477    const BLOCKS_PER_WRITE: usize,
478    const READ_SIZE: usize,
479    const WRITE_SIZE: usize,
480>(
481    row_slice: &mut [F],
482    read_record: &Rv32VecHeapReadRecord<F, NUM_READS, BLOCKS_PER_READ, READ_SIZE>,
483    write_record: &Rv32VecHeapWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>,
484    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
485    address_bits: usize,
486    memory: &OfflineMemory<F>,
487) {
488    let aux_cols_factory = memory.aux_cols_factory();
489    let row_slice: &mut Rv32VecHeapAdapterCols<
490        F,
491        NUM_READS,
492        BLOCKS_PER_READ,
493        BLOCKS_PER_WRITE,
494        READ_SIZE,
495        WRITE_SIZE,
496    > = row_slice.borrow_mut();
497    row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
498
499    let rd = memory.record_by_id(read_record.rd);
500    let rs = read_record
501        .rs
502        .into_iter()
503        .map(|r| memory.record_by_id(r))
504        .collect::<Vec<_>>();
505
506    row_slice.rd_ptr = rd.pointer;
507    row_slice.rd_val.copy_from_slice(rd.data_slice());
508
509    for (i, r) in rs.iter().enumerate() {
510        row_slice.rs_ptr[i] = r.pointer;
511        row_slice.rs_val[i].copy_from_slice(r.data_slice());
512        aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
513    }
514
515    aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux);
516
517    for (i, reads) in read_record.reads.iter().enumerate() {
518        for (j, &x) in reads.iter().enumerate() {
519            let record = memory.record_by_id(x);
520            aux_cols_factory.generate_read_aux(record, &mut row_slice.reads_aux[i][j]);
521        }
522    }
523
524    for (i, &w) in write_record.writes.iter().enumerate() {
525        let record = memory.record_by_id(w);
526        aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]);
527    }
528
529    // Range checks:
530    let need_range_check: Vec<u32> = rs
531        .iter()
532        .chain(std::iter::repeat(&rd).take(2))
533        .map(|record| {
534            record
535                .data_at(RV32_REGISTER_NUM_LIMBS - 1)
536                .as_canonical_u32()
537        })
538        .collect();
539    debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
540    let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits;
541    for pair in need_range_check.chunks_exact(2) {
542        bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
543    }
544}