openvm_rv32_adapters/
eq_mod.rs

1use std::{
2    array::from_fn,
3    borrow::{Borrow, BorrowMut},
4    marker::PhantomData,
5};
6
7use itertools::izip;
8use openvm_circuit::{
9    arch::{
10        AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
11        ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
12        VmAdapterInterface,
13    },
14    system::{
15        memory::{
16            offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17            MemoryAddress, MemoryController, OfflineMemory, RecordId,
18        },
19        program::ProgramBus,
20    },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23    BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27    instruction::Instruction,
28    program::DEFAULT_PC_STEP,
29    riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32    read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35    interaction::InteractionBuilder,
36    p3_air::BaseAir,
37    p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_big_array::BigArray;
41use serde_with::serde_as;
42
43/// This adapter reads from NUM_READS <= 2 pointers and writes to a register.
44/// * The data is read from the heap (address space 2), and the pointers
45///   are read from registers (address space 1).
46/// * Reads take the form of `BLOCKS_PER_READ` consecutive reads of size
47///   `BLOCK_SIZE` from the heap, starting from the addresses in `rs[0]`
48///   (and `rs[1]` if `R = 2`).
49/// * Writes are to 32-bit register rd.
50#[repr(C)]
51#[derive(AlignedBorrow)]
52pub struct Rv32IsEqualModAdapterCols<
53    T,
54    const NUM_READS: usize,
55    const BLOCKS_PER_READ: usize,
56    const BLOCK_SIZE: usize,
57> {
58    pub from_state: ExecutionState<T>,
59
60    pub rs_ptr: [T; NUM_READS],
61    pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
62    pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
63    pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
64
65    pub rd_ptr: T,
66    pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
67}
68
69#[allow(dead_code)]
70#[derive(Clone, Copy, Debug, derive_new::new)]
71pub struct Rv32IsEqualModAdapterAir<
72    const NUM_READS: usize,
73    const BLOCKS_PER_READ: usize,
74    const BLOCK_SIZE: usize,
75    const TOTAL_READ_SIZE: usize,
76> {
77    pub(super) execution_bridge: ExecutionBridge,
78    pub(super) memory_bridge: MemoryBridge,
79    pub bus: BitwiseOperationLookupBus,
80    address_bits: usize,
81}
82
83impl<
84        F: Field,
85        const NUM_READS: usize,
86        const BLOCKS_PER_READ: usize,
87        const BLOCK_SIZE: usize,
88        const TOTAL_READ_SIZE: usize,
89    > BaseAir<F>
90    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
91{
92    fn width(&self) -> usize {
93        Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
94    }
95}
96
97impl<
98        AB: InteractionBuilder,
99        const NUM_READS: usize,
100        const BLOCKS_PER_READ: usize,
101        const BLOCK_SIZE: usize,
102        const TOTAL_READ_SIZE: usize,
103    > VmAdapterAir<AB>
104    for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
105{
106    type Interface = BasicAdapterInterface<
107        AB::Expr,
108        MinimalInstruction<AB::Expr>,
109        NUM_READS,
110        1,
111        TOTAL_READ_SIZE,
112        RV32_REGISTER_NUM_LIMBS,
113    >;
114
115    fn eval(
116        &self,
117        builder: &mut AB,
118        local: &[AB::Var],
119        ctx: AdapterAirContext<AB::Expr, Self::Interface>,
120    ) {
121        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
122            local.borrow();
123        let timestamp = cols.from_state.timestamp;
124        let mut timestamp_delta: usize = 0;
125        let mut timestamp_pp = || {
126            timestamp_delta += 1;
127            timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
128        };
129
130        // Address spaces
131        let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
132        let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
133
134        // Read register values for rs
135        for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
136            self.memory_bridge
137                .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
138                .eval(builder, ctx.instruction.is_valid.clone());
139        }
140
141        // Compose the u32 register value into single field element, with
142        // a range check on the highest limb.
143        let rs_val_f = cols.rs_val.map(|decomp| {
144            decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
145                acc * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS) + limb
146            })
147        });
148
149        let need_range_check: [_; 2] = from_fn(|i| {
150            if i < NUM_READS {
151                cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
152            } else {
153                AB::Expr::ZERO
154            }
155        });
156
157        let limb_shift = AB::F::from_canonical_usize(
158            1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
159        );
160
161        self.bus
162            .send_range(
163                need_range_check[0].clone() * limb_shift,
164                need_range_check[1].clone() * limb_shift,
165            )
166            .eval(builder, ctx.instruction.is_valid.clone());
167
168        // Reads from heap
169        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
170        let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
171            ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
172                let mut r_it = r.into_iter();
173                from_fn(|_| from_fn(|_| r_it.next().unwrap()))
174            });
175        let block_ptr_offset: [_; BLOCKS_PER_READ] =
176            from_fn(|i| AB::F::from_canonical_usize(i * BLOCK_SIZE));
177
178        for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
179            for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
180                self.memory_bridge
181                    .read(
182                        MemoryAddress::new(e, ptr.clone() + offset),
183                        data,
184                        timestamp_pp(),
185                        aux,
186                    )
187                    .eval(builder, ctx.instruction.is_valid.clone());
188            }
189        }
190
191        // Write to rd register
192        self.memory_bridge
193            .write(
194                MemoryAddress::new(d, cols.rd_ptr),
195                ctx.writes[0].clone(),
196                timestamp_pp(),
197                &cols.writes_aux,
198            )
199            .eval(builder, ctx.instruction.is_valid.clone());
200
201        self.execution_bridge
202            .execute_and_increment_or_set_pc(
203                ctx.instruction.opcode,
204                [
205                    cols.rd_ptr.into(),
206                    cols.rs_ptr
207                        .first()
208                        .map(|&x| x.into())
209                        .unwrap_or(AB::Expr::ZERO),
210                    cols.rs_ptr
211                        .get(1)
212                        .map(|&x| x.into())
213                        .unwrap_or(AB::Expr::ZERO),
214                    d.into(),
215                    e.into(),
216                ],
217                cols.from_state,
218                AB::F::from_canonical_usize(timestamp_delta),
219                (DEFAULT_PC_STEP, ctx.to_pc),
220            )
221            .eval(builder, ctx.instruction.is_valid.clone());
222    }
223
224    fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
225        let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
226            local.borrow();
227        cols.from_state.pc
228    }
229}
230
231pub struct Rv32IsEqualModAdapterChip<
232    F: Field,
233    const NUM_READS: usize,
234    const BLOCKS_PER_READ: usize,
235    const BLOCK_SIZE: usize,
236    const TOTAL_READ_SIZE: usize,
237> {
238    pub air: Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>,
239    pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
240    _marker: PhantomData<F>,
241}
242
243impl<
244        F: PrimeField32,
245        const NUM_READS: usize,
246        const BLOCKS_PER_READ: usize,
247        const BLOCK_SIZE: usize,
248        const TOTAL_READ_SIZE: usize,
249    > Rv32IsEqualModAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
250{
251    pub fn new(
252        execution_bus: ExecutionBus,
253        program_bus: ProgramBus,
254        memory_bridge: MemoryBridge,
255        address_bits: usize,
256        bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
257    ) -> Self {
258        assert!(NUM_READS <= 2);
259        assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
260        assert!(
261            RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
262            "address_bits={address_bits} needs to be large enough for high limb range check"
263        );
264        Self {
265            air: Rv32IsEqualModAdapterAir {
266                execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
267                memory_bridge,
268                bus: bitwise_lookup_chip.bus(),
269                address_bits,
270            },
271            bitwise_lookup_chip,
272            _marker: PhantomData,
273        }
274    }
275}
276
277#[repr(C)]
278#[serde_as]
279#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
280pub struct Rv32IsEqualModReadRecord<
281    const NUM_READS: usize,
282    const BLOCKS_PER_READ: usize,
283    const BLOCK_SIZE: usize,
284> {
285    #[serde(with = "BigArray")]
286    pub rs: [RecordId; NUM_READS],
287    #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")]
288    pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS],
289}
290
291#[repr(C)]
292#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
293pub struct Rv32IsEqualModWriteRecord {
294    pub from_state: ExecutionState<u32>,
295    pub rd_id: RecordId,
296}
297
298impl<
299        F: PrimeField32,
300        const NUM_READS: usize,
301        const BLOCKS_PER_READ: usize,
302        const BLOCK_SIZE: usize,
303        const TOTAL_READ_SIZE: usize,
304    > VmAdapterChip<F>
305    for Rv32IsEqualModAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
306{
307    type ReadRecord = Rv32IsEqualModReadRecord<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>;
308    type WriteRecord = Rv32IsEqualModWriteRecord;
309    type Air = Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>;
310    type Interface = BasicAdapterInterface<
311        F,
312        MinimalInstruction<F>,
313        NUM_READS,
314        1,
315        TOTAL_READ_SIZE,
316        RV32_REGISTER_NUM_LIMBS,
317    >;
318
319    fn preprocess(
320        &mut self,
321        memory: &mut MemoryController<F>,
322        instruction: &Instruction<F>,
323    ) -> Result<(
324        <Self::Interface as VmAdapterInterface<F>>::Reads,
325        Self::ReadRecord,
326    )> {
327        let Instruction { b, c, d, e, .. } = *instruction;
328
329        debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
330        debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
331
332        let mut rs_vals = [0; NUM_READS];
333        let rs_records: [_; NUM_READS] = from_fn(|i| {
334            let addr = if i == 0 { b } else { c };
335            let (record, val) = read_rv32_register(memory, d, addr);
336            rs_vals[i] = val;
337            record
338        });
339
340        let read_records = rs_vals.map(|address| {
341            debug_assert!(address < (1 << self.air.address_bits));
342            from_fn(|i| {
343                memory
344                    .read::<BLOCK_SIZE>(e, F::from_canonical_u32(address + (i * BLOCK_SIZE) as u32))
345            })
346        });
347
348        let read_data = read_records.map(|r| {
349            let read = r.map(|x| x.1);
350            let mut read_it = read.iter().flatten();
351            from_fn(|_| *(read_it.next().unwrap()))
352        });
353        let record = Rv32IsEqualModReadRecord {
354            rs: rs_records,
355            reads: read_records.map(|r| r.map(|x| x.0)),
356        };
357
358        Ok((read_data, record))
359    }
360
361    fn postprocess(
362        &mut self,
363        memory: &mut MemoryController<F>,
364        instruction: &Instruction<F>,
365        from_state: ExecutionState<u32>,
366        output: AdapterRuntimeContext<F, Self::Interface>,
367        _read_record: &Self::ReadRecord,
368    ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
369        let Instruction { a, d, .. } = *instruction;
370        let (rd_id, _) = memory.write(d, a, output.writes[0]);
371
372        debug_assert!(
373            memory.timestamp() - from_state.timestamp
374                == (NUM_READS * (BLOCKS_PER_READ + 1) + 1) as u32,
375            "timestamp delta is {}, expected {}",
376            memory.timestamp() - from_state.timestamp,
377            NUM_READS * (BLOCKS_PER_READ + 1) + 1
378        );
379
380        Ok((
381            ExecutionState {
382                pc: from_state.pc + DEFAULT_PC_STEP,
383                timestamp: memory.timestamp(),
384            },
385            Self::WriteRecord { from_state, rd_id },
386        ))
387    }
388
389    fn generate_trace_row(
390        &self,
391        row_slice: &mut [F],
392        read_record: Self::ReadRecord,
393        write_record: Self::WriteRecord,
394        memory: &OfflineMemory<F>,
395    ) {
396        let aux_cols_factory = memory.aux_cols_factory();
397        let row_slice: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
398            row_slice.borrow_mut();
399        row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
400
401        let rs = read_record.rs.map(|r| memory.record_by_id(r));
402        for (i, r) in rs.iter().enumerate() {
403            row_slice.rs_ptr[i] = r.pointer;
404            row_slice.rs_val[i].copy_from_slice(r.data_slice());
405            aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
406            for (j, x) in read_record.reads[i].iter().enumerate() {
407                let read = memory.record_by_id(*x);
408                aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]);
409            }
410        }
411
412        let rd = memory.record_by_id(write_record.rd_id);
413        row_slice.rd_ptr = rd.pointer;
414        aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);
415
416        // Range checks
417        let need_range_check: [u32; 2] = from_fn(|i| {
418            if i < NUM_READS {
419                rs[i]
420                    .data_at(RV32_REGISTER_NUM_LIMBS - 1)
421                    .as_canonical_u32()
422            } else {
423                0
424            }
425        });
426        let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits;
427        self.bitwise_lookup_chip.request_range(
428            need_range_check[0] << limb_shift_bits,
429            need_range_check[1] << limb_shift_bits,
430        );
431    }
432
433    fn air(&self) -> &Self::Air {
434        &self.air
435    }
436}