openvm_rv32_adapters/
vec_heap_two_reads.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    iter::zip,
5    marker::PhantomData,
6};
7
8use itertools::izip;
9use openvm_circuit::{
10    arch::{
11        AdapterAirContext, AdapterRuntimeContext, ExecutionBridge, ExecutionBus, ExecutionState,
12        Result, VecHeapTwoReadsAdapterInterface, VmAdapterAir, VmAdapterChip, VmAdapterInterface,
13    },
14    system::{
15        memory::{
16            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17            MemoryAddress, MemoryController, OfflineMemory, RecordId,
18        },
19        program::ProgramBus,
20    },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27    instruction::Instruction,
28    program::DEFAULT_PC_STEP,
29    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32    abstract_compose, read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::BaseAir,
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_with::serde_as;
41
42/// This adapter reads from 2 pointers and writes to 1 pointer.
43/// * The data is read from the heap (address space 2), and the pointers
44///   are read from registers (address space 1).
45/// * Reads take the form of `BLOCKS_PER_READX` consecutive reads of size
46///   `READ_SIZE` from the heap, starting from the addresses in `rs[X]`
47/// * NOTE that the two reads can read different numbers of blocks.
48/// * Writes take the form of `BLOCKS_PER_WRITE` consecutive writes of
49///   size `WRITE_SIZE` to the heap, starting from the address in `rd`.
50pub struct Rv32VecHeapTwoReadsAdapterChip<
51    F: Field,
52    const BLOCKS_PER_READ1: usize,
53    const BLOCKS_PER_READ2: usize,
54    const BLOCKS_PER_WRITE: usize,
55    const READ_SIZE: usize,
56    const WRITE_SIZE: usize,
57> {
58    pub air: Rv32VecHeapTwoReadsAdapterAir<
59        BLOCKS_PER_READ1,
60        BLOCKS_PER_READ2,
61        BLOCKS_PER_WRITE,
62        READ_SIZE,
63        WRITE_SIZE,
64    >,
65    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
66    _marker: PhantomData<F>,
67}
68
69impl<
70        F: PrimeField32,
71        const BLOCKS_PER_READ1: usize,
72        const BLOCKS_PER_READ2: usize,
73        const BLOCKS_PER_WRITE: usize,
74        const READ_SIZE: usize,
75        const WRITE_SIZE: usize,
76    >
77    Rv32VecHeapTwoReadsAdapterChip<
78        F,
79        BLOCKS_PER_READ1,
80        BLOCKS_PER_READ2,
81        BLOCKS_PER_WRITE,
82        READ_SIZE,
83        WRITE_SIZE,
84    >
85{
86    pub fn new(
87        execution_bus: ExecutionBus,
88        program_bus: ProgramBus,
89        memory_bridge: MemoryBridge,
90        address_bits: usize,
91        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
92    ) -> Self {
93        assert!(
94            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
95            "address_bits={address_bits} needs to be large enough for high limb range check"
96        );
97        Self {
98            air: Rv32VecHeapTwoReadsAdapterAir {
99                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
100                memory_bridge,
101                bus: bitwise_lookup_chip.bus(),
102                address_bits,
103            },
104            bitwise_lookup_chip,
105            _marker: PhantomData,
106        }
107    }
108}
109
110#[repr(C)]
111#[serde_as]
112#[derive(Clone, Debug, Serialize, Deserialize)]
113#[serde(bound = "F: Field")]
114pub struct Rv32VecHeapTwoReadsReadRecord<
115    F: Field,
116    const BLOCKS_PER_READ1: usize,
117    const BLOCKS_PER_READ2: usize,
118    const READ_SIZE: usize,
119> {
120    /// Read register value from address space e=1
121    pub rs1: RecordId,
122    pub rs2: RecordId,
123    /// Read register value from address space d=1
124    pub rd: RecordId,
125
126    pub rd_val: F,
127
128    #[serde_as(as = "[_; BLOCKS_PER_READ1]")]
129    pub reads1: [RecordId; BLOCKS_PER_READ1],
130    #[serde_as(as = "[_; BLOCKS_PER_READ2]")]
131    pub reads2: [RecordId; BLOCKS_PER_READ2],
132}
133
134#[repr(C)]
135#[serde_as]
136#[derive(Clone, Debug, Serialize, Deserialize)]
137pub struct Rv32VecHeapTwoReadsWriteRecord<const BLOCKS_PER_WRITE: usize, const WRITE_SIZE: usize> {
138    pub from_state: ExecutionState<u32>,
139    #[serde_as(as = "[_; BLOCKS_PER_WRITE]")]
140    pub writes: [RecordId; BLOCKS_PER_WRITE],
141}
142
143#[repr(C)]
144#[derive(AlignedBorrow)]
145pub struct Rv32VecHeapTwoReadsAdapterCols<
146    T,
147    const BLOCKS_PER_READ1: usize,
148    const BLOCKS_PER_READ2: usize,
149    const BLOCKS_PER_WRITE: usize,
150    const READ_SIZE: usize,
151    const WRITE_SIZE: usize,
152> {
153    pub from_state: ExecutionState<T>,
154
155    pub rs1_ptr: T,
156    pub rs2_ptr: T,
157    pub rd_ptr: T,
158
159    pub rs1_val: [T; RV32_REGISTER_NUM_LIMBS],
160    pub rs2_val: [T; RV32_REGISTER_NUM_LIMBS],
161    pub rd_val: [T; RV32_REGISTER_NUM_LIMBS],
162
163    pub rs1_read_aux: MemoryReadAuxCols<T>,
164    pub rs2_read_aux: MemoryReadAuxCols<T>,
165    pub rd_read_aux: MemoryReadAuxCols<T>,
166
167    pub reads1_aux: [MemoryReadAuxCols<T>; BLOCKS_PER_READ1],
168    pub reads2_aux: [MemoryReadAuxCols<T>; BLOCKS_PER_READ2],
169    pub writes_aux: [MemoryWriteAuxCols<T, WRITE_SIZE>; BLOCKS_PER_WRITE],
170}
171
172#[allow(dead_code)]
173#[derive(Clone, Copy, Debug, derive_new::new)]
174pub struct Rv32VecHeapTwoReadsAdapterAir<
175    const BLOCKS_PER_READ1: usize,
176    const BLOCKS_PER_READ2: usize,
177    const BLOCKS_PER_WRITE: usize,
178    const READ_SIZE: usize,
179    const WRITE_SIZE: usize,
180> {
181    pub(super) execution_bridge: ExecutionBridge,
182    pub(super) memory_bridge: MemoryBridge,
183    pub bus: BitwiseOperationLookupBus,
184    /// The max number of bits for an address in memory
185    address_bits: usize,
186}
187
188impl<
189        F: Field,
190        const BLOCKS_PER_READ1: usize,
191        const BLOCKS_PER_READ2: usize,
192        const BLOCKS_PER_WRITE: usize,
193        const READ_SIZE: usize,
194        const WRITE_SIZE: usize,
195    > BaseAir<F>
196    for Rv32VecHeapTwoReadsAdapterAir<
197        BLOCKS_PER_READ1,
198        BLOCKS_PER_READ2,
199        BLOCKS_PER_WRITE,
200        READ_SIZE,
201        WRITE_SIZE,
202    >
203{
204    fn width(&self) -> usize {
205        Rv32VecHeapTwoReadsAdapterCols::<
206            F,
207            BLOCKS_PER_READ1,
208            BLOCKS_PER_READ2,
209            BLOCKS_PER_WRITE,
210            READ_SIZE,
211            WRITE_SIZE,
212        >::width()
213    }
214}
215
216impl<
217        AB: InteractionBuilder,
218        const BLOCKS_PER_READ1: usize,
219        const BLOCKS_PER_READ2: usize,
220        const BLOCKS_PER_WRITE: usize,
221        const READ_SIZE: usize,
222        const WRITE_SIZE: usize,
223    > VmAdapterAir<AB>
224    for Rv32VecHeapTwoReadsAdapterAir<
225        BLOCKS_PER_READ1,
226        BLOCKS_PER_READ2,
227        BLOCKS_PER_WRITE,
228        READ_SIZE,
229        WRITE_SIZE,
230    >
231{
232    type Interface = VecHeapTwoReadsAdapterInterface<
233        AB::Expr,
234        BLOCKS_PER_READ1,
235        BLOCKS_PER_READ2,
236        BLOCKS_PER_WRITE,
237        READ_SIZE,
238        WRITE_SIZE,
239    >;
240
241    fn eval(
242        &self,
243        builder: &mut AB,
244        local: &[AB::Var],
245        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
246    ) {
247        let cols: &Rv32VecHeapTwoReadsAdapterCols<
248            _,
249            BLOCKS_PER_READ1,
250            BLOCKS_PER_READ2,
251            BLOCKS_PER_WRITE,
252            READ_SIZE,
253            WRITE_SIZE,
254        > = local.borrow();
255        let timestamp = cols.from_state.timestamp;
256        let mut timestamp_delta: usize = 0;
257        let mut timestamp_pp = || {
258            timestamp_delta += 1;
259            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
260        };
261
262        let ptrs = [cols.rs1_ptr, cols.rs2_ptr, cols.rd_ptr];
263        let vals = [cols.rs1_val, cols.rs2_val, cols.rd_val];
264        let auxs = [&cols.rs1_read_aux, &cols.rs2_read_aux, &cols.rd_read_aux];
265
266        // Read register values for rs1, rs2, rd
267        for (ptr, val, aux) in izip!(ptrs, vals, auxs) {
268            self.memory_bridge
269                .read(
270                    MemoryAddress::new(AB::F::from_canonical_u32(RV32_REGISTER_AS), ptr),
271                    val,
272                    timestamp_pp(),
273                    aux,
274                )
275                .eval(builder, ctx.instruction.is_valid.clone());
276        }
277
278        // Range checks: see Rv32VecHeaperAdapterAir
279        let need_range_check = [&cols.rs1_val, &cols.rs2_val, &cols.rd_val, &cols.rd_val]
280            .map(|val| val[RV32_REGISTER_NUM_LIMBS - 1]);
281
282        // range checks constrain to RV32_CELL_BITS bits, so we need to shift the limbs to constrain the correct amount of bits
283        let limb_shift = AB::F::from_canonical_usize(
284            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
285        );
286
287        // Note: since limbs are read from memory we already know that limb[i] < 2^RV32_CELL_BITS
288        //       thus range checking limb[i] * shift < 2^RV32_CELL_BITS, gives us that
289        //       limb[i] < 2^(addr_bits - (RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1)))
290        for pair in need_range_check.chunks_exact(2) {
291            self.bus
292                .send_range(pair[0] * limb_shift, pair[1] * limb_shift)
293                .eval(builder, ctx.instruction.is_valid.clone());
294        }
295
296        let rd_val_f: AB::Expr = abstract_compose(cols.rd_val);
297        let rs1_val_f: AB::Expr = abstract_compose(cols.rs1_val);
298        let rs2_val_f: AB::Expr = abstract_compose(cols.rs2_val);
299
300        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
301        // Reads from heap
302        for (i, (read, aux)) in zip(ctx.reads.0, &cols.reads1_aux).enumerate() {
303            self.memory_bridge
304                .read(
305                    MemoryAddress::new(
306                        e,
307                        rs1_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
308                    ),
309                    read,
310                    timestamp_pp(),
311                    aux,
312                )
313                .eval(builder, ctx.instruction.is_valid.clone());
314        }
315        for (i, (read, aux)) in zip(ctx.reads.1, &cols.reads2_aux).enumerate() {
316            self.memory_bridge
317                .read(
318                    MemoryAddress::new(
319                        e,
320                        rs2_val_f.clone() + AB::Expr::from_canonical_usize(i * READ_SIZE),
321                    ),
322                    read,
323                    timestamp_pp(),
324                    aux,
325                )
326                .eval(builder, ctx.instruction.is_valid.clone());
327        }
328
329        // Writes to heap
330        for (i, (write, aux)) in zip(ctx.writes, &cols.writes_aux).enumerate() {
331            self.memory_bridge
332                .write(
333                    MemoryAddress::new(
334                        e,
335                        rd_val_f.clone() + AB::Expr::from_canonical_usize(i * WRITE_SIZE),
336                    ),
337                    write,
338                    timestamp_pp(),
339                    aux,
340                )
341                .eval(builder, ctx.instruction.is_valid.clone());
342        }
343
344        self.execution_bridge
345            .execute_and_increment_or_set_pc(
346                ctx.instruction.opcode,
347                [
348                    cols.rd_ptr.into(),
349                    cols.rs1_ptr.into(),
350                    cols.rs2_ptr.into(),
351                    AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
352                    e.into(),
353                ],
354                cols.from_state,
355                AB::F::from_canonical_usize(timestamp_delta),
356                (DEFAULT_PC_STEP, ctx.to_pc),
357            )
358            .eval(builder, ctx.instruction.is_valid.clone());
359    }
360
361    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
362        let cols: &Rv32VecHeapTwoReadsAdapterCols<
363            _,
364            BLOCKS_PER_READ1,
365            BLOCKS_PER_READ2,
366            BLOCKS_PER_WRITE,
367            READ_SIZE,
368            WRITE_SIZE,
369        > = local.borrow();
370        cols.from_state.pc
371    }
372}
373
374impl<
375        F: PrimeField32,
376        const BLOCKS_PER_READ1: usize,
377        const BLOCKS_PER_READ2: usize,
378        const BLOCKS_PER_WRITE: usize,
379        const READ_SIZE: usize,
380        const WRITE_SIZE: usize,
381    > VmAdapterChip<F>
382    for Rv32VecHeapTwoReadsAdapterChip<
383        F,
384        BLOCKS_PER_READ1,
385        BLOCKS_PER_READ2,
386        BLOCKS_PER_WRITE,
387        READ_SIZE,
388        WRITE_SIZE,
389    >
390{
391    type ReadRecord =
392        Rv32VecHeapTwoReadsReadRecord<F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, READ_SIZE>;
393    type WriteRecord = Rv32VecHeapTwoReadsWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>;
394    type Air = Rv32VecHeapTwoReadsAdapterAir<
395        BLOCKS_PER_READ1,
396        BLOCKS_PER_READ2,
397        BLOCKS_PER_WRITE,
398        READ_SIZE,
399        WRITE_SIZE,
400    >;
401    type Interface = VecHeapTwoReadsAdapterInterface<
402        F,
403        BLOCKS_PER_READ1,
404        BLOCKS_PER_READ2,
405        BLOCKS_PER_WRITE,
406        READ_SIZE,
407        WRITE_SIZE,
408    >;
409
410    fn preprocess(
411        &mut self,
412        memory: &mut MemoryController<F>,
413        instruction: &Instruction<F>,
414    ) -> Result<(
415        <Self::Interface as VmAdapterInterface<F>>::Reads,
416        Self::ReadRecord,
417    )> {
418        let Instruction { a, b, c, d, e, .. } = *instruction;
419
420        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
421        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
422
423        let (rs1_record, rs1_val) = read_rv32_register(memory, d, b);
424        let (rs2_record, rs2_val) = read_rv32_register(memory, d, c);
425        let (rd_record, rd_val) = read_rv32_register(memory, d, a);
426
427        assert!(rs1_val as usize + READ_SIZE * BLOCKS_PER_READ1 - 1 < (1 << self.air.address_bits));
428        let read1_records = from_fn(|i| {
429            memory.read::<READ_SIZE>(e, F::from_canonical_u32(rs1_val + (i * READ_SIZE) as u32))
430        });
431        let read1_data = read1_records.map(|r| r.1);
432        assert!(rs2_val as usize + READ_SIZE * BLOCKS_PER_READ2 - 1 < (1 << self.air.address_bits));
433        let read2_records = from_fn(|i| {
434            memory.read::<READ_SIZE>(e, F::from_canonical_u32(rs2_val + (i * READ_SIZE) as u32))
435        });
436        let read2_data = read2_records.map(|r| r.1);
437        assert!(rd_val as usize + WRITE_SIZE * BLOCKS_PER_WRITE - 1 < (1 << self.air.address_bits));
438
439        let record = Rv32VecHeapTwoReadsReadRecord {
440            rs1: rs1_record,
441            rs2: rs2_record,
442            rd: rd_record,
443            rd_val: F::from_canonical_u32(rd_val),
444            reads1: read1_records.map(|r| r.0),
445            reads2: read2_records.map(|r| r.0),
446        };
447
448        Ok(((read1_data, read2_data), record))
449    }
450
451    fn postprocess(
452        &mut self,
453        memory: &mut MemoryController<F>,
454        instruction: &Instruction<F>,
455        from_state: ExecutionState<u32>,
456        output: AdapterRuntimeContext<F, Self::Interface>,
457        read_record: &Self::ReadRecord,
458    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
459        let e = instruction.e;
460        let mut i = 0;
461        let writes = output.writes.map(|write| {
462            let (record_id, _) = memory.write(
463                e,
464                read_record.rd_val + F::from_canonical_u32((i * WRITE_SIZE) as u32),
465                write,
466            );
467            i += 1;
468            record_id
469        });
470
471        Ok((
472            ExecutionState {
473                pc: from_state.pc + DEFAULT_PC_STEP,
474                timestamp: memory.timestamp(),
475            },
476            Self::WriteRecord { from_state, writes },
477        ))
478    }
479
480    fn generate_trace_row(
481        &self,
482        row_slice: &mut [F],
483        read_record: Self::ReadRecord,
484        write_record: Self::WriteRecord,
485        memory: &OfflineMemory<F>,
486    ) {
487        vec_heap_two_reads_generate_trace_row_impl(
488            row_slice,
489            &read_record,
490            &write_record,
491            self.bitwise_lookup_chip.clone(),
492            self.air.address_bits,
493            memory,
494        )
495    }
496
497    fn air(&self) -> &Self::Air {
498        &self.air
499    }
500}
501
502pub(super) fn vec_heap_two_reads_generate_trace_row_impl<
503    F: PrimeField32,
504    const BLOCKS_PER_READ1: usize,
505    const BLOCKS_PER_READ2: usize,
506    const BLOCKS_PER_WRITE: usize,
507    const READ_SIZE: usize,
508    const WRITE_SIZE: usize,
509>(
510    row_slice: &mut [F],
511    read_record: &Rv32VecHeapTwoReadsReadRecord<F, BLOCKS_PER_READ1, BLOCKS_PER_READ2, READ_SIZE>,
512    write_record: &Rv32VecHeapTwoReadsWriteRecord<BLOCKS_PER_WRITE, WRITE_SIZE>,
513    bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
514    address_bits: usize,
515    memory: &OfflineMemory<F>,
516) {
517    let aux_cols_factory = memory.aux_cols_factory();
518    let row_slice: &mut Rv32VecHeapTwoReadsAdapterCols<
519        F,
520        BLOCKS_PER_READ1,
521        BLOCKS_PER_READ2,
522        BLOCKS_PER_WRITE,
523        READ_SIZE,
524        WRITE_SIZE,
525    > = row_slice.borrow_mut();
526    row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
527
528    let rd = memory.record_by_id(read_record.rd);
529    let rs1 = memory.record_by_id(read_record.rs1);
530    let rs2 = memory.record_by_id(read_record.rs2);
531
532    row_slice.rd_ptr = rd.pointer;
533    row_slice.rs1_ptr = rs1.pointer;
534    row_slice.rs2_ptr = rs2.pointer;
535
536    row_slice.rd_val.copy_from_slice(rd.data_slice());
537    row_slice.rs1_val.copy_from_slice(rs1.data_slice());
538    row_slice.rs2_val.copy_from_slice(rs2.data_slice());
539
540    aux_cols_factory.generate_read_aux(rs1, &mut row_slice.rs1_read_aux);
541    aux_cols_factory.generate_read_aux(rs2, &mut row_slice.rs2_read_aux);
542    aux_cols_factory.generate_read_aux(rd, &mut row_slice.rd_read_aux);
543
544    for (i, r) in read_record.reads1.iter().enumerate() {
545        let record = memory.record_by_id(*r);
546        aux_cols_factory.generate_read_aux(record, &mut row_slice.reads1_aux[i]);
547    }
548
549    for (i, r) in read_record.reads2.iter().enumerate() {
550        let record = memory.record_by_id(*r);
551        aux_cols_factory.generate_read_aux(record, &mut row_slice.reads2_aux[i]);
552    }
553
554    for (i, w) in write_record.writes.iter().enumerate() {
555        let record = memory.record_by_id(*w);
556        aux_cols_factory.generate_write_aux(record, &mut row_slice.writes_aux[i]);
557    }
558    // Range checks:
559    let need_range_check = [
560        &read_record.rs1,
561        &read_record.rs2,
562        &read_record.rd,
563        &read_record.rd,
564    ]
565    .map(|record| {
566        memory
567            .record_by_id(*record)
568            .data_at(RV32_REGISTER_NUM_LIMBS - 1)
569            .as_canonical_u32()
570    });
571    debug_assert!(address_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
572    let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits;
573    for pair in need_range_check.chunks_exact(2) {
574        bitwise_lookup_chip.request_range(pair[0] << limb_shift_bits, pair[1] << limb_shift_bits);
575    }
576}