1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::*,
5 system::memory::{
6 offline_checker::{
7 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
8 MemoryWriteBytesAuxRecord,
9 },
10 online::TracingMemory,
11 MemoryAddress, MemoryAuxColsFactory,
12 },
13};
14use openvm_circuit_primitives::{
15 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
16 utils::not,
17};
18use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow};
19use openvm_instructions::{
20 instruction::Instruction,
21 program::DEFAULT_PC_STEP,
22 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
23 LocalOpcode,
24};
25use openvm_rv32im_transpiler::{
26 Rv32HintStoreOpcode,
27 Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW},
28};
29use openvm_stark_backend::{
30 interaction::InteractionBuilder,
31 p3_air::{Air, AirBuilder, BaseAir},
32 p3_field::{Field, FieldAlgebra, PrimeField32},
33 p3_matrix::{dense::RowMajorMatrix, Matrix},
34 p3_maybe_rayon::prelude::*,
35 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
36};
37
38use crate::adapters::{read_rv32_register, tracing_read, tracing_write};
39
40mod execution;
41
42#[cfg(feature = "cuda")]
43mod cuda;
44#[cfg(feature = "cuda")]
45pub use cuda::*;
46
47#[cfg(test)]
48mod tests;
49
50#[repr(C)]
51#[derive(AlignedBorrow, Debug)]
52pub struct Rv32HintStoreCols<T> {
53 pub is_single: T,
55 pub is_buffer: T,
56 pub rem_words_limbs: [T; RV32_REGISTER_NUM_LIMBS],
58
59 pub from_state: ExecutionState<T>,
60 pub mem_ptr_ptr: T,
61 pub mem_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS],
62 pub mem_ptr_aux_cols: MemoryReadAuxCols<T>,
63
64 pub write_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
65 pub data: [T; RV32_REGISTER_NUM_LIMBS],
66
67 pub is_buffer_start: T,
69 pub num_words_ptr: T,
70 pub num_words_aux_cols: MemoryReadAuxCols<T>,
71}
72
73#[derive(Copy, Clone, Debug, derive_new::new)]
74pub struct Rv32HintStoreAir {
75 pub execution_bridge: ExecutionBridge,
76 pub memory_bridge: MemoryBridge,
77 pub bitwise_operation_lookup_bus: BitwiseOperationLookupBus,
78 pub offset: usize,
79 pointer_max_bits: usize,
80}
81
82impl<F: Field> BaseAir<F> for Rv32HintStoreAir {
83 fn width(&self) -> usize {
84 Rv32HintStoreCols::<F>::width()
85 }
86}
87
88impl<F: Field> BaseAirWithPublicValues<F> for Rv32HintStoreAir {}
89impl<F: Field> PartitionedBaseAir<F> for Rv32HintStoreAir {}
90
91impl<AB: InteractionBuilder> Air<AB> for Rv32HintStoreAir {
92 fn eval(&self, builder: &mut AB) {
93 let main = builder.main();
94 let local = main.row_slice(0);
95 let local_cols: &Rv32HintStoreCols<AB::Var> = (*local).borrow();
96 let next = main.row_slice(1);
97 let next_cols: &Rv32HintStoreCols<AB::Var> = (*next).borrow();
98
99 let timestamp: AB::Var = local_cols.from_state.timestamp;
100 let mut timestamp_delta: usize = 0;
101 let mut timestamp_pp = || {
102 timestamp_delta += 1;
103 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
104 };
105
106 builder.assert_bool(local_cols.is_single);
107 builder.assert_bool(local_cols.is_buffer);
108 builder.assert_bool(local_cols.is_buffer_start);
109 builder
110 .when(local_cols.is_buffer_start)
111 .assert_one(local_cols.is_buffer);
112 builder.assert_bool(local_cols.is_single + local_cols.is_buffer);
113
114 let is_valid = local_cols.is_single + local_cols.is_buffer;
115 let is_start = local_cols.is_single + local_cols.is_buffer_start;
116 let is_end = not::<AB::Expr>(next_cols.is_buffer) + next_cols.is_buffer_start;
120
121 let mut rem_words = AB::Expr::ZERO;
122 let mut next_rem_words = AB::Expr::ZERO;
123 let mut mem_ptr = AB::Expr::ZERO;
124 let mut next_mem_ptr = AB::Expr::ZERO;
125 for i in (0..RV32_REGISTER_NUM_LIMBS).rev() {
126 rem_words = rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
127 + local_cols.rem_words_limbs[i];
128 next_rem_words = next_rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
129 + next_cols.rem_words_limbs[i];
130 mem_ptr = mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
131 + local_cols.mem_ptr_limbs[i];
132 next_mem_ptr = next_mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
133 + next_cols.mem_ptr_limbs[i];
134 }
135
136 builder
138 .when_transition()
139 .when(not::<AB::Expr>(is_valid.clone()))
140 .assert_zero(next_cols.is_single + next_cols.is_buffer);
141
142 builder
144 .when(local_cols.is_single)
145 .assert_one(is_end.clone());
146 builder
147 .when_first_row()
148 .assert_one(not::<AB::Expr>(local_cols.is_buffer) + local_cols.is_buffer_start);
149
150 self.memory_bridge
152 .read(
153 MemoryAddress::new(
154 AB::F::from_canonical_u32(RV32_REGISTER_AS),
155 local_cols.mem_ptr_ptr,
156 ),
157 local_cols.mem_ptr_limbs,
158 timestamp_pp(),
159 &local_cols.mem_ptr_aux_cols,
160 )
161 .eval(builder, is_start.clone());
162
163 self.memory_bridge
165 .read(
166 MemoryAddress::new(
167 AB::F::from_canonical_u32(RV32_REGISTER_AS),
168 local_cols.num_words_ptr,
169 ),
170 local_cols.rem_words_limbs,
171 timestamp_pp(),
172 &local_cols.num_words_aux_cols,
173 )
174 .eval(builder, local_cols.is_buffer_start);
175
176 self.memory_bridge
178 .write(
179 MemoryAddress::new(AB::F::from_canonical_u32(RV32_MEMORY_AS), mem_ptr.clone()),
180 local_cols.data,
181 timestamp_pp(),
182 &local_cols.write_aux,
183 )
184 .eval(builder, is_valid.clone());
185 let expected_opcode = (local_cols.is_single
186 * AB::F::from_canonical_usize(HINT_STOREW as usize + self.offset))
187 + (local_cols.is_buffer
188 * AB::F::from_canonical_usize(HINT_BUFFER as usize + self.offset));
189
190 self.execution_bridge
191 .execute_and_increment_pc(
192 expected_opcode,
193 [
194 local_cols.is_buffer * (local_cols.num_words_ptr),
195 local_cols.mem_ptr_ptr.into(),
196 AB::Expr::ZERO,
197 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
198 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
199 ],
200 local_cols.from_state,
201 rem_words.clone() * AB::F::from_canonical_usize(timestamp_delta),
202 )
203 .eval(builder, is_start.clone());
204
205 self.bitwise_operation_lookup_bus
210 .send_range(
211 local_cols.mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1]
212 * AB::F::from_canonical_usize(
213 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
214 ),
215 local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1]
216 * AB::F::from_canonical_usize(
217 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
218 ),
219 )
220 .eval(builder, is_start.clone());
221
222 for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
224 self.bitwise_operation_lookup_bus
225 .send_range(local_cols.data[2 * i], local_cols.data[(2 * i) + 1])
226 .eval(builder, is_valid.clone());
227 }
228
229 builder
235 .when(is_valid)
236 .when(is_end.clone())
237 .assert_one(rem_words.clone());
238
239 let mut when_buffer_transition = builder.when(not::<AB::Expr>(is_end.clone()));
240 when_buffer_transition.assert_one(rem_words.clone() - next_rem_words.clone());
249 when_buffer_transition.assert_eq(
256 next_mem_ptr.clone() - mem_ptr.clone(),
257 AB::F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS),
258 );
259 when_buffer_transition.assert_eq(
260 timestamp + AB::F::from_canonical_usize(timestamp_delta),
261 next_cols.from_state.timestamp,
262 );
263 }
264}
265
266#[derive(Copy, Clone, Debug)]
267pub struct Rv32HintStoreMetadata {
268 num_words: usize,
269}
270
271impl MultiRowMetadata for Rv32HintStoreMetadata {
272 #[inline(always)]
273 fn get_num_rows(&self) -> usize {
274 self.num_words
275 }
276}
277
278pub type Rv32HintStoreLayout = MultiRowLayout<Rv32HintStoreMetadata>;
279
280#[repr(C)]
282#[derive(AlignedBytesBorrow, Debug)]
283pub struct Rv32HintStoreRecordHeader {
284 pub num_words: u32,
285
286 pub from_pc: u32,
287 pub timestamp: u32,
288
289 pub mem_ptr_ptr: u32,
290 pub mem_ptr: u32,
291 pub mem_ptr_aux_record: MemoryReadAuxRecord,
292
293 pub num_words_ptr: u32,
295 pub num_words_read: MemoryReadAuxRecord,
296}
297
298#[repr(C)]
300#[derive(AlignedBytesBorrow, Debug)]
301pub struct Rv32HintStoreVar {
302 pub data_write_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
303 pub data: [u8; RV32_REGISTER_NUM_LIMBS],
304}
305
306#[derive(Debug)]
310pub struct Rv32HintStoreRecordMut<'a> {
311 pub inner: &'a mut Rv32HintStoreRecordHeader,
312 pub var: &'a mut [Rv32HintStoreVar],
313}
314
315impl<'a> CustomBorrow<'a, Rv32HintStoreRecordMut<'a>, Rv32HintStoreLayout> for [u8] {
320 fn custom_borrow(&'a mut self, layout: Rv32HintStoreLayout) -> Rv32HintStoreRecordMut<'a> {
321 let (header_buf, rest) =
326 unsafe { self.split_at_mut_unchecked(size_of::<Rv32HintStoreRecordHeader>()) };
327
328 let (_, vars, _) = unsafe { rest.align_to_mut::<Rv32HintStoreVar>() };
333 Rv32HintStoreRecordMut {
334 inner: header_buf.borrow_mut(),
335 var: &mut vars[..layout.metadata.num_words],
336 }
337 }
338
339 unsafe fn extract_layout(&self) -> Rv32HintStoreLayout {
340 let header: &Rv32HintStoreRecordHeader = self.borrow();
341 MultiRowLayout::new(Rv32HintStoreMetadata {
342 num_words: header.num_words as usize,
343 })
344 }
345}
346
347impl SizedRecord<Rv32HintStoreLayout> for Rv32HintStoreRecordMut<'_> {
348 fn size(layout: &Rv32HintStoreLayout) -> usize {
349 let mut total_len = size_of::<Rv32HintStoreRecordHeader>();
350 total_len = total_len.next_multiple_of(align_of::<Rv32HintStoreVar>());
352 total_len += size_of::<Rv32HintStoreVar>() * layout.metadata.num_words;
353 total_len
354 }
355
356 fn alignment(_layout: &Rv32HintStoreLayout) -> usize {
357 align_of::<Rv32HintStoreRecordHeader>()
358 }
359}
360
361#[derive(Clone, Copy, derive_new::new)]
362pub struct Rv32HintStoreExecutor {
363 pub pointer_max_bits: usize,
364 pub offset: usize,
365}
366
367#[derive(Clone, derive_new::new)]
368pub struct Rv32HintStoreFiller {
369 pointer_max_bits: usize,
370 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
371}
372
373impl<F, RA> PreflightExecutor<F, RA> for Rv32HintStoreExecutor
374where
375 F: PrimeField32,
376 for<'buf> RA:
377 RecordArena<'buf, MultiRowLayout<Rv32HintStoreMetadata>, Rv32HintStoreRecordMut<'buf>>,
378{
379 fn get_opcode_name(&self, opcode: usize) -> String {
380 if opcode == HINT_STOREW.global_opcode().as_usize() {
381 String::from("HINT_STOREW")
382 } else if opcode == HINT_BUFFER.global_opcode().as_usize() {
383 String::from("HINT_BUFFER")
384 } else {
385 unreachable!("unsupported opcode: {}", opcode)
386 }
387 }
388
389 fn execute(
390 &self,
391 state: VmStateMut<F, TracingMemory, RA>,
392 instruction: &Instruction<F>,
393 ) -> Result<(), ExecutionError> {
394 let &Instruction {
395 opcode, a, b, d, e, ..
396 } = instruction;
397
398 let a = a.as_canonical_u32();
399 let b = b.as_canonical_u32();
400 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
401 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
402
403 let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
404
405 let num_words = if local_opcode == HINT_STOREW {
407 1
408 } else {
409 read_rv32_register(state.memory.data(), a)
410 };
411
412 let record = state.ctx.alloc(MultiRowLayout::new(Rv32HintStoreMetadata {
413 num_words: num_words as usize,
414 }));
415
416 record.inner.from_pc = *state.pc;
417 record.inner.timestamp = state.memory.timestamp;
418 record.inner.mem_ptr_ptr = b;
419
420 record.inner.mem_ptr = u32::from_le_bytes(tracing_read(
421 state.memory,
422 RV32_REGISTER_AS,
423 b,
424 &mut record.inner.mem_ptr_aux_record.prev_timestamp,
425 ));
426
427 debug_assert!(record.inner.mem_ptr <= (1 << self.pointer_max_bits));
428 debug_assert_ne!(num_words, 0);
429 debug_assert!(num_words <= (1 << self.pointer_max_bits));
430
431 record.inner.num_words = num_words;
432 if local_opcode == HINT_STOREW {
433 state.memory.increment_timestamp();
434 record.inner.num_words_ptr = u32::MAX;
435 } else {
436 record.inner.num_words_ptr = a;
437 tracing_read::<RV32_REGISTER_NUM_LIMBS>(
438 state.memory,
439 RV32_REGISTER_AS,
440 record.inner.num_words_ptr,
441 &mut record.inner.num_words_read.prev_timestamp,
442 );
443 };
444
445 if state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize {
446 return Err(ExecutionError::HintOutOfBounds { pc: *state.pc });
447 }
448
449 for idx in 0..(num_words as usize) {
450 if idx != 0 {
451 state.memory.increment_timestamp();
452 state.memory.increment_timestamp();
453 }
454
455 let data_f: [F; RV32_REGISTER_NUM_LIMBS] =
456 std::array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap());
457 let data: [u8; RV32_REGISTER_NUM_LIMBS] =
458 data_f.map(|byte| byte.as_canonical_u32() as u8);
459
460 record.var[idx].data = data;
461
462 tracing_write(
463 state.memory,
464 RV32_MEMORY_AS,
465 record.inner.mem_ptr + (RV32_REGISTER_NUM_LIMBS * idx) as u32,
466 data,
467 &mut record.var[idx].data_write_aux.prev_timestamp,
468 &mut record.var[idx].data_write_aux.prev_data,
469 );
470 }
471 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
472
473 Ok(())
474 }
475}
476
477impl<F: PrimeField32> TraceFiller<F> for Rv32HintStoreFiller {
478 fn fill_trace(
479 &self,
480 mem_helper: &MemoryAuxColsFactory<F>,
481 trace: &mut RowMajorMatrix<F>,
482 rows_used: usize,
483 ) {
484 if rows_used == 0 {
485 return;
486 }
487
488 let width = trace.width;
489 debug_assert_eq!(width, size_of::<Rv32HintStoreCols<u8>>());
490 let mut trace = &mut trace.values[..width * rows_used];
491 let mut sizes = Vec::with_capacity(rows_used);
492 let mut chunks = Vec::with_capacity(rows_used);
493
494 while !trace.is_empty() {
495 let record: &Rv32HintStoreRecordHeader =
500 unsafe { get_record_from_slice(&mut trace, ()) };
501 let (chunk, rest) = trace.split_at_mut(width * record.num_words as usize);
502 sizes.push(record.num_words);
503 chunks.push(chunk);
504 trace = rest;
505 }
506
507 let msl_rshift: u32 = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32;
508 let msl_lshift: u32 =
509 (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits) as u32;
510
511 chunks
512 .par_iter_mut()
513 .zip(sizes.par_iter())
514 .for_each(|(chunk, &num_words)| {
515 let record: Rv32HintStoreRecordMut = unsafe {
522 get_record_from_slice(
523 chunk,
524 MultiRowLayout::new(Rv32HintStoreMetadata {
525 num_words: num_words as usize,
526 }),
527 )
528 };
529 self.bitwise_lookup_chip.request_range(
530 (record.inner.mem_ptr >> msl_rshift) << msl_lshift,
531 (num_words >> msl_rshift) << msl_lshift,
532 );
533
534 let mut timestamp = record.inner.timestamp + num_words * 3;
535 let mut mem_ptr = record.inner.mem_ptr + num_words * RV32_REGISTER_NUM_LIMBS as u32;
536
537 chunk
542 .rchunks_exact_mut(width)
543 .zip(record.var.iter().enumerate().rev())
544 .for_each(|(row, (idx, var))| {
545 for pair in var.data.chunks_exact(2) {
546 self.bitwise_lookup_chip
547 .request_range(pair[0] as u32, pair[1] as u32);
548 }
549
550 let cols: &mut Rv32HintStoreCols<F> = row.borrow_mut();
551 let is_single = record.inner.num_words_ptr == u32::MAX;
552 timestamp -= 3;
553 if idx == 0 && !is_single {
554 mem_helper.fill(
555 record.inner.num_words_read.prev_timestamp,
556 timestamp + 1,
557 cols.num_words_aux_cols.as_mut(),
558 );
559 cols.num_words_ptr = F::from_canonical_u32(record.inner.num_words_ptr);
560 } else {
561 mem_helper.fill_zero(cols.num_words_aux_cols.as_mut());
562 cols.num_words_ptr = F::ZERO;
563 }
564
565 cols.is_buffer_start = F::from_bool(idx == 0 && !is_single);
566
567 cols.data = var.data.map(|x| F::from_canonical_u8(x));
569
570 cols.write_aux.set_prev_data(
571 var.data_write_aux
572 .prev_data
573 .map(|x| F::from_canonical_u8(x)),
574 );
575 mem_helper.fill(
576 var.data_write_aux.prev_timestamp,
577 timestamp + 2,
578 cols.write_aux.as_mut(),
579 );
580
581 if idx == 0 {
582 mem_helper.fill(
583 record.inner.mem_ptr_aux_record.prev_timestamp,
584 timestamp,
585 cols.mem_ptr_aux_cols.as_mut(),
586 );
587 } else {
588 mem_helper.fill_zero(cols.mem_ptr_aux_cols.as_mut());
589 }
590
591 mem_ptr -= RV32_REGISTER_NUM_LIMBS as u32;
592 cols.mem_ptr_limbs = mem_ptr.to_le_bytes().map(|x| F::from_canonical_u8(x));
593 cols.mem_ptr_ptr = F::from_canonical_u32(record.inner.mem_ptr_ptr);
594
595 cols.from_state.timestamp = F::from_canonical_u32(timestamp);
596 cols.from_state.pc = F::from_canonical_u32(record.inner.from_pc);
597
598 cols.rem_words_limbs = (num_words - idx as u32)
599 .to_le_bytes()
600 .map(|x| F::from_canonical_u8(x));
601 cols.is_buffer = F::from_bool(!is_single);
602 cols.is_single = F::from_bool(is_single);
603 });
604 })
605 }
606}
607
608pub type Rv32HintStoreChip<F> = VmChipWrapper<F, Rv32HintStoreFiller>;