openvm_rv32_adapters/
eq_mod.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8    arch::{
9        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
11    },
12    system::memory::{
13        offline_checker::{
14            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
15            MemoryWriteBytesAuxRecord,
16        },
17        online::TracingMemory,
18        MemoryAddress, MemoryAuxColsFactory,
19    },
20};
21use openvm_circuit_primitives::{
22    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
23    AlignedBytesBorrow,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27    instruction::Instruction,
28    program::DEFAULT_PC_STEP,
29    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32    tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::BaseAir,
37    p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
38};
39
40/// This adapter reads from NUM_READS <= 2 pointers and writes to a register.
41/// * The data is read from the heap (address space 2), and the pointers are read from registers
42///   (address space 1).
43/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size `BLOCK_SIZE` from the heap,
44///   starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
45/// * Writes are to 32-bit register rd.
46#[repr(C)]
47#[derive(AlignedBorrow, Debug)]
48pub struct Rv32IsEqualModAdapterCols<
49    T,
50    const NUM_READS: usize,
51    const BLOCKS_PER_READ: usize,
52    const BLOCK_SIZE: usize,
53> {
54    pub from_state: ExecutionState<T>,
55
56    pub rs_ptr: [T; NUM_READS],
57    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
58    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
59    pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
60
61    pub rd_ptr: T,
62    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
63}
64
65#[allow(dead_code)]
66#[derive(Clone, Copy, Debug, derive_new::new)]
67pub struct Rv32IsEqualModAdapterAir<
68    const NUM_READS: usize,
69    const BLOCKS_PER_READ: usize,
70    const BLOCK_SIZE: usize,
71    const TOTAL_READ_SIZE: usize,
72> {
73    pub(super) execution_bridge: ExecutionBridge,
74    pub(super) memory_bridge: MemoryBridge,
75    pub bus: BitwiseOperationLookupBus,
76    address_bits: usize,
77}
78
79impl<
80        F: Field,
81        const NUM_READS: usize,
82        const BLOCKS_PER_READ: usize,
83        const BLOCK_SIZE: usize,
84        const TOTAL_READ_SIZE: usize,
85    > BaseAir<F>
86    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
87{
88    fn width(&self) -> usize {
89        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
90    }
91}
92
93impl<
94        AB: InteractionBuilder,
95        const NUM_READS: usize,
96        const BLOCKS_PER_READ: usize,
97        const BLOCK_SIZE: usize,
98        const TOTAL_READ_SIZE: usize,
99    > VmAdapterAir<AB>
100    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
101{
102    type Interface = BasicAdapterInterface<
103        AB::Expr,
104        MinimalInstruction<AB::Expr>,
105        NUM_READS,
106        1,
107        TOTAL_READ_SIZE,
108        RV32_REGISTER_NUM_LIMBS,
109    >;
110
111    fn eval(
112        &self,
113        builder: &mut AB,
114        local: &[AB::Var],
115        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
116    ) {
117        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
118            local.borrow();
119        let timestamp = cols.from_state.timestamp;
120        let mut timestamp_delta: usize = 0;
121        let mut timestamp_pp = || {
122            timestamp_delta += 1;
123            timestamp + AB::F::from_usize(timestamp_delta - 1)
124        };
125
126        // Address spaces
127        let d = AB::F::from_u32(RV32_REGISTER_AS);
128        let e = AB::F::from_u32(RV32_MEMORY_AS);
129
130        // Read register values for rs
131        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
132            self.memory_bridge
133                .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
134                .eval(builder, ctx.instruction.is_valid.clone());
135        }
136
137        // Compose the u32 register value into single field element, with
138        // a range check on the highest limb.
139        let rs_val_f = cols.rs_val.map(|decomp| {
140            decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
141                acc * AB::Expr::from_usize(1 << RV32_CELL_BITS) + limb
142            })
143        });
144
145        let need_range_check: [_; 2] = from_fn(|i| {
146            if i < NUM_READS {
147                cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
148            } else {
149                AB::Expr::ZERO
150            }
151        });
152
153        let limb_shift =
154            AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
155
156        self.bus
157            .send_range(
158                need_range_check[0].clone() * limb_shift,
159                need_range_check[1].clone() * limb_shift,
160            )
161            .eval(builder, ctx.instruction.is_valid.clone());
162
163        // Reads from heap
164        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
165        let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
166            ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
167                let mut r_it = r.into_iter();
168                from_fn(|_| from_fn(|_| r_it.next().unwrap()))
169            });
170        let block_ptr_offset: [_; BLOCKS_PER_READ] = from_fn(|i| AB::F::from_usize(i * BLOCK_SIZE));
171
172        for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
173            for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
174                self.memory_bridge
175                    .read(
176                        MemoryAddress::new(e, ptr.clone() + offset),
177                        data,
178                        timestamp_pp(),
179                        aux,
180                    )
181                    .eval(builder, ctx.instruction.is_valid.clone());
182            }
183        }
184
185        // Write to rd register
186        self.memory_bridge
187            .write(
188                MemoryAddress::new(d, cols.rd_ptr),
189                ctx.writes[0].clone(),
190                timestamp_pp(),
191                &cols.writes_aux,
192            )
193            .eval(builder, ctx.instruction.is_valid.clone());
194
195        self.execution_bridge
196            .execute_and_increment_or_set_pc(
197                ctx.instruction.opcode,
198                [
199                    cols.rd_ptr.into(),
200                    cols.rs_ptr
201                        .first()
202                        .map(|&x| x.into())
203                        .unwrap_or(AB::Expr::ZERO),
204                    cols.rs_ptr
205                        .get(1)
206                        .map(|&x| x.into())
207                        .unwrap_or(AB::Expr::ZERO),
208                    d.into(),
209                    e.into(),
210                ],
211                cols.from_state,
212                AB::F::from_usize(timestamp_delta),
213                (DEFAULT_PC_STEP, ctx.to_pc),
214            )
215            .eval(builder, ctx.instruction.is_valid.clone());
216    }
217
218    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
219        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
220            local.borrow();
221        cols.from_state.pc
222    }
223}
224
225#[repr(C)]
226#[derive(AlignedBytesBorrow, Debug)]
227pub struct Rv32IsEqualModAdapterRecord<
228    const NUM_READS: usize,
229    const BLOCKS_PER_READ: usize,
230    const BLOCK_SIZE: usize,
231    const TOTAL_READ_SIZE: usize,
232> {
233    pub from_pc: u32,
234    pub timestamp: u32,
235
236    pub rs_ptr: [u32; NUM_READS],
237    pub rs_val: [u32; NUM_READS],
238    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
239    pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
240
241    pub rd_ptr: u32,
242    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
243}
244
245#[derive(Clone, Copy)]
246pub struct Rv32IsEqualModAdapterExecutor<
247    const NUM_READS: usize,
248    const BLOCKS_PER_READ: usize,
249    const BLOCK_SIZE: usize,
250    const TOTAL_READ_SIZE: usize,
251> {
252    pointer_max_bits: usize,
253}
254
255#[derive(derive_new::new)]
256pub struct Rv32IsEqualModAdapterFiller<
257    const NUM_READS: usize,
258    const BLOCKS_PER_READ: usize,
259    const BLOCK_SIZE: usize,
260    const TOTAL_READ_SIZE: usize,
261> {
262    pointer_max_bits: usize,
263    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
264}
265
266impl<
267        const NUM_READS: usize,
268        const BLOCKS_PER_READ: usize,
269        const BLOCK_SIZE: usize,
270        const TOTAL_READ_SIZE: usize,
271    > Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
272{
273    pub fn new(pointer_max_bits: usize) -> Self {
274        assert!(NUM_READS <= 2);
275        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
276        assert!(
277            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
278            "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
279        );
280        Self { pointer_max_bits }
281    }
282}
283
284impl<
285        F: PrimeField32,
286        const NUM_READS: usize,
287        const BLOCKS_PER_READ: usize,
288        const BLOCK_SIZE: usize,
289        const TOTAL_READ_SIZE: usize,
290    > AdapterTraceExecutor<F>
291    for Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
292where
293    F: PrimeField32,
294{
295    const WIDTH: usize =
296        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
297    type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS];
298    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
299    type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord<
300        NUM_READS,
301        BLOCKS_PER_READ,
302        BLOCK_SIZE,
303        TOTAL_READ_SIZE,
304    >;
305
306    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
307        record.from_pc = pc;
308        record.timestamp = memory.timestamp;
309    }
310
311    fn read(
312        &self,
313        memory: &mut TracingMemory,
314        instruction: &Instruction<F>,
315        record: &mut Self::RecordMut<'_>,
316    ) -> Self::ReadData {
317        let Instruction { b, c, d, e, .. } = *instruction;
318
319        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
320        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
321
322        // Read register values
323        record.rs_val = from_fn(|i| {
324            record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32();
325
326            u32::from_le_bytes(tracing_read(
327                memory,
328                RV32_REGISTER_AS,
329                record.rs_ptr[i],
330                &mut record.rs_read_aux[i].prev_timestamp,
331            ))
332        });
333
334        // Read memory values
335        from_fn(|i| {
336            debug_assert!(
337                record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)
338            );
339            from_fn::<_, BLOCKS_PER_READ, _>(|j| {
340                tracing_read::<BLOCK_SIZE>(
341                    memory,
342                    RV32_MEMORY_AS,
343                    record.rs_val[i] + (j * BLOCK_SIZE) as u32,
344                    &mut record.heap_read_aux[i][j].prev_timestamp,
345                )
346            })
347            .concat()
348            .try_into()
349            .unwrap()
350        })
351    }
352
353    fn write(
354        &self,
355        memory: &mut TracingMemory,
356        instruction: &Instruction<F>,
357        data: Self::WriteData,
358        record: &mut Self::RecordMut<'_>,
359    ) {
360        let Instruction { a, .. } = *instruction;
361        record.rd_ptr = a.as_canonical_u32();
362        tracing_write(
363            memory,
364            RV32_REGISTER_AS,
365            record.rd_ptr,
366            data,
367            &mut record.writes_aux.prev_timestamp,
368            &mut record.writes_aux.prev_data,
369        );
370    }
371}
372
373impl<
374        F: PrimeField32,
375        const NUM_READS: usize,
376        const BLOCKS_PER_READ: usize,
377        const BLOCK_SIZE: usize,
378        const TOTAL_READ_SIZE: usize,
379    > AdapterTraceFiller<F>
380    for Rv32IsEqualModAdapterFiller<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
381{
382    const WIDTH: usize =
383        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
384
385    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
386        // SAFETY:
387        // - caller ensures `adapter_row` contains a valid record representation that was previously
388        //   written by the executor
389        let record: &Rv32IsEqualModAdapterRecord<
390            NUM_READS,
391            BLOCKS_PER_READ,
392            BLOCK_SIZE,
393            TOTAL_READ_SIZE,
394        > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
395
396        let cols: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
397            adapter_row.borrow_mut();
398
399        let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1;
400        let mut timestamp_mm = || {
401            timestamp -= 1;
402            timestamp
403        };
404        // Do range checks before writing anything:
405        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
406        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
407        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
408        self.bitwise_lookup_chip.request_range(
409            (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits,
410            if NUM_READS > 1 {
411                (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits
412            } else {
413                0
414            },
415        );
416        // Writing in reverse order
417        cols.writes_aux
418            .set_prev_data(record.writes_aux.prev_data.map(F::from_u8));
419        mem_helper.fill(
420            record.writes_aux.prev_timestamp,
421            timestamp_mm(),
422            cols.writes_aux.as_mut(),
423        );
424        cols.rd_ptr = F::from_u32(record.rd_ptr);
425
426        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
427        cols.heap_read_aux
428            .iter_mut()
429            .rev()
430            .zip(record.heap_read_aux.iter().rev())
431            .for_each(|(col_reads, record_reads)| {
432                col_reads
433                    .iter_mut()
434                    .rev()
435                    .zip(record_reads.iter().rev())
436                    .for_each(|(col, record)| {
437                        mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
438                    });
439            });
440
441        cols.rs_read_aux
442            .iter_mut()
443            .rev()
444            .zip(record.rs_read_aux.iter().rev())
445            .for_each(|(col, record)| {
446                mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
447            });
448
449        cols.rs_val = record.rs_val.map(|val| val.to_le_bytes().map(F::from_u8));
450        cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_u32(ptr));
451
452        cols.from_state.timestamp = F::from_u32(record.timestamp);
453        cols.from_state.pc = F::from_u32(record.from_pc);
454    }
455}