openvm_rv32_adapters/
vec_heap.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    iter::{once, zip},
5};
6
7use itertools::izip;
8use openvm_circuit::{
9    arch::{
10        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
11        ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir,
12    },
13    system::memory::{
14        offline_checker::{
15            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
16            MemoryWriteBytesAuxRecord,
17        },
18        online::TracingMemory,
19        MemoryAddress, MemoryAuxColsFactory,
20    },
21};
22use openvm_circuit_primitives::{
23    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
24    AlignedBytesBorrow,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28    instruction::Instruction,
29    program::DEFAULT_PC_STEP,
30    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33    abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36    interaction::InteractionBuilder,
37    p3_air::BaseAir,
38    p3_field::{Field, FieldAlgebra, PrimeField32},
39};
40
41/// This adapter reads from R (R <= 2) pointers and writes to 1 pointer.
42/// * The data is read from the heap (address space 2), and the pointers are read from registers
43///   (address space 1).
44/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size `READ_SIZE` from the heap,
45///   starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
46/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the
47///   heap, starting from the address in `rd`.
48#[repr(C)]
49#[derive(AlignedBorrow, Debug)]
50pub struct Rv32VecHeapAdapterCols<
51    T,
52    const NUM_READS: usize,
53    const BLOCKS_PER_READ: usize,
54    const BLOCKS_PER_WRITE: usize,
55    const READ_SIZE: usize,
56    const WRITE_SIZE: usize,
57> {
58    pub from_state: ExecutionState<T>,
59
60    pub rs_ptr: [T; NUM_READS],
61    pub rd_ptr: T,
62
63    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
64    pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
65
66    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
67    pub rd_read_aux: MemoryReadAuxCols<T>,
68
69    pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
70    pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
71}
72
73#[allow(dead_code)]
74#[derive(Clone, Copy, Debug, derive_new::new)]
75pub struct Rv32VecHeapAdapterAir<
76    const NUM_READS: usize,
77    const BLOCKS_PER_READ: usize,
78    const BLOCKS_PER_WRITE: usize,
79    const READ_SIZE: usize,
80    const WRITE_SIZE: usize,
81> {
82    pub(super) execution_bridge: ExecutionBridge,
83    pub(super) memory_bridge: MemoryBridge,
84    pub bus: BitwiseOperationLookupBus,
85    /// The max number of bits for an address in memory
86    address_bits: usize,
87}
88
89impl<
90        F: Field,
91        const NUM_READS: usize,
92        const BLOCKS_PER_READ: usize,
93        const BLOCKS_PER_WRITE: usize,
94        const READ_SIZE: usize,
95        const WRITE_SIZE: usize,
96    > BaseAir<F>
97    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
98{
99    fn width(&self) -> usize {
100        Rv32VecHeapAdapterCols::<
101            F,
102            NUM_READS,
103            BLOCKS_PER_READ,
104            BLOCKS_PER_WRITE,
105            READ_SIZE,
106            WRITE_SIZE,
107        >::width()
108    }
109}
110
111impl<
112        AB: InteractionBuilder,
113        const NUM_READS: usize,
114        const BLOCKS_PER_READ: usize,
115        const BLOCKS_PER_WRITE: usize,
116        const READ_SIZE: usize,
117        const WRITE_SIZE: usize,
118    > VmAdapterAir<AB>
119    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
120{
121    type Interface = VecHeapAdapterInterface<
122        AB::Expr,
123        NUM_READS,
124        BLOCKS_PER_READ,
125        BLOCKS_PER_WRITE,
126        READ_SIZE,
127        WRITE_SIZE,
128    >;
129
130    fn eval(
131        &self,
132        builder: &mut AB,
133        local: &[AB::Var],
134        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
135    ) {
136        let cols: &Rv32VecHeapAdapterCols<
137            _,
138            NUM_READS,
139            BLOCKS_PER_READ,
140            BLOCKS_PER_WRITE,
141            READ_SIZE,
142            WRITE_SIZE,
143        > = local.borrow();
144        let timestamp = cols.from_state.timestamp;
145        let mut timestamp_delta: usize = 0;
146        let mut timestamp_pp = || {
147            timestamp_delta += 1;
148            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
149        };
150
151        // Read register values for rs, rd
152        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
153            cols.rd_ptr,
154            cols.rd_val,
155            &cols.rd_read_aux,
156        ))) {
157            self.memory_bridge
158                .read(
159                    MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
160                    val,
161                    timestamp_pp(),
162                    aux,
163                )
164                .eval(builder, ctx.instruction.is_valid.clone());
165        }
166
167        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits -
168        // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow
169        // occurs when computing memory pointers. Since the number of cells accessed with each
170        // address will be small enough, and combined with the memory argument, it ensures
171        // that all the cells accessed in the memory are less than 2^addr_bits.
172        let need_range_check: Vec<AB::Var> = cols
173            .rs_val
174            .iter()
175            .chain(std::iter::repeat_n(&cols.rd_val, 2))
176            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
177            .collect();
178
179        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain
180        // the correct amount of bits
181        let limb_shift = AB::F::from_canonical_usize(
182            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
183        );
184
185        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
186        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
187        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
188        for pair in need_range_check.chunks_exact(2) {
189            self.bus
190                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
191                .eval(builder, ctx.instruction.is_valid.clone());
192        }
193
194        // Compose the u32 register value into single field element, with `abstract_compose`
195        let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
196        let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
197
198        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
199        // Reads from heap
200        for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
201            for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
202                self.memory_bridge
203                    .read(
204                        MemoryAddress::new(
205                            e,
206                            address.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
207                        ),
208                        read,
209                        timestamp_pp(),
210                        aux,
211                    )
212                    .eval(builder, ctx.instruction.is_valid.clone());
213            }
214        }
215
216        // Writes to heap
217        for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
218            self.memory_bridge
219                .write(
220                    MemoryAddress::new(
221                        e,
222                        rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
223                    ),
224                    write,
225                    timestamp_pp(),
226                    aux,
227                )
228                .eval(builder, ctx.instruction.is_valid.clone());
229        }
230
231        self.execution_bridge
232            .execute_and_increment_or_set_pc(
233                ctx.instruction.opcode,
234                [
235                    cols.rd_ptr.into(),
236                    cols.rs_ptr
237                        .first()
238                        .map(|&x| x.into())
239                        .unwrap_or(AB::Expr::ZERO),
240                    cols.rs_ptr
241                        .get(1)
242                        .map(|&x| x.into())
243                        .unwrap_or(AB::Expr::ZERO),
244                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
245                    e.into(),
246                ],
247                cols.from_state,
248                AB::F::from_canonical_usize(timestamp_delta),
249                (DEFAULT_PC_STEP, ctx.to_pc),
250            )
251            .eval(builder, ctx.instruction.is_valid.clone());
252    }
253
254    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
255        let cols: &Rv32VecHeapAdapterCols<
256            _,
257            NUM_READS,
258            BLOCKS_PER_READ,
259            BLOCKS_PER_WRITE,
260            READ_SIZE,
261            WRITE_SIZE,
262        > = local.borrow();
263        cols.from_state.pc
264    }
265}
266
267// Intermediate type that should not be copied or cloned and should be directly written to
268#[repr(C)]
269#[derive(AlignedBytesBorrow, Debug)]
270pub struct Rv32VecHeapAdapterRecord<
271    const NUM_READS: usize,
272    const BLOCKS_PER_READ: usize,
273    const BLOCKS_PER_WRITE: usize,
274    const READ_SIZE: usize,
275    const WRITE_SIZE: usize,
276> {
277    pub from_pc: u32,
278    pub from_timestamp: u32,
279
280    pub rs_ptrs: [u32; NUM_READS],
281    pub rd_ptr: u32,
282
283    pub rs_vals: [u32; NUM_READS],
284    pub rd_val: u32,
285
286    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
287    pub rd_read_aux: MemoryReadAuxRecord,
288
289    pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
290    pub writes_aux: [MemoryWriteBytesAuxRecord<WRITE_SIZE>; BLOCKS_PER_WRITE],
291}
292
293#[derive(derive_new::new, Clone, Copy)]
294pub struct Rv32VecHeapAdapterExecutor<
295    const NUM_READS: usize,
296    const BLOCKS_PER_READ: usize,
297    const BLOCKS_PER_WRITE: usize,
298    const READ_SIZE: usize,
299    const WRITE_SIZE: usize,
300> {
301    pointer_max_bits: usize,
302}
303
304#[derive(derive_new::new)]
305pub struct Rv32VecHeapAdapterFiller<
306    const NUM_READS: usize,
307    const BLOCKS_PER_READ: usize,
308    const BLOCKS_PER_WRITE: usize,
309    const READ_SIZE: usize,
310    const WRITE_SIZE: usize,
311> {
312    pointer_max_bits: usize,
313    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
314}
315
316impl<
317        F: PrimeField32,
318        const NUM_READS: usize,
319        const BLOCKS_PER_READ: usize,
320        const BLOCKS_PER_WRITE: usize,
321        const READ_SIZE: usize,
322        const WRITE_SIZE: usize,
323    > AdapterTraceExecutor<F>
324    for Rv32VecHeapAdapterExecutor<
325        NUM_READS,
326        BLOCKS_PER_READ,
327        BLOCKS_PER_WRITE,
328        READ_SIZE,
329        WRITE_SIZE,
330    >
331{
332    const WIDTH: usize = Rv32VecHeapAdapterCols::<
333        F,
334        NUM_READS,
335        BLOCKS_PER_READ,
336        BLOCKS_PER_WRITE,
337        READ_SIZE,
338        WRITE_SIZE,
339    >::width();
340    type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS];
341    type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE];
342    type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord<
343        NUM_READS,
344        BLOCKS_PER_READ,
345        BLOCKS_PER_WRITE,
346        READ_SIZE,
347        WRITE_SIZE,
348    >;
349
350    #[inline(always)]
351    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
352        record.from_pc = pc;
353        record.from_timestamp = memory.timestamp;
354    }
355
356    fn read(
357        &self,
358        memory: &mut TracingMemory,
359        instruction: &Instruction<F>,
360        record: &mut &mut Rv32VecHeapAdapterRecord<
361            NUM_READS,
362            BLOCKS_PER_READ,
363            BLOCKS_PER_WRITE,
364            READ_SIZE,
365            WRITE_SIZE,
366        >,
367    ) -> Self::ReadData {
368        let &Instruction { a, b, c, d, e, .. } = instruction;
369
370        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
371        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
372
373        // Read register values
374        record.rs_vals = from_fn(|i| {
375            record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32();
376            u32::from_le_bytes(tracing_read(
377                memory,
378                RV32_REGISTER_AS,
379                record.rs_ptrs[i],
380                &mut record.rs_read_aux[i].prev_timestamp,
381            ))
382        });
383
384        record.rd_ptr = a.as_canonical_u32();
385        record.rd_val = u32::from_le_bytes(tracing_read(
386            memory,
387            RV32_REGISTER_AS,
388            a.as_canonical_u32(),
389            &mut record.rd_read_aux.prev_timestamp,
390        ));
391
392        // Read memory values
393        from_fn(|i| {
394            debug_assert!(
395                (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32)
396                    < (1 << self.pointer_max_bits) as u32
397            );
398            from_fn(|j| {
399                tracing_read(
400                    memory,
401                    RV32_MEMORY_AS,
402                    record.rs_vals[i] + (j * READ_SIZE) as u32,
403                    &mut record.reads_aux[i][j].prev_timestamp,
404                )
405            })
406        })
407    }
408
409    fn write(
410        &self,
411        memory: &mut TracingMemory,
412        instruction: &Instruction<F>,
413        data: Self::WriteData,
414        record: &mut &mut Rv32VecHeapAdapterRecord<
415            NUM_READS,
416            BLOCKS_PER_READ,
417            BLOCKS_PER_WRITE,
418            READ_SIZE,
419            WRITE_SIZE,
420        >,
421    ) {
422        debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS);
423
424        debug_assert!(
425            record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1
426                < (1 << self.pointer_max_bits)
427        );
428
429        #[allow(clippy::needless_range_loop)]
430        for i in 0..BLOCKS_PER_WRITE {
431            tracing_write(
432                memory,
433                RV32_MEMORY_AS,
434                record.rd_val + (i * WRITE_SIZE) as u32,
435                data[i],
436                &mut record.writes_aux[i].prev_timestamp,
437                &mut record.writes_aux[i].prev_data,
438            );
439        }
440    }
441}
442
443impl<
444        F: PrimeField32,
445        const NUM_READS: usize,
446        const BLOCKS_PER_READ: usize,
447        const BLOCKS_PER_WRITE: usize,
448        const READ_SIZE: usize,
449        const WRITE_SIZE: usize,
450    > AdapterTraceFiller<F>
451    for Rv32VecHeapAdapterFiller<
452        NUM_READS,
453        BLOCKS_PER_READ,
454        BLOCKS_PER_WRITE,
455        READ_SIZE,
456        WRITE_SIZE,
457    >
458{
459    const WIDTH: usize = Rv32VecHeapAdapterCols::<
460        F,
461        NUM_READS,
462        BLOCKS_PER_READ,
463        BLOCKS_PER_WRITE,
464        READ_SIZE,
465        WRITE_SIZE,
466    >::width();
467
468    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
469        // SAFETY:
470        // - caller ensures `adapter_row` contains a valid record representation that was previously
471        //   written by the executor
472        let record: &Rv32VecHeapAdapterRecord<
473            NUM_READS,
474            BLOCKS_PER_READ,
475            BLOCKS_PER_WRITE,
476            READ_SIZE,
477            WRITE_SIZE,
478        > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
479
480        let cols: &mut Rv32VecHeapAdapterCols<
481            F,
482            NUM_READS,
483            BLOCKS_PER_READ,
484            BLOCKS_PER_WRITE,
485            READ_SIZE,
486            WRITE_SIZE,
487        > = adapter_row.borrow_mut();
488
489        // Range checks:
490        // **NOTE**: Must do the range checks before overwriting the records
491        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
492        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
493        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
494        if NUM_READS > 1 {
495            self.bitwise_lookup_chip.request_range(
496                (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
497                (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits,
498            );
499            self.bitwise_lookup_chip.request_range(
500                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
501                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
502            );
503        } else {
504            self.bitwise_lookup_chip.request_range(
505                (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
506                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
507            );
508        }
509
510        let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE;
511        let mut timestamp = record.from_timestamp + timestamp_delta as u32;
512        let mut timestamp_mm = || {
513            timestamp -= 1;
514            timestamp
515        };
516
517        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
518        record
519            .writes_aux
520            .iter()
521            .rev()
522            .zip(cols.writes_aux.iter_mut().rev())
523            .for_each(|(write, cols_write)| {
524                cols_write.set_prev_data(write.prev_data.map(F::from_canonical_u8));
525                mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut());
526            });
527
528        record
529            .reads_aux
530            .iter()
531            .zip(cols.reads_aux.iter_mut())
532            .rev()
533            .for_each(|(reads, cols_reads)| {
534                reads
535                    .iter()
536                    .zip(cols_reads.iter_mut())
537                    .rev()
538                    .for_each(|(read, cols_read)| {
539                        mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut());
540                    });
541            });
542
543        mem_helper.fill(
544            record.rd_read_aux.prev_timestamp,
545            timestamp_mm(),
546            cols.rd_read_aux.as_mut(),
547        );
548
549        record
550            .rs_read_aux
551            .iter()
552            .zip(cols.rs_read_aux.iter_mut())
553            .rev()
554            .for_each(|(aux, cols_aux)| {
555                mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut());
556            });
557
558        cols.rd_val = record.rd_val.to_le_bytes().map(F::from_canonical_u8);
559        cols.rs_val
560            .iter_mut()
561            .rev()
562            .zip(record.rs_vals.iter().rev())
563            .for_each(|(cols_val, val)| {
564                *cols_val = val.to_le_bytes().map(F::from_canonical_u8);
565            });
566        cols.rd_ptr = F::from_canonical_u32(record.rd_ptr);
567        cols.rs_ptr
568            .iter_mut()
569            .rev()
570            .zip(record.rs_ptrs.iter().rev())
571            .for_each(|(cols_ptr, ptr)| {
572                *cols_ptr = F::from_canonical_u32(*ptr);
573            });
574        cols.from_state.timestamp = F::from_canonical_u32(record.from_timestamp);
575        cols.from_state.pc = F::from_canonical_u32(record.from_pc);
576    }
577}