openvm_rv32_adapters/
eq_mod.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8    arch::{
9        get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10        BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
11    },
12    system::memory::{
13        offline_checker::{
14            MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
15            MemoryWriteBytesAuxRecord,
16        },
17        online::TracingMemory,
18        MemoryAddress, MemoryAuxColsFactory,
19    },
20};
21use openvm_circuit_primitives::{
22    bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
23    AlignedBytesBorrow,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27    instruction::Instruction,
28    program::DEFAULT_PC_STEP,
29    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32    tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::BaseAir,
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39
40/// This adapter reads from NUM_READS <= 2 pointers and writes to a register.
41/// * The data is read from the heap (address space 2), and the pointers are read from registers
42///   (address space 1).
43/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size `BLOCK_SIZE` from the heap,
44///   starting from the addresses in `rs[0]` (and `rs[1]` if `R = 2`).
45/// * Writes are to 32-bit register rd.
46#[repr(C)]
47#[derive(AlignedBorrow, Debug)]
48pub struct Rv32IsEqualModAdapterCols<
49    T,
50    const NUM_READS: usize,
51    const BLOCKS_PER_READ: usize,
52    const BLOCK_SIZE: usize,
53> {
54    pub from_state: ExecutionState<T>,
55
56    pub rs_ptr: [T; NUM_READS],
57    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
58    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
59    pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
60
61    pub rd_ptr: T,
62    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
63}
64
65#[allow(dead_code)]
66#[derive(Clone, Copy, Debug, derive_new::new)]
67pub struct Rv32IsEqualModAdapterAir<
68    const NUM_READS: usize,
69    const BLOCKS_PER_READ: usize,
70    const BLOCK_SIZE: usize,
71    const TOTAL_READ_SIZE: usize,
72> {
73    pub(super) execution_bridge: ExecutionBridge,
74    pub(super) memory_bridge: MemoryBridge,
75    pub bus: BitwiseOperationLookupBus,
76    address_bits: usize,
77}
78
79impl<
80        F: Field,
81        const NUM_READS: usize,
82        const BLOCKS_PER_READ: usize,
83        const BLOCK_SIZE: usize,
84        const TOTAL_READ_SIZE: usize,
85    > BaseAir<F>
86    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
87{
88    fn width(&self) -> usize {
89        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
90    }
91}
92
93impl<
94        AB: InteractionBuilder,
95        const NUM_READS: usize,
96        const BLOCKS_PER_READ: usize,
97        const BLOCK_SIZE: usize,
98        const TOTAL_READ_SIZE: usize,
99    > VmAdapterAir<AB>
100    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
101{
102    type Interface = BasicAdapterInterface<
103        AB::Expr,
104        MinimalInstruction<AB::Expr>,
105        NUM_READS,
106        1,
107        TOTAL_READ_SIZE,
108        RV32_REGISTER_NUM_LIMBS,
109    >;
110
111    fn eval(
112        &self,
113        builder: &mut AB,
114        local: &[AB::Var],
115        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
116    ) {
117        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
118            local.borrow();
119        let timestamp = cols.from_state.timestamp;
120        let mut timestamp_delta: usize = 0;
121        let mut timestamp_pp = || {
122            timestamp_delta += 1;
123            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
124        };
125
126        // Address spaces
127        let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
128        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
129
130        // Read register values for rs
131        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
132            self.memory_bridge
133                .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
134                .eval(builder, ctx.instruction.is_valid.clone());
135        }
136
137        // Compose the u32 register value into single field element, with
138        // a range check on the highest limb.
139        let rs_val_f = cols.rs_val.map(|decomp| {
140            decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
141                acc * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS) + limb
142            })
143        });
144
145        let need_range_check: [_; 2] = from_fn(|i| {
146            if i < NUM_READS {
147                cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
148            } else {
149                AB::Expr::ZERO
150            }
151        });
152
153        let limb_shift = AB::F::from_canonical_usize(
154            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
155        );
156
157        self.bus
158            .send_range(
159                need_range_check[0].clone() * limb_shift,
160                need_range_check[1].clone() * limb_shift,
161            )
162            .eval(builder, ctx.instruction.is_valid.clone());
163
164        // Reads from heap
165        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
166        let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
167            ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
168                let mut r_it = r.into_iter();
169                from_fn(|_| from_fn(|_| r_it.next().unwrap()))
170            });
171        let block_ptr_offset: [_; BLOCKS_PER_READ] =
172            from_fn(|i| AB::F::from_canonical_usize(i * BLOCK_SIZE));
173
174        for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
175            for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
176                self.memory_bridge
177                    .read(
178                        MemoryAddress::new(e, ptr.clone() + offset),
179                        data,
180                        timestamp_pp(),
181                        aux,
182                    )
183                    .eval(builder, ctx.instruction.is_valid.clone());
184            }
185        }
186
187        // Write to rd register
188        self.memory_bridge
189            .write(
190                MemoryAddress::new(d, cols.rd_ptr),
191                ctx.writes[0].clone(),
192                timestamp_pp(),
193                &cols.writes_aux,
194            )
195            .eval(builder, ctx.instruction.is_valid.clone());
196
197        self.execution_bridge
198            .execute_and_increment_or_set_pc(
199                ctx.instruction.opcode,
200                [
201                    cols.rd_ptr.into(),
202                    cols.rs_ptr
203                        .first()
204                        .map(|&x| x.into())
205                        .unwrap_or(AB::Expr::ZERO),
206                    cols.rs_ptr
207                        .get(1)
208                        .map(|&x| x.into())
209                        .unwrap_or(AB::Expr::ZERO),
210                    d.into(),
211                    e.into(),
212                ],
213                cols.from_state,
214                AB::F::from_canonical_usize(timestamp_delta),
215                (DEFAULT_PC_STEP, ctx.to_pc),
216            )
217            .eval(builder, ctx.instruction.is_valid.clone());
218    }
219
220    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
221        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
222            local.borrow();
223        cols.from_state.pc
224    }
225}
226
227#[repr(C)]
228#[derive(AlignedBytesBorrow, Debug)]
229pub struct Rv32IsEqualModAdapterRecord<
230    const NUM_READS: usize,
231    const BLOCKS_PER_READ: usize,
232    const BLOCK_SIZE: usize,
233    const TOTAL_READ_SIZE: usize,
234> {
235    pub from_pc: u32,
236    pub timestamp: u32,
237
238    pub rs_ptr: [u32; NUM_READS],
239    pub rs_val: [u32; NUM_READS],
240    pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
241    pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
242
243    pub rd_ptr: u32,
244    pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
245}
246
247#[derive(Clone, Copy)]
248pub struct Rv32IsEqualModAdapterExecutor<
249    const NUM_READS: usize,
250    const BLOCKS_PER_READ: usize,
251    const BLOCK_SIZE: usize,
252    const TOTAL_READ_SIZE: usize,
253> {
254    pointer_max_bits: usize,
255}
256
257#[derive(derive_new::new)]
258pub struct Rv32IsEqualModAdapterFiller<
259    const NUM_READS: usize,
260    const BLOCKS_PER_READ: usize,
261    const BLOCK_SIZE: usize,
262    const TOTAL_READ_SIZE: usize,
263> {
264    pointer_max_bits: usize,
265    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
266}
267
268impl<
269        const NUM_READS: usize,
270        const BLOCKS_PER_READ: usize,
271        const BLOCK_SIZE: usize,
272        const TOTAL_READ_SIZE: usize,
273    > Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
274{
275    pub fn new(pointer_max_bits: usize) -> Self {
276        assert!(NUM_READS <= 2);
277        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
278        assert!(
279            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
280            "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
281        );
282        Self { pointer_max_bits }
283    }
284}
285
286impl<
287        F: PrimeField32,
288        const NUM_READS: usize,
289        const BLOCKS_PER_READ: usize,
290        const BLOCK_SIZE: usize,
291        const TOTAL_READ_SIZE: usize,
292    > AdapterTraceExecutor<F>
293    for Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
294where
295    F: PrimeField32,
296{
297    const WIDTH: usize =
298        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
299    type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS];
300    type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
301    type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord<
302        NUM_READS,
303        BLOCKS_PER_READ,
304        BLOCK_SIZE,
305        TOTAL_READ_SIZE,
306    >;
307
308    fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
309        record.from_pc = pc;
310        record.timestamp = memory.timestamp;
311    }
312
313    fn read(
314        &self,
315        memory: &mut TracingMemory,
316        instruction: &Instruction<F>,
317        record: &mut Self::RecordMut<'_>,
318    ) -> Self::ReadData {
319        let Instruction { b, c, d, e, .. } = *instruction;
320
321        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
322        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
323
324        // Read register values
325        record.rs_val = from_fn(|i| {
326            record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32();
327
328            u32::from_le_bytes(tracing_read(
329                memory,
330                RV32_REGISTER_AS,
331                record.rs_ptr[i],
332                &mut record.rs_read_aux[i].prev_timestamp,
333            ))
334        });
335
336        // Read memory values
337        from_fn(|i| {
338            debug_assert!(
339                record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)
340            );
341            from_fn::<_, BLOCKS_PER_READ, _>(|j| {
342                tracing_read::<BLOCK_SIZE>(
343                    memory,
344                    RV32_MEMORY_AS,
345                    record.rs_val[i] + (j * BLOCK_SIZE) as u32,
346                    &mut record.heap_read_aux[i][j].prev_timestamp,
347                )
348            })
349            .concat()
350            .try_into()
351            .unwrap()
352        })
353    }
354
355    fn write(
356        &self,
357        memory: &mut TracingMemory,
358        instruction: &Instruction<F>,
359        data: Self::WriteData,
360        record: &mut Self::RecordMut<'_>,
361    ) {
362        let Instruction { a, .. } = *instruction;
363        record.rd_ptr = a.as_canonical_u32();
364        tracing_write(
365            memory,
366            RV32_REGISTER_AS,
367            record.rd_ptr,
368            data,
369            &mut record.writes_aux.prev_timestamp,
370            &mut record.writes_aux.prev_data,
371        );
372    }
373}
374
375impl<
376        F: PrimeField32,
377        const NUM_READS: usize,
378        const BLOCKS_PER_READ: usize,
379        const BLOCK_SIZE: usize,
380        const TOTAL_READ_SIZE: usize,
381    > AdapterTraceFiller<F>
382    for Rv32IsEqualModAdapterFiller<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
383{
384    const WIDTH: usize =
385        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
386
387    fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
388        // SAFETY:
389        // - caller ensures `adapter_row` contains a valid record representation that was previously
390        //   written by the executor
391        let record: &Rv32IsEqualModAdapterRecord<
392            NUM_READS,
393            BLOCKS_PER_READ,
394            BLOCK_SIZE,
395            TOTAL_READ_SIZE,
396        > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
397
398        let cols: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
399            adapter_row.borrow_mut();
400
401        let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1;
402        let mut timestamp_mm = || {
403            timestamp -= 1;
404            timestamp
405        };
406        // Do range checks before writing anything:
407        debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
408        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
409        const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
410        self.bitwise_lookup_chip.request_range(
411            (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits,
412            if NUM_READS > 1 {
413                (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits
414            } else {
415                0
416            },
417        );
418        // Writing in reverse order
419        cols.writes_aux
420            .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8));
421        mem_helper.fill(
422            record.writes_aux.prev_timestamp,
423            timestamp_mm(),
424            cols.writes_aux.as_mut(),
425        );
426        cols.rd_ptr = F::from_canonical_u32(record.rd_ptr);
427
428        // **NOTE**: Must iterate everything in reverse order to avoid overwriting the records
429        cols.heap_read_aux
430            .iter_mut()
431            .rev()
432            .zip(record.heap_read_aux.iter().rev())
433            .for_each(|(col_reads, record_reads)| {
434                col_reads
435                    .iter_mut()
436                    .rev()
437                    .zip(record_reads.iter().rev())
438                    .for_each(|(col, record)| {
439                        mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
440                    });
441            });
442
443        cols.rs_read_aux
444            .iter_mut()
445            .rev()
446            .zip(record.rs_read_aux.iter().rev())
447            .for_each(|(col, record)| {
448                mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
449            });
450
451        cols.rs_val = record
452            .rs_val
453            .map(|val| val.to_le_bytes().map(F::from_canonical_u8));
454        cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_canonical_u32(ptr));
455
456        cols.from_state.timestamp = F::from_canonical_u32(record.timestamp);
457        cols.from_state.pc = F::from_canonical_u32(record.from_pc);
458    }
459}