1use std::{
2 array::from_fn,
3 borrow::{Borrow, BorrowMut},
4};
5
6use itertools::izip;
7use openvm_circuit::{
8 arch::{
9 get_record_from_slice, AdapterAirContext, AdapterTraceExecutor, AdapterTraceFiller,
10 BasicAdapterInterface, ExecutionBridge, ExecutionState, MinimalInstruction, VmAdapterAir,
11 },
12 system::memory::{
13 offline_checker::{
14 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
15 MemoryWriteBytesAuxRecord,
16 },
17 online::TracingMemory,
18 MemoryAddress, MemoryAuxColsFactory,
19 },
20};
21use openvm_circuit_primitives::{
22 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
23 AlignedBytesBorrow,
24};
25use openvm_circuit_primitives_derive::AlignedBorrow;
26use openvm_instructions::{
27 instruction::Instruction,
28 program::DEFAULT_PC_STEP,
29 riscv::{RV32_MEMORY_AS, RV32_REGISTER_AS},
30};
31use openvm_rv32im_circuit::adapters::{
32 tracing_read, tracing_write, RV32_CELL_BITS, RV32_REGISTER_NUM_LIMBS,
33};
34use openvm_stark_backend::{
35 interaction::InteractionBuilder,
36 p3_air::BaseAir,
37 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
38};
39
40#[repr(C)]
47#[derive(AlignedBorrow, Debug)]
48pub struct Rv32IsEqualModAdapterCols<
49 T,
50 const NUM_READS: usize,
51 const BLOCKS_PER_READ: usize,
52 const BLOCK_SIZE: usize,
53> {
54 pub from_state: ExecutionState<T>,
55
56 pub rs_ptr: [T; NUM_READS],
57 pub rs_val: [[T; RV32_REGISTER_NUM_LIMBS]; NUM_READS],
58 pub rs_read_aux: [MemoryReadAuxCols<T>; NUM_READS],
59 pub heap_read_aux: [[MemoryReadAuxCols<T>; BLOCKS_PER_READ]; NUM_READS],
60
61 pub rd_ptr: T,
62 pub writes_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
63}
64
65#[allow(dead_code)]
66#[derive(Clone, Copy, Debug, derive_new::new)]
67pub struct Rv32IsEqualModAdapterAir<
68 const NUM_READS: usize,
69 const BLOCKS_PER_READ: usize,
70 const BLOCK_SIZE: usize,
71 const TOTAL_READ_SIZE: usize,
72> {
73 pub(super) execution_bridge: ExecutionBridge,
74 pub(super) memory_bridge: MemoryBridge,
75 pub bus: BitwiseOperationLookupBus,
76 address_bits: usize,
77}
78
79impl<
80 F: Field,
81 const NUM_READS: usize,
82 const BLOCKS_PER_READ: usize,
83 const BLOCK_SIZE: usize,
84 const TOTAL_READ_SIZE: usize,
85 > BaseAir<F>
86 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
87{
88 fn width(&self) -> usize {
89 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width()
90 }
91}
92
93impl<
94 AB: InteractionBuilder,
95 const NUM_READS: usize,
96 const BLOCKS_PER_READ: usize,
97 const BLOCK_SIZE: usize,
98 const TOTAL_READ_SIZE: usize,
99 > VmAdapterAir<AB>
100 for Rv32IsEqualModAdapterAir<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
101{
102 type Interface = BasicAdapterInterface<
103 AB::Expr,
104 MinimalInstruction<AB::Expr>,
105 NUM_READS,
106 1,
107 TOTAL_READ_SIZE,
108 RV32_REGISTER_NUM_LIMBS,
109 >;
110
111 fn eval(
112 &self,
113 builder: &mut AB,
114 local: &[AB::Var],
115 ctx: AdapterAirContext<AB::Expr, Self::Interface>,
116 ) {
117 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
118 local.borrow();
119 let timestamp = cols.from_state.timestamp;
120 let mut timestamp_delta: usize = 0;
121 let mut timestamp_pp = || {
122 timestamp_delta += 1;
123 timestamp + AB::F::from_usize(timestamp_delta - 1)
124 };
125
126 let d = AB::F::from_u32(RV32_REGISTER_AS);
128 let e = AB::F::from_u32(RV32_MEMORY_AS);
129
130 for (ptr, val, aux) in izip!(cols.rs_ptr, cols.rs_val, &cols.rs_read_aux) {
132 self.memory_bridge
133 .read(MemoryAddress::new(d, ptr), val, timestamp_pp(), aux)
134 .eval(builder, ctx.instruction.is_valid.clone());
135 }
136
137 let rs_val_f = cols.rs_val.map(|decomp| {
140 decomp.iter().rev().fold(AB::Expr::ZERO, |acc, &limb| {
141 acc * AB::Expr::from_usize(1 << RV32_CELL_BITS) + limb
142 })
143 });
144
145 let need_range_check: [_; 2] = from_fn(|i| {
146 if i < NUM_READS {
147 cols.rs_val[i][RV32_REGISTER_NUM_LIMBS - 1].into()
148 } else {
149 AB::Expr::ZERO
150 }
151 });
152
153 let limb_shift =
154 AB::F::from_usize(1 << (RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.address_bits));
155
156 self.bus
157 .send_range(
158 need_range_check[0].clone() * limb_shift,
159 need_range_check[1].clone() * limb_shift,
160 )
161 .eval(builder, ctx.instruction.is_valid.clone());
162
163 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
165 let read_block_data: [[[_; BLOCK_SIZE]; BLOCKS_PER_READ]; NUM_READS] =
166 ctx.reads.map(|r: [AB::Expr; TOTAL_READ_SIZE]| {
167 let mut r_it = r.into_iter();
168 from_fn(|_| from_fn(|_| r_it.next().unwrap()))
169 });
170 let block_ptr_offset: [_; BLOCKS_PER_READ] = from_fn(|i| AB::F::from_usize(i * BLOCK_SIZE));
171
172 for (ptr, block_data, block_aux) in izip!(rs_val_f, read_block_data, &cols.heap_read_aux) {
173 for (offset, data, aux) in izip!(block_ptr_offset, block_data, block_aux) {
174 self.memory_bridge
175 .read(
176 MemoryAddress::new(e, ptr.clone() + offset),
177 data,
178 timestamp_pp(),
179 aux,
180 )
181 .eval(builder, ctx.instruction.is_valid.clone());
182 }
183 }
184
185 self.memory_bridge
187 .write(
188 MemoryAddress::new(d, cols.rd_ptr),
189 ctx.writes[0].clone(),
190 timestamp_pp(),
191 &cols.writes_aux,
192 )
193 .eval(builder, ctx.instruction.is_valid.clone());
194
195 self.execution_bridge
196 .execute_and_increment_or_set_pc(
197 ctx.instruction.opcode,
198 [
199 cols.rd_ptr.into(),
200 cols.rs_ptr
201 .first()
202 .map(|&x| x.into())
203 .unwrap_or(AB::Expr::ZERO),
204 cols.rs_ptr
205 .get(1)
206 .map(|&x| x.into())
207 .unwrap_or(AB::Expr::ZERO),
208 d.into(),
209 e.into(),
210 ],
211 cols.from_state,
212 AB::F::from_usize(timestamp_delta),
213 (DEFAULT_PC_STEP, ctx.to_pc),
214 )
215 .eval(builder, ctx.instruction.is_valid.clone());
216 }
217
218 fn get_from_pc(&self, local: &[AB::Var]) -> AB::Var {
219 let cols: &Rv32IsEqualModAdapterCols<_, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
220 local.borrow();
221 cols.from_state.pc
222 }
223}
224
225#[repr(C)]
226#[derive(AlignedBytesBorrow, Debug)]
227pub struct Rv32IsEqualModAdapterRecord<
228 const NUM_READS: usize,
229 const BLOCKS_PER_READ: usize,
230 const BLOCK_SIZE: usize,
231 const TOTAL_READ_SIZE: usize,
232> {
233 pub from_pc: u32,
234 pub timestamp: u32,
235
236 pub rs_ptr: [u32; NUM_READS],
237 pub rs_val: [u32; NUM_READS],
238 pub rs_read_aux: [MemoryReadAuxRecord; NUM_READS],
239 pub heap_read_aux: [[MemoryReadAuxRecord; BLOCKS_PER_READ]; NUM_READS],
240
241 pub rd_ptr: u32,
242 pub writes_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
243}
244
245#[derive(Clone, Copy)]
246pub struct Rv32IsEqualModAdapterExecutor<
247 const NUM_READS: usize,
248 const BLOCKS_PER_READ: usize,
249 const BLOCK_SIZE: usize,
250 const TOTAL_READ_SIZE: usize,
251> {
252 pointer_max_bits: usize,
253}
254
255#[derive(derive_new::new)]
256pub struct Rv32IsEqualModAdapterFiller<
257 const NUM_READS: usize,
258 const BLOCKS_PER_READ: usize,
259 const BLOCK_SIZE: usize,
260 const TOTAL_READ_SIZE: usize,
261> {
262 pointer_max_bits: usize,
263 pub bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
264}
265
266impl<
267 const NUM_READS: usize,
268 const BLOCKS_PER_READ: usize,
269 const BLOCK_SIZE: usize,
270 const TOTAL_READ_SIZE: usize,
271 > Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
272{
273 pub fn new(pointer_max_bits: usize) -> Self {
274 assert!(NUM_READS <= 2);
275 assert_eq!(TOTAL_READ_SIZE, BLOCKS_PER_READ * BLOCK_SIZE);
276 assert!(
277 RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - pointer_max_bits < RV32_CELL_BITS,
278 "pointer_max_bits={pointer_max_bits} needs to be large enough for high limb range check"
279 );
280 Self { pointer_max_bits }
281 }
282}
283
284impl<
285 F: PrimeField32,
286 const NUM_READS: usize,
287 const BLOCKS_PER_READ: usize,
288 const BLOCK_SIZE: usize,
289 const TOTAL_READ_SIZE: usize,
290 > AdapterTraceExecutor<F>
291 for Rv32IsEqualModAdapterExecutor<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
292where
293 F: PrimeField32,
294{
295 const WIDTH: usize =
296 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
297 type ReadData = [[u8; TOTAL_READ_SIZE]; NUM_READS];
298 type WriteData = [u8; RV32_REGISTER_NUM_LIMBS];
299 type RecordMut<'a> = &'a mut Rv32IsEqualModAdapterRecord<
300 NUM_READS,
301 BLOCKS_PER_READ,
302 BLOCK_SIZE,
303 TOTAL_READ_SIZE,
304 >;
305
306 fn start(pc: u32, memory: &TracingMemory, record: &mut Self::RecordMut<'_>) {
307 record.from_pc = pc;
308 record.timestamp = memory.timestamp;
309 }
310
311 fn read(
312 &self,
313 memory: &mut TracingMemory,
314 instruction: &Instruction<F>,
315 record: &mut Self::RecordMut<'_>,
316 ) -> Self::ReadData {
317 let Instruction { b, c, d, e, .. } = *instruction;
318
319 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
320 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
321
322 record.rs_val = from_fn(|i| {
324 record.rs_ptr[i] = if i == 0 { b } else { c }.as_canonical_u32();
325
326 u32::from_le_bytes(tracing_read(
327 memory,
328 RV32_REGISTER_AS,
329 record.rs_ptr[i],
330 &mut record.rs_read_aux[i].prev_timestamp,
331 ))
332 });
333
334 from_fn(|i| {
336 debug_assert!(
337 record.rs_val[i] as usize + TOTAL_READ_SIZE - 1 < (1 << self.pointer_max_bits)
338 );
339 from_fn::<_, BLOCKS_PER_READ, _>(|j| {
340 tracing_read::<BLOCK_SIZE>(
341 memory,
342 RV32_MEMORY_AS,
343 record.rs_val[i] + (j * BLOCK_SIZE) as u32,
344 &mut record.heap_read_aux[i][j].prev_timestamp,
345 )
346 })
347 .concat()
348 .try_into()
349 .unwrap()
350 })
351 }
352
353 fn write(
354 &self,
355 memory: &mut TracingMemory,
356 instruction: &Instruction<F>,
357 data: Self::WriteData,
358 record: &mut Self::RecordMut<'_>,
359 ) {
360 let Instruction { a, .. } = *instruction;
361 record.rd_ptr = a.as_canonical_u32();
362 tracing_write(
363 memory,
364 RV32_REGISTER_AS,
365 record.rd_ptr,
366 data,
367 &mut record.writes_aux.prev_timestamp,
368 &mut record.writes_aux.prev_data,
369 );
370 }
371}
372
373impl<
374 F: PrimeField32,
375 const NUM_READS: usize,
376 const BLOCKS_PER_READ: usize,
377 const BLOCK_SIZE: usize,
378 const TOTAL_READ_SIZE: usize,
379 > AdapterTraceFiller<F>
380 for Rv32IsEqualModAdapterFiller<NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE, TOTAL_READ_SIZE>
381{
382 const WIDTH: usize =
383 Rv32IsEqualModAdapterCols::<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE>::width();
384
385 fn fill_trace_row(&self, mem_helper: &MemoryAuxColsFactory<F>, mut adapter_row: &mut [F]) {
386 let record: &Rv32IsEqualModAdapterRecord<
390 NUM_READS,
391 BLOCKS_PER_READ,
392 BLOCK_SIZE,
393 TOTAL_READ_SIZE,
394 > = unsafe { get_record_from_slice(&mut adapter_row, ()) };
395
396 let cols: &mut Rv32IsEqualModAdapterCols<F, NUM_READS, BLOCKS_PER_READ, BLOCK_SIZE> =
397 adapter_row.borrow_mut();
398
399 let mut timestamp = record.timestamp + (NUM_READS + NUM_READS * BLOCKS_PER_READ) as u32 + 1;
400 let mut timestamp_mm = || {
401 timestamp -= 1;
402 timestamp
403 };
404 debug_assert!(self.pointer_max_bits <= RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS);
406 let limb_shift_bits = RV32_CELL_BITS * RV32_REGISTER_NUM_LIMBS - self.pointer_max_bits;
407 const MSL_SHIFT: usize = RV32_CELL_BITS * (RV32_REGISTER_NUM_LIMBS - 1);
408 self.bitwise_lookup_chip.request_range(
409 (record.rs_val[0] >> MSL_SHIFT) << limb_shift_bits,
410 if NUM_READS > 1 {
411 (record.rs_val[1] >> MSL_SHIFT) << limb_shift_bits
412 } else {
413 0
414 },
415 );
416 cols.writes_aux
418 .set_prev_data(record.writes_aux.prev_data.map(F::from_u8));
419 mem_helper.fill(
420 record.writes_aux.prev_timestamp,
421 timestamp_mm(),
422 cols.writes_aux.as_mut(),
423 );
424 cols.rd_ptr = F::from_u32(record.rd_ptr);
425
426 cols.heap_read_aux
428 .iter_mut()
429 .rev()
430 .zip(record.heap_read_aux.iter().rev())
431 .for_each(|(col_reads, record_reads)| {
432 col_reads
433 .iter_mut()
434 .rev()
435 .zip(record_reads.iter().rev())
436 .for_each(|(col, record)| {
437 mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
438 });
439 });
440
441 cols.rs_read_aux
442 .iter_mut()
443 .rev()
444 .zip(record.rs_read_aux.iter().rev())
445 .for_each(|(col, record)| {
446 mem_helper.fill(record.prev_timestamp, timestamp_mm(), col.as_mut());
447 });
448
449 cols.rs_val = record.rs_val.map(|val| val.to_le_bytes().map(F::from_u8));
450 cols.rs_ptr = record.rs_ptr.map(|ptr| F::from_u32(ptr));
451
452 cols.from_state.timestamp = F::from_u32(record.timestamp);
453 cols.from_state.pc = F::from_u32(record.from_pc);
454 }
455}