1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8 arch::{
9 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10 BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
11 },
12 system::memory::{
13 offline_checker::{
14 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
15 MemoryWriteBytesAuxRecord,
16 },
17 online::TracingMemory,
18 MemoryAddress, MemoryAuxColsFactory,
19 },
20};
21use openvm_circuit_primitives::{
22 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
23 AlignedBytesBorrow,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27 instruction::Instruction,
28 program::DEFAULT_PC_STEP,
29 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32 tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::BaseAir,
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38};
39
40#[repr(C)]
47#[derive(AlignedBorrow, Debug)]
48pub struct Rv32IsEqualModAdapterCols<
49 T,
50 const NUM_READS: usize,
51 const BLOCKS_PER_READ: usize,
52 const BLOCK_SIZE: usize,
53> {
54 pub from_state: ExecutionState<T>,
55
56 pub rs_ptr: [T; NUM_READS],
57 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
58 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
59 pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
60
61 pub rd_ptr: T,
62 pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
63}
64
65#[allow(dead_code)]
66#[derive(Clone, Copy, Debug, derive_new::new)]
67pub struct Rv32IsEqualModAdapterAir<
68 const NUM_READS: usize,
69 const BLOCKS_PER_READ: usize,
70 const BLOCK_SIZE: usize,
71 const TOTAL_READ_SIZE: usize,
72> {
73 pub(super) execution_bridge: ExecutionBridge,
74 pub(super) memory_bridge: MemoryBridge,
75 pub bus: BitwiseOperationLookupBus,
76 address_bits: usize,
77}
78
79impl<
80 F: Field,
81 const NUM_READS: usize,
82 const BLOCKS_PER_READ: usize,
83 const BLOCK_SIZE: usize,
84 const TOTAL_READ_SIZE: usize,
85 > BaseAir<F>
86 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
87{
88 fn width(&self) -> usize {
89 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
90 }
91}
92
93impl<
94 AB: InteractionBuilder,
95 const NUM_READS: usize,
96 const BLOCKS_PER_READ: usize,
97 const BLOCK_SIZE: usize,
98 const TOTAL_READ_SIZE: usize,
99 > VmAdapterAir<AB>
100 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
101{
102 type Interface = BasicAdapterInterface<
103 AB::Expr,
104 MinimalInstruction<AB::Expr>,
105 NUM_READS,
106 1,
107 TOTAL_READ_SIZE,
108 RV32_REGISTER_NUM_LIMBS,
109 >;
110
111 fn eval(
112 &self,
113 builder: &mut AB,
114 local: &[AB::Var],
115 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
116 ) {
117 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
118 local.borrow();
119 let timestamp = cols.from_state.timestamp;
120 let mut timestamp_delta: usize = 0;
121 let mut timestamp_pp = || {
122 timestamp_delta += 1;
123 timestamp + AB::F::from_canonical_usize(timestamp_delta - 1)
124 };
125
126 let d = AB::F::from_canonical_u32(RV32_REGISTER_AS);
128 let e = AB::F::from_canonical_u32(RV32_MEMORY_AS);
129
130 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
132 self.memory_bridge
133 .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
134 .eval(builder, ctx.instruction.is_valid.clone());
135 }
136
137 let rs_val_f = cols.rs_val.map(|decomp| {
140 decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
141 acc * AB::Expr::from_canonical_usize(1 << RV32_CELL_BITS) + limb
142 })
143 });
144
145 let need_range_check: [_; 2] = from_fn(|i| {
146 if i < NUM_READS {
147 cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
148 } else {
149 AB::Expr::ZERO
150 }
151 });
152
153 let limb_shift = AB::F::from_canonical_usize(
154 1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits),
155 );
156
157 self.bus
158 .send_range(
159 need_range_check[0].clone() * limb_shift,
160 need_range_check[1].clone() * limb_shift,
161 )
162 .eval(builder, ctx.instruction.is_valid.clone());
163
164 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
166 let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
167 ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
168 let mut r_it = r.into_iter();
169 from_fn(|_| from_fn(|_| r_it.next().unwrap()))
170 });
171 let block_ptr_offset: [_; BLOCKS_PER_READ] =
172 from_fn(|i| AB::F::from_canonical_usize(i * BLOCK_SIZE));
173
174 for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
175 for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
176 self.memory_bridge
177 .read(
178 MemoryAddress::new(e, ptr.clone() + offset),
179 data,
180 timestamp_pp(),
181 aux,
182 )
183 .eval(builder, ctx.instruction.is_valid.clone());
184 }
185 }
186
187 self.memory_bridge
189 .write(
190 MemoryAddress::new(d, cols.rd_ptr),
191 ctx.writes[0].clone(),
192 timestamp_pp(),
193 &cols.writes_aux,
194 )
195 .eval(builder, ctx.instruction.is_valid.clone());
196
197 self.execution_bridge
198 .execute_and_increment_or_set_pc(
199 ctx.instruction.opcode,
200 [
201 cols.rd_ptr.into(),
202 cols.rs_ptr
203 .first()
204 .map(|&x| x.into())
205 .unwrap_or(AB::Expr::ZERO),
206 cols.rs_ptr
207 .get(1)
208 .map(|&x| x.into())
209 .unwrap_or(AB::Expr::ZERO),
210 d.into(),
211 e.into(),
212 ],
213 cols.from_state,
214 AB::F::from_canonical_usize(timestamp_delta),
215 (DEFAULT_PC_STEP, ctx.to_pc),
216 )
217 .eval(builder, ctx.instruction.is_valid.clone());
218 }
219
220 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
221 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
222 local.borrow();
223 cols.from_state.pc
224 }
225}
226
227#[repr(C)]
228#[derive(AlignedBytesBorrow, Debug)]
229pub struct Rv32IsEqualModAdapterRecord<
230 const NUM_READS: usize,
231 const BLOCKS_PER_READ: usize,
232 const BLOCK_SIZE: usize,
233 const TOTAL_READ_SIZE: usize,
234> {
235 pub from_pc: u32,
236 pub timestamp: u32,
237
238 pub rs_ptr: [u32; NUM_READS],
239 pub rs_val: [u32; NUM_READS],
240 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
241 pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
242
243 pub rd_ptr: u32,
244 pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
245}
246
247#[derive(Clone, Copy)]
248pub struct Rv32IsEqualModAdapterExecutor<
249 const NUM_READS: usize,
250 const BLOCKS_PER_READ: usize,
251 const BLOCK_SIZE: usize,
252 const TOTAL_READ_SIZE: usize,
253> {
254 pointer_max_bits: usize,
255}
256
257#[derive(derive_new::new)]
258pub struct Rv32IsEqualModAdapterFiller<
259 const NUM_READS: usize,
260 const BLOCKS_PER_READ: usize,
261 const BLOCK_SIZE: usize,
262 const TOTAL_READ_SIZE: usize,
263> {
264 pointer_max_bits: usize,
265 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
266}
267
268impl<
269 const NUM_READS: usize,
270 const BLOCKS_PER_READ: usize,
271 const BLOCK_SIZE: usize,
272 const TOTAL_READ_SIZE: usize,
273 > Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
274{
275 pub fn new(pointer_max_bits: usize) -> Self {
276 assert!(NUM_READS <= 2);
277 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
278 assert!(
279 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
280 "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
281 );
282 Self { pointer_max_bits }
283 }
284}
285
286impl<
287 F: PrimeField32,
288 const NUM_READS: usize,
289 const BLOCKS_PER_READ: usize,
290 const BLOCK_SIZE: usize,
291 const TOTAL_READ_SIZE: usize,
292 > AdapterTraceExecutor<F>
293 for Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
294where
295 F: PrimeField32,
296{
297 const WIDTH: usize =
298 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
299 type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS];
300 type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
301 type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord<
302 NUM_READS,
303 BLOCKS_PER_READ,
304 BLOCK_SIZE,
305 TOTAL_READ_SIZE,
306 >;
307
308 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
309 record.from_pc = pc;
310 record.timestamp = memory.timestamp;
311 }
312
313 fn read(
314 &self,
315 memory: &mut TracingMemory,
316 instruction: &Instruction<F>,
317 record: &mut Self::RecordMut<'_>,
318 ) -> Self::ReadData {
319 let Instruction { b, c, d, e, .. } = *instruction;
320
321 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
322 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
323
324 record.rs_val = from_fn(|i| {
326 record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32();
327
328 u32::from_le_bytes(tracing_read(
329 memory,
330 RV32_REGISTER_AS,
331 record.rs_ptr[i],
332 &mut record.rs_read_aux[i].prev_timestamp,
333 ))
334 });
335
336 from_fn(|i| {
338 debug_assert!(
339 record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)
340 );
341 from_fn::<_, BLOCKS_PER_READ, _>(|j| {
342 tracing_read::<BLOCK_SIZE>(
343 memory,
344 RV32_MEMORY_AS,
345 record.rs_val[i] + (j * BLOCK_SIZE) as u32,
346 &mut record.heap_read_aux[i][j].prev_timestamp,
347 )
348 })
349 .concat()
350 .try_into()
351 .unwrap()
352 })
353 }
354
355 fn write(
356 &self,
357 memory: &mut TracingMemory,
358 instruction: &Instruction<F>,
359 data: Self::WriteData,
360 record: &mut Self::RecordMut<'_>,
361 ) {
362 let Instruction { a, .. } = *instruction;
363 record.rd_ptr = a.as_canonical_u32();
364 tracing_write(
365 memory,
366 RV32_REGISTER_AS,
367 record.rd_ptr,
368 data,
369 &mut record.writes_aux.prev_timestamp,
370 &mut record.writes_aux.prev_data,
371 );
372 }
373}
374
375impl<
376 F: PrimeField32,
377 const NUM_READS: usize,
378 const BLOCKS_PER_READ: usize,
379 const BLOCK_SIZE: usize,
380 const TOTAL_READ_SIZE: usize,
381 > AdapterTraceFiller<F>
382 for Rv32IsEqualModAdapterFiller<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
383{
384 const WIDTH: usize =
385 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
386
387 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
388 let record: &Rv32IsEqualModAdapterRecord<
392 NUM_READS,
393 BLOCKS_PER_READ,
394 BLOCK_SIZE,
395 TOTAL_READ_SIZE,
396 > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
397
398 let cols: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
399 adapter_row.borrow_mut();
400
401 let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1;
402 let mut timestamp_mm = || {
403 timestamp -= 1;
404 timestamp
405 };
406 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
408 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
409 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
410 self.bitwise_lookup_chip.request_range(
411 (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits,
412 if NUM_READS > 1 {
413 (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits
414 } else {
415 0
416 },
417 );
418 cols.writes_aux
420 .set_prev_data(record.writes_aux.prev_data.map(F::from_canonical_u8));
421 mem_helper.fill(
422 record.writes_aux.prev_timestamp,
423 timestamp_mm(),
424 cols.writes_aux.as_mut(),
425 );
426 cols.rd_ptr = F::from_canonical_u32(record.rd_ptr);
427
428 cols.heap_read_aux
430 .iter_mut()
431 .rev()
432 .zip(record.heap_read_aux.iter().rev())
433 .for_each(|(col_reads, record_reads)| {
434 col_reads
435 .iter_mut()
436 .rev()
437 .zip(record_reads.iter().rev())
438 .for_each(|(col, record)| {
439 mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
440 });
441 });
442
443 cols.rs_read_aux
444 .iter_mut()
445 .rev()
446 .zip(record.rs_read_aux.iter().rev())
447 .for_each(|(col, record)| {
448 mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
449 });
450
451 cols.rs_val = record
452 .rs_val
453 .map(|val| val.to_le_bytes().map(F::from_canonical_u8));
454 cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_canonical_u32(ptr));
455
456 cols.from_state.timestamp = F::from_canonical_u32(record.timestamp);
457 cols.from_state.pc = F::from_canonical_u32(record.from_pc);
458 }
459}