1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4 marker::PhantomData,
5};
6
7use itertools::izip;
8use openvm_circuit::{
9 arch::{
10 AdapterAirContext, AdapterRuntimeContext, BasicAdapterInterface, ExecutionBridge,
11 ExecutionBus, ExecutionState, MinimalInstruction, Result, VmAdapterAir, VmAdapterChip,
12 VmAdapterInterface,
13 },
14 system::{
15 memory::{
16 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
17 MemoryAddress, MemoryController, OfflineMemory, RecordId,
18 },
19 program::ProgramBus,
20 },
21};
22use openvm_circuit_primitives::bitwise_op_lookup::{
23 BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27 instruction::Instruction,
28 program::DEFAULT_PC_STEP,
29 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32 read_rv32_register, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::BaseAir,
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39use serde::{Deserialize, Serialize};
40use serde_big_array::BigArray;
41use serde_with::serde_as;
42
43#[repr(C)]
51#[derive(AlignedBorrow)]
52pub struct Rv32IsEqualModAdapterCols<
53 T,
54 const NUM_READS: usize,
55 const BLOCKS_PER_READ: usize,
56 const BLOCK_SIZE: usize,
57> {
58 pub from_state: ExecutionState<T>,
59
60 pub rs_ptr: [T; NUM_READS],
61 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
62 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
63 pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
64
65 pub rd_ptr: T,
66 pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
67}
68
69#[allow(dead_code)]
70#[derive(Clone, Copy, Debug, derive_new::new)]
71pub struct Rv32IsEqualModAdapterAir<
72 const NUM_READS: usize,
73 const BLOCKS_PER_READ: usize,
74 const BLOCK_SIZE: usize,
75 const TOTAL_READ_SIZE: usize,
76> {
77 pub(super) execution_bridge: ExecutionBridge,
78 pub(super) memory_bridge: MemoryBridge,
79 pub bus: BitwiseOperationLookupBus,
80 address_bits: usize,
81}
82
83impl<
84 F: Field,
85 const NUM_READS: usize,
86 const BLOCKS_PER_READ: usize,
87 const BLOCK_SIZE: usize,
88 const TOTAL_READ_SIZE: usize,
89 > BaseAir<F>
90 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
91{
92 fn width(&self) -> usize {
93 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
94 }
95}
96
97impl<
98 AB: InteractionBuilder,
99 const NUM_READS: usize,
100 const BLOCKS_PER_READ: usize,
101 const BLOCK_SIZE: usize,
102 const TOTAL_READ_SIZE: usize,
103 > VmAdapterAir<AB>
104 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
105{
106 type Interface = BasicAdapterInterface<
107 AB::Expr,
108 MinimalInstruction<AB::Expr>,
109 NUM_READS,
110 1,
111 TOTAL_READ_SIZE,
112 RV32_REGISTER_NUM_LIMBS,
113 >;
114
115 fn eval(
116 &self,
117 builder: &mut AB,
118 local: &[AB::Var],
119 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
120 ) {
121 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
122 local.borrow();
123 let timestamp = cols.from_state.timestamp;
124 let mut timestamp_delta: usize = 0;
125 let mut timestamp_pp = || {
126 timestamp_delta += 1;
127 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
128 };
129
130 let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
132 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
133
134 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
136 self.memory_bridge
137 .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
138 .eval(builder, ctx.instruction.is_valid.clone());
139 }
140
141 let rs_val_f = cols.rs_val.map(|decomp| {
144 decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
145 acc * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS) + limb
146 })
147 });
148
149 let need_range_check: [_; 2] = from_fn(|i| {
150 if i < NUM_READS {
151 cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
152 } else {
153 AB::Expr::ZERO
154 }
155 });
156
157 let limb_shift = AB::F::from_canonical_usize(
158 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
159 );
160
161 self.bus
162 .send_range(
163 need_range_check[0].clone() * limb_shift,
164 need_range_check[1].clone() * limb_shift,
165 )
166 .eval(builder, ctx.instruction.is_valid.clone());
167
168 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
170 let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
171 ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
172 let mut r_it = r.into_iter();
173 from_fn(|_| from_fn(|_| r_it.next().unwrap()))
174 });
175 let block_ptr_offset: [_; BLOCKS_PER_READ] =
176 from_fn(|i| AB::F::from_canonical_usize(i * BLOCK_SIZE));
177
178 for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
179 for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
180 self.memory_bridge
181 .read(
182 MemoryAddress::new(e, ptr.clone() + offset),
183 data,
184 timestamp_pp(),
185 aux,
186 )
187 .eval(builder, ctx.instruction.is_valid.clone());
188 }
189 }
190
191 self.memory_bridge
193 .write(
194 MemoryAddress::new(d, cols.rd_ptr),
195 ctx.writes[0].clone(),
196 timestamp_pp(),
197 &cols.writes_aux,
198 )
199 .eval(builder, ctx.instruction.is_valid.clone());
200
201 self.execution_bridge
202 .execute_and_increment_or_set_pc(
203 ctx.instruction.opcode,
204 [
205 cols.rd_ptr.into(),
206 cols.rs_ptr
207 .first()
208 .map(|&x| x.into())
209 .unwrap_or(AB::Expr::ZERO),
210 cols.rs_ptr
211 .get(1)
212 .map(|&x| x.into())
213 .unwrap_or(AB::Expr::ZERO),
214 d.into(),
215 e.into(),
216 ],
217 cols.from_state,
218 AB::F::from_canonical_usize(timestamp_delta),
219 (DEFAULT_PC_STEP, ctx.to_pc),
220 )
221 .eval(builder, ctx.instruction.is_valid.clone());
222 }
223
224 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
225 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
226 local.borrow();
227 cols.from_state.pc
228 }
229}
230
231pub struct Rv32IsEqualModAdapterChip<
232 F: Field,
233 const NUM_READS: usize,
234 const BLOCKS_PER_READ: usize,
235 const BLOCK_SIZE: usize,
236 const TOTAL_READ_SIZE: usize,
237> {
238 pub air: Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>,
239 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
240 _marker: PhantomData<F>,
241}
242
243impl<
244 F: PrimeField32,
245 const NUM_READS: usize,
246 const BLOCKS_PER_READ: usize,
247 const BLOCK_SIZE: usize,
248 const TOTAL_READ_SIZE: usize,
249 > Rv32IsEqualModAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
250{
251 pub fn new(
252 execution_bus: ExecutionBus,
253 program_bus: ProgramBus,
254 memory_bridge: MemoryBridge,
255 address_bits: usize,
256 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
257 ) -> Self {
258 assert!(NUM_READS <= 2);
259 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
260 assert!(
261 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - address_bits < RV32_CELL_BITS,
262 "address_bits={address_bits} needs to be large enough for high limb range check"
263 );
264 Self {
265 air: Rv32IsEqualModAdapterAir {
266 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
267 memory_bridge,
268 bus: bitwise_lookup_chip.bus(),
269 address_bits,
270 },
271 bitwise_lookup_chip,
272 _marker: PhantomData,
273 }
274 }
275}
276
277#[repr(C)]
278#[serde_as]
279#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
280pub struct Rv32IsEqualModReadRecord<
281 const NUM_READS: usize,
282 const BLOCKS_PER_READ: usize,
283 const BLOCK_SIZE: usize,
284> {
285 #[serde(with = "BigArray")]
286 pub rs: [RecordId; NUM_READS],
287 #[serde_as(as = "[[_; BLOCKS_PER_READ]; NUM_READS]")]
288 pub reads: [[RecordId; BLOCKS_PER_READ]; NUM_READS],
289}
290
291#[repr(C)]
292#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
293pub struct Rv32IsEqualModWriteRecord {
294 pub from_state: ExecutionState<u32>,
295 pub rd_id: RecordId,
296}
297
298impl<
299 F: PrimeField32,
300 const NUM_READS: usize,
301 const BLOCKS_PER_READ: usize,
302 const BLOCK_SIZE: usize,
303 const TOTAL_READ_SIZE: usize,
304 > VmAdapterChip<F>
305 for Rv32IsEqualModAdapterChip<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
306{
307 type ReadRecord = Rv32IsEqualModReadRecord<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>;
308 type WriteRecord = Rv32IsEqualModWriteRecord;
309 type Air = Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>;
310 type Interface = BasicAdapterInterface<
311 F,
312 MinimalInstruction<F>,
313 NUM_READS,
314 1,
315 TOTAL_READ_SIZE,
316 RV32_REGISTER_NUM_LIMBS,
317 >;
318
319 fn preprocess(
320 &mut self,
321 memory: &mut MemoryController<F>,
322 instruction: &Instruction<F>,
323 ) -> Result<(
324 <Self::Interface as VmAdapterInterface<F>>::Reads,
325 Self::ReadRecord,
326 )> {
327 let Instruction { b, c, d, e, .. } = *instruction;
328
329 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
330 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
331
332 let mut rs_vals = [0; NUM_READS];
333 let rs_records: [_; NUM_READS] = from_fn(|i| {
334 let addr = if i == 0 { b } else { c };
335 let (record, val) = read_rv32_register(memory, d, addr);
336 rs_vals[i] = val;
337 record
338 });
339
340 let read_records = rs_vals.map(|address| {
341 debug_assert!(address < (1 << self.air.address_bits));
342 from_fn(|i| {
343 memory
344 .read::<BLOCK_SIZE>(e, F::from_canonical_u32(address + (i * BLOCK_SIZE) as u32))
345 })
346 });
347
348 let read_data = read_records.map(|r| {
349 let read = r.map(|x| x.1);
350 let mut read_it = read.iter().flatten();
351 from_fn(|_| *(read_it.next().unwrap()))
352 });
353 let record = Rv32IsEqualModReadRecord {
354 rs: rs_records,
355 reads: read_records.map(|r| r.map(|x| x.0)),
356 };
357
358 Ok((read_data, record))
359 }
360
361 fn postprocess(
362 &mut self,
363 memory: &mut MemoryController<F>,
364 instruction: &Instruction<F>,
365 from_state: ExecutionState<u32>,
366 output: AdapterRuntimeContext<F, Self::Interface>,
367 _read_record: &Self::ReadRecord,
368 ) -> Result<(ExecutionState<u32>, Self::WriteRecord)> {
369 let Instruction { a, d, .. } = *instruction;
370 let (rd_id, _) = memory.write(d, a, output.writes[0]);
371
372 debug_assert!(
373 memory.timestamp() - from_state.timestamp
374 == (NUM_READS * (BLOCKS_PER_READ + 1) + 1) as u32,
375 "timestamp delta is {}, expected {}",
376 memory.timestamp() - from_state.timestamp,
377 NUM_READS * (BLOCKS_PER_READ + 1) + 1
378 );
379
380 Ok((
381 ExecutionState {
382 pc: from_state.pc + DEFAULT_PC_STEP,
383 timestamp: memory.timestamp(),
384 },
385 Self::WriteRecord { from_state, rd_id },
386 ))
387 }
388
389 fn generate_trace_row(
390 &self,
391 row_slice: &mut [F],
392 read_record: Self::ReadRecord,
393 write_record: Self::WriteRecord,
394 memory: &OfflineMemory<F>,
395 ) {
396 let aux_cols_factory = memory.aux_cols_factory();
397 let row_slice: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
398 row_slice.borrow_mut();
399 row_slice.from_state = write_record.from_state.map(F::from_canonical_u32);
400
401 let rs = read_record.rs.map(|r| memory.record_by_id(r));
402 for (i, r) in rs.iter().enumerate() {
403 row_slice.rs_ptr[i] = r.pointer;
404 row_slice.rs_val[i].copy_from_slice(r.data_slice());
405 aux_cols_factory.generate_read_aux(r, &mut row_slice.rs_read_aux[i]);
406 for (j, x) in read_record.reads[i].iter().enumerate() {
407 let read = memory.record_by_id(*x);
408 aux_cols_factory.generate_read_aux(read, &mut row_slice.heap_read_aux[i][j]);
409 }
410 }
411
412 let rd = memory.record_by_id(write_record.rd_id);
413 row_slice.rd_ptr = rd.pointer;
414 aux_cols_factory.generate_write_aux(rd, &mut row_slice.writes_aux);
415
416 let need_range_check: [u32; 2] = from_fn(|i| {
418 if i < NUM_READS {
419 rs[i]
420 .data_at(RV32_REGISTER_NUM_LIMBS - 1)
421 .as_canonical_u32()
422 } else {
423 0
424 }
425 });
426 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.air.address_bits;
427 self.bitwise_lookup_chip.request_range(
428 need_range_check[0] << limb_shift_bits,
429 need_range_check[1] << limb_shift_bits,
430 );
431 }
432
433 fn air(&self) -> &Self::Air {
434 &self.air
435 }
436}