openvm_rv32_adapters/
vec_heap.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    iter::{once, zip},
5};
6
7use itertools::izip;
8use openvm_circuit::{
9    arch::{
10        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
11        ExecutionBridge, ExecutionState, VecHeapAdapterInterface, VmAdapterAir,
12    },
13    system::memory::{
14        offline_checker::{
15            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
16            MemoryWriteBytesAuxRecord,
17        },
18        online::TracingMemory,
19        MemoryAddress, MemoryAuxColsFactory,
20    },
21};
22use openvm_circuit_primitives::{
23    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
24    AlignedBytesBorrow,
25};
26use openvm_circuit_primitives_derive::AlignedBorrow;
27use openvm_instructions::{
28    instruction::Instruction,
29    program::DEFAULT_PC_STEP,
30    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
31};
32use openvm_rv32im_circuit::adapters::{
33    abstract_compose, tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
34};
35use openvm_stark_backend::{
36    interaction::InteractionBuilder,
37    p3_air::BaseAir,
38    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
39};
40
41/// This adapter reads from R (R <= 2) pointers and writes to 1 pointer.
42/// * The data is read from the heap (address space 2), and the pointers are read from registers
43///   (address space 1).
44/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size `READ_SIZE` from the heap,
45///   starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
46/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of size `WRITE_SIZE` to the
47///   heap, starting from the address in `rd`.
48#[repr(C)]
49#[derive(AlignedBorrow, Debug)]
50pub struct Rv32VecHeapAdapterCols<
51    T,
52    const NUM_READS: usize,
53    const BLOCKS_PER_READ: usize,
54    const BLOCKS_PER_WRITE: usize,
55    const READ_SIZE: usize,
56    const WRITE_SIZE: usize,
57> {
58    pub from_state: ExecutionState<T>,
59
60    pub rs_ptr: [T; NUM_READS],
61    pub rd_ptr: T,
62
63    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
64    pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
65
66    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
67    pub rd_read_aux: MemoryReadAuxCols<T>,
68
69    pub reads_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
70    pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
71}
72
73#[allow(dead_code)]
74#[derive(Clone, Copy, Debug, derive_new::new)]
75pub struct Rv32VecHeapAdapterAir<
76    const NUM_READS: usize,
77    const BLOCKS_PER_READ: usize,
78    const BLOCKS_PER_WRITE: usize,
79    const READ_SIZE: usize,
80    const WRITE_SIZE: usize,
81> {
82    pub(super) execution_bridge: ExecutionBridge,
83    pub(super) memory_bridge: MemoryBridge,
84    pub bus: BitwiseOperationLookupBus,
85    /// The max number of bits for an address in memory
86    address_bits: usize,
87}
88
89impl<
90        F: Field,
91        const NUM_READS: usize,
92        const BLOCKS_PER_READ: usize,
93        const BLOCKS_PER_WRITE: usize,
94        const READ_SIZE: usize,
95        const WRITE_SIZE: usize,
96    > BaseAir<F>
97    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
98{
99    fn width(&self) -> usize {
100        Rv32VecHeapAdapterCols::<
101            F,
102            NUM_READS,
103            BLOCKS_PER_READ,
104            BLOCKS_PER_WRITE,
105            READ_SIZE,
106            WRITE_SIZE,
107        >::width()
108    }
109}
110
111impl<
112        AB: InteractionBuilder,
113        const NUM_READS: usize,
114        const BLOCKS_PER_READ: usize,
115        const BLOCKS_PER_WRITE: usize,
116        const READ_SIZE: usize,
117        const WRITE_SIZE: usize,
118    > VmAdapterAir<AB>
119    for Rv32VecHeapAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCKS_PER_WRITE, READ_SIZE, WRITE_SIZE>
120{
121    type Interface = VecHeapAdapterInterface<
122        AB::Expr,
123        NUM_READS,
124        BLOCKS_PER_READ,
125        BLOCKS_PER_WRITE,
126        READ_SIZE,
127        WRITE_SIZE,
128    >;
129
130    fn eval(
131        &self,
132        builder: &mut AB,
133        local: &[AB::Var],
134        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
135    ) {
136        let cols: &Rv32VecHeapAdapterCols<
137            _,
138            NUM_READS,
139            BLOCKS_PER_READ,
140            BLOCKS_PER_WRITE,
141            READ_SIZE,
142            WRITE_SIZE,
143        > = local.borrow();
144        let timestamp = cols.from_state.timestamp;
145        let mut timestamp_delta: usize = 0;
146        let mut timestamp_pp = || {
147            timestamp_delta += 1;
148            timestamp + AB::F::from_usize(timestamp_delta - 1)
149        };
150
151        // Read register values for rs, rd
152        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux).chain(once((
153            cols.rd_ptr,
154            cols.rd_val,
155            &cols.rd_read_aux,
156        ))) {
157            self.memory_bridge
158                .read(
159                    MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), ptr),
160                    val,
161                    timestamp_pp(),
162                    aux,
163                )
164                .eval(builder, ctx.instruction.is_valid.clone());
165        }
166
167        // We constrain the highest limbs of heap pointers to be less than 2^(addr_bits -
168        // (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1))). This ensures that no overflow
169        // occurs when computing memory pointers. Since the number of cells accessed with each
170        // address will be small enough, and combined with the memory argument, it ensures
171        // that all the cells accessed in the memory are less than 2^addr_bits.
172        let need_range_check: Vec<AB::Var> = cols
173            .rs_val
174            .iter()
175            .chain(std::iter::repeat_n(&cols.rd_val, 2))
176            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1])
177            .collect();
178
179        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain
180        // the correct amount of bits
181        let limb_shift =
182            AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
183
184        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
185        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
186        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
187        for pair in need_range_check.chunks_exact(2) {
188            self.bus
189                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
190                .eval(builder, ctx.instruction.is_valid.clone());
191        }
192
193        // Compose the u32 register value into single field element, with `abstract_compose`
194        let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
195        let rs_val_f: [AB::Expr; NUM_READS] = cols.rs_val.map(abstract_compose);
196
197        let e = AB::F::from_u32(RV32_MEMORY_AS);
198        // Reads from heap
199        for (address, reads, reads_aux) in izip!(rs_val_f, ctx.reads, &cols.reads_aux,) {
200            for (i, (read, aux)) in zip(reads, reads_aux).enumerate() {
201                self.memory_bridge
202                    .read(
203                        MemoryAddress::new(
204                            e,
205                            address.clone() + AB::Expr::from_usize(i * READ_SIZE),
206                        ),
207                        read,
208                        timestamp_pp(),
209                        aux,
210                    )
211                    .eval(builder, ctx.instruction.is_valid.clone());
212            }
213        }
214
215        // Writes to heap
216        for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
217            self.memory_bridge
218                .write(
219                    MemoryAddress::new(e, rd_val_f.clone() + AB::Expr::from_usize(i * WRITE_SIZE)),
220                    write,
221                    timestamp_pp(),
222                    aux,
223                )
224                .eval(builder, ctx.instruction.is_valid.clone());
225        }
226
227        self.execution_bridge
228            .execute_and_increment_or_set_pc(
229                ctx.instruction.opcode,
230                [
231                    cols.rd_ptr.into(),
232                    cols.rs_ptr
233                        .first()
234                        .map(|&x| x.into())
235                        .unwrap_or(AB::Expr::ZERO),
236                    cols.rs_ptr
237                        .get(1)
238                        .map(|&x| x.into())
239                        .unwrap_or(AB::Expr::ZERO),
240                    AB::Expr::from_u32(RV32_REGISTER_AS),
241                    e.into(),
242                ],
243                cols.from_state,
244                AB::F::from_usize(timestamp_delta),
245                (DEFAULT_PC_STEP, ctx.to_pc),
246            )
247            .eval(builder, ctx.instruction.is_valid.clone());
248    }
249
250    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
251        let cols: &Rv32VecHeapAdapterCols<
252            _,
253            NUM_READS,
254            BLOCKS_PER_READ,
255            BLOCKS_PER_WRITE,
256            READ_SIZE,
257            WRITE_SIZE,
258        > = local.borrow();
259        cols.from_state.pc
260    }
261}
262
263// Intermediate type that should not be copied or cloned and should be directly written to
264#[repr(C)]
265#[derive(AlignedBytesBorrow, Debug)]
266pub struct Rv32VecHeapAdapterRecord<
267    const NUM_READS: usize,
268    const BLOCKS_PER_READ: usize,
269    const BLOCKS_PER_WRITE: usize,
270    const READ_SIZE: usize,
271    const WRITE_SIZE: usize,
272> {
273    pub from_pc: u32,
274    pub from_timestamp: u32,
275
276    pub rs_ptrs: [u32; NUM_READS],
277    pub rd_ptr: u32,
278
279    pub rs_vals: [u32; NUM_READS],
280    pub rd_val: u32,
281
282    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
283    pub rd_read_aux: MemoryReadAuxRecord,
284
285    pub reads_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
286    pub writes_aux: [MemoryWriteBytesAuxRecord<WRITE_SIZE>; BLOCKS_PER_WRITE],
287}
288
289#[derive(derive_new::new, Clone, Copy)]
290pub struct Rv32VecHeapAdapterExecutor<
291    const NUM_READS: usize,
292    const BLOCKS_PER_READ: usize,
293    const BLOCKS_PER_WRITE: usize,
294    const READ_SIZE: usize,
295    const WRITE_SIZE: usize,
296> {
297    pointer_max_bits: usize,
298}
299
300#[derive(derive_new::new)]
301pub struct Rv32VecHeapAdapterFiller<
302    const NUM_READS: usize,
303    const BLOCKS_PER_READ: usize,
304    const BLOCKS_PER_WRITE: usize,
305    const READ_SIZE: usize,
306    const WRITE_SIZE: usize,
307> {
308    pointer_max_bits: usize,
309    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
310}
311
312impl<
313        F: PrimeField32,
314        const NUM_READS: usize,
315        const BLOCKS_PER_READ: usize,
316        const BLOCKS_PER_WRITE: usize,
317        const READ_SIZE: usize,
318        const WRITE_SIZE: usize,
319    > AdapterTraceExecutor<F>
320    for Rv32VecHeapAdapterExecutor<
321        NUM_READS,
322        BLOCKS_PER_READ,
323        BLOCKS_PER_WRITE,
324        READ_SIZE,
325        WRITE_SIZE,
326    >
327{
328    const WIDTH: usize = Rv32VecHeapAdapterCols::<
329        F,
330        NUM_READS,
331        BLOCKS_PER_READ,
332        BLOCKS_PER_WRITE,
333        READ_SIZE,
334        WRITE_SIZE,
335    >::width();
336    type ReadData = [[[u8; READ_SIZE]; BLOCKS_PER_READ]; NUM_READS];
337    type WriteData = [[u8; WRITE_SIZE]; BLOCKS_PER_WRITE];
338    type RecordMut<'a> = &'a mut Rv32VecHeapAdapterRecord<
339        NUM_READS,
340        BLOCKS_PER_READ,
341        BLOCKS_PER_WRITE,
342        READ_SIZE,
343        WRITE_SIZE,
344    >;
345
346    #[inline(always)]
347    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
348        record.from_pc = pc;
349        record.from_timestamp = memory.timestamp;
350    }
351
352    fn read(
353        &self,
354        memory: &mut TracingMemory,
355        instruction: &Instruction<F>,
356        record: &mut &mut Rv32VecHeapAdapterRecord<
357            NUM_READS,
358            BLOCKS_PER_READ,
359            BLOCKS_PER_WRITE,
360            READ_SIZE,
361            WRITE_SIZE,
362        >,
363    ) -> Self::ReadData {
364        let &Instruction { a, b, c, d, e, .. } = instruction;
365
366        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
367        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
368
369        // Read register values
370        record.rs_vals = from_fn(|i| {
371            record.rs_ptrs[i] = if i == 0 { b } else { c }.as_canonical_u32();
372            u32::from_le_bytes(tracing_read(
373                memory,
374                RV32_REGISTER_AS,
375                record.rs_ptrs[i],
376                &mut record.rs_read_aux[i].prev_timestamp,
377            ))
378        });
379
380        record.rd_ptr = a.as_canonical_u32();
381        record.rd_val = u32::from_le_bytes(tracing_read(
382            memory,
383            RV32_REGISTER_AS,
384            a.as_canonical_u32(),
385            &mut record.rd_read_aux.prev_timestamp,
386        ));
387
388        // Read memory values
389        from_fn(|i| {
390            debug_assert!(
391                (record.rs_vals[i] + (READ_SIZE * BLOCKS_PER_READ - 1) as u32)
392                    < (1 << self.pointer_max_bits) as u32
393            );
394            from_fn(|j| {
395                tracing_read(
396                    memory,
397                    RV32_MEMORY_AS,
398                    record.rs_vals[i] + (j * READ_SIZE) as u32,
399                    &mut record.reads_aux[i][j].prev_timestamp,
400                )
401            })
402        })
403    }
404
405    fn write(
406        &self,
407        memory: &mut TracingMemory,
408        instruction: &Instruction<F>,
409        data: Self::WriteData,
410        record: &mut &mut Rv32VecHeapAdapterRecord<
411            NUM_READS,
412            BLOCKS_PER_READ,
413            BLOCKS_PER_WRITE,
414            READ_SIZE,
415            WRITE_SIZE,
416        >,
417    ) {
418        debug_assert_eq!(instruction.e.as_canonical_u32(), RV32_MEMORY_AS);
419
420        debug_assert!(
421            record.rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1
422                < (1 << self.pointer_max_bits)
423        );
424
425        #[allow(clippy::needless_range_loop)]
426        for i in 0..BLOCKS_PER_WRITE {
427            tracing_write(
428                memory,
429                RV32_MEMORY_AS,
430                record.rd_val + (i * WRITE_SIZE) as u32,
431                data[i],
432                &mut record.writes_aux[i].prev_timestamp,
433                &mut record.writes_aux[i].prev_data,
434            );
435        }
436    }
437}
438
439impl<
440        F: PrimeField32,
441        const NUM_READS: usize,
442        const BLOCKS_PER_READ: usize,
443        const BLOCKS_PER_WRITE: usize,
444        const READ_SIZE: usize,
445        const WRITE_SIZE: usize,
446    > AdapterTraceFiller<F>
447    for Rv32VecHeapAdapterFiller<
448        NUM_READS,
449        BLOCKS_PER_READ,
450        BLOCKS_PER_WRITE,
451        READ_SIZE,
452        WRITE_SIZE,
453    >
454{
455    const WIDTH: usize = Rv32VecHeapAdapterCols::<
456        F,
457        NUM_READS,
458        BLOCKS_PER_READ,
459        BLOCKS_PER_WRITE,
460        READ_SIZE,
461        WRITE_SIZE,
462    >::width();
463
464    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
465        // SAFETY:
466        // - caller ensures `adapter_row` contains a valid record representation that was previously
467        //   written by the executor
468        let record: &Rv32VecHeapAdapterRecord<
469            NUM_READS,
470            BLOCKS_PER_READ,
471            BLOCKS_PER_WRITE,
472            READ_SIZE,
473            WRITE_SIZE,
474        > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
475
476        let cols: &mut Rv32VecHeapAdapterCols<
477            F,
478            NUM_READS,
479            BLOCKS_PER_READ,
480            BLOCKS_PER_WRITE,
481            READ_SIZE,
482            WRITE_SIZE,
483        > = adapter_row.borrow_mut();
484
485        // Range checks:
486        // **NOTE**: Must do the range checks before overwriting the records
487        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
488        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
489        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
490        if NUM_READS > 1 {
491            self.bitwise_lookup_chip.request_range(
492                (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
493                (record.rs_vals[1] >> MSL_SHIFT) << limb_shift_bits,
494            );
495            self.bitwise_lookup_chip.request_range(
496                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
497                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
498            );
499        } else {
500            self.bitwise_lookup_chip.request_range(
501                (record.rs_vals[0] >> MSL_SHIFT) << limb_shift_bits,
502                (record.rd_val >> MSL_SHIFT) << limb_shift_bits,
503            );
504        }
505
506        let timestamp_delta = NUM_READS + 1 + NUM_READS * BLOCKS_PER_READ + BLOCKS_PER_WRITE;
507        let mut timestamp = record.from_timestamp + timestamp_delta as u32;
508        let mut timestamp_mm = || {
509            timestamp -= 1;
510            timestamp
511        };
512
513        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
514        record
515            .writes_aux
516            .iter()
517            .rev()
518            .zip(cols.writes_aux.iter_mut().rev())
519            .for_each(|(write, cols_write)| {
520                cols_write.set_prev_data(write.prev_data.map(F::from_u8));
521                mem_helper.fill(write.prev_timestamp, timestamp_mm(), cols_write.as_mut());
522            });
523
524        record
525            .reads_aux
526            .iter()
527            .zip(cols.reads_aux.iter_mut())
528            .rev()
529            .for_each(|(reads, cols_reads)| {
530                reads
531                    .iter()
532                    .zip(cols_reads.iter_mut())
533                    .rev()
534                    .for_each(|(read, cols_read)| {
535                        mem_helper.fill(read.prev_timestamp, timestamp_mm(), cols_read.as_mut());
536                    });
537            });
538
539        mem_helper.fill(
540            record.rd_read_aux.prev_timestamp,
541            timestamp_mm(),
542            cols.rd_read_aux.as_mut(),
543        );
544
545        record
546            .rs_read_aux
547            .iter()
548            .zip(cols.rs_read_aux.iter_mut())
549            .rev()
550            .for_each(|(aux, cols_aux)| {
551                mem_helper.fill(aux.prev_timestamp, timestamp_mm(), cols_aux.as_mut());
552            });
553
554        cols.rd_val = record.rd_val.to_le_bytes().map(F::from_u8);
555        cols.rs_val
556            .iter_mut()
557            .rev()
558            .zip(record.rs_vals.iter().rev())
559            .for_each(|(cols_val, val)| {
560                *cols_val = val.to_le_bytes().map(F::from_u8);
561            });
562        cols.rd_ptr = F::from_u32(record.rd_ptr);
563        cols.rs_ptr
564            .iter_mut()
565            .rev()
566            .zip(record.rs_ptrs.iter().rev())
567            .for_each(|(cols_ptr, ptr)| {
568                *cols_ptr = F::from_u32(*ptr);
569            });
570        cols.from_state.timestamp = F::from_u32(record.from_timestamp);
571        cols.from_state.pc = F::from_u32(record.from_pc);
572    }
573}