1use std::borrow::{Borrow, BorrowMut};
2
3use openvm_circuit::{
4 arch::*,
5 system::memory::{
6 offline_checker::{
7 MemoryBridge, MemoryReadAuxCols, MemoryReadAuxRecord, MemoryWriteAuxCols,
8 MemoryWriteBytesAuxRecord,
9 },
10 online::TracingMemory,
11 MemoryAddress, MemoryAuxColsFactory,
12 },
13};
14use openvm_circuit_primitives::{
15 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
16 utils::not,
17};
18use openvm_circuit_primitives_derive::{AlignedBorrow, AlignedBytesBorrow};
19use openvm_instructions::{
20 instruction::Instruction,
21 program::DEFAULT_PC_STEP,
22 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
23 LocalOpcode,
24};
25use openvm_rv32im_transpiler::{
26 Rv32HintStoreOpcode,
27 Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW},
28};
29use openvm_stark_backend::{
30 interaction::InteractionBuilder,
31 p3_air::{Air, AirBuilder, BaseAir},
32 p3_field::{Field, PrimeCharacteristicRing, PrimeField32},
33 p3_matrix::{dense::RowMajorMatrix, Matrix},
34 p3_maybe_rayon::prelude::*,
35 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
36};
37
38use crate::adapters::{read_rv32_register, tracing_read, tracing_write};
39
40mod execution;
41
42#[cfg(feature = "cuda")]
43mod cuda;
44#[cfg(feature = "cuda")]
45pub use cuda::*;
46
47#[cfg(test)]
48mod tests;
49
50#[repr(C)]
51#[derive(AlignedBorrow, Debug)]
52pub struct Rv32HintStoreCols<T> {
53 pub is_single: T,
55 pub is_buffer: T,
56 pub rem_words_limbs: [T; RV32_REGISTER_NUM_LIMBS],
58
59 pub from_state: ExecutionState<T>,
60 pub mem_ptr_ptr: T,
61 pub mem_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS],
62 pub mem_ptr_aux_cols: MemoryReadAuxCols<T>,
63
64 pub write_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
65 pub data: [T; RV32_REGISTER_NUM_LIMBS],
66
67 pub is_buffer_start: T,
69 pub num_words_ptr: T,
70 pub num_words_aux_cols: MemoryReadAuxCols<T>,
71}
72
73#[derive(Copy, Clone, Debug, derive_new::new)]
74pub struct Rv32HintStoreAir {
75 pub execution_bridge: ExecutionBridge,
76 pub memory_bridge: MemoryBridge,
77 pub bitwise_operation_lookup_bus: BitwiseOperationLookupBus,
78 pub offset: usize,
79 pointer_max_bits: usize,
80}
81
82impl<F: Field> BaseAir<F> for Rv32HintStoreAir {
83 fn width(&self) -> usize {
84 Rv32HintStoreCols::<F>::width()
85 }
86}
87
88impl<F: Field> BaseAirWithPublicValues<F> for Rv32HintStoreAir {}
89impl<F: Field> PartitionedBaseAir<F> for Rv32HintStoreAir {}
90
91impl<AB: InteractionBuilder> Air<AB> for Rv32HintStoreAir {
92 fn eval(&self, builder: &mut AB) {
93 let main = builder.main();
94 let local = main.row_slice(0).expect("window should have two elements");
95 let local_cols: &Rv32HintStoreCols<AB::Var> = (*local).borrow();
96 let next = main.row_slice(1).expect("window should have two elements");
97 let next_cols: &Rv32HintStoreCols<AB::Var> = (*next).borrow();
98
99 let timestamp: AB::Var = local_cols.from_state.timestamp;
100 let mut timestamp_delta: usize = 0;
101 let mut timestamp_pp = || {
102 timestamp_delta += 1;
103 timestamp + AB::Expr::from_usize(timestamp_delta - 1)
104 };
105
106 builder.assert_bool(local_cols.is_single);
107 builder.assert_bool(local_cols.is_buffer);
108 builder.assert_bool(local_cols.is_buffer_start);
109 builder
110 .when(local_cols.is_buffer_start)
111 .assert_one(local_cols.is_buffer);
112 builder.assert_bool(local_cols.is_single + local_cols.is_buffer);
113
114 let is_valid = local_cols.is_single + local_cols.is_buffer;
115 let is_start = local_cols.is_single + local_cols.is_buffer_start;
116 let is_end = not::<AB::Expr>(next_cols.is_buffer) + next_cols.is_buffer_start;
120
121 let mut rem_words = AB::Expr::ZERO;
122 let mut next_rem_words = AB::Expr::ZERO;
123 let mut mem_ptr = AB::Expr::ZERO;
124 let mut next_mem_ptr = AB::Expr::ZERO;
125 for i in (0..RV32_REGISTER_NUM_LIMBS).rev() {
126 rem_words =
127 rem_words * AB::F::from_u32(1 << RV32_CELL_BITS) + local_cols.rem_words_limbs[i];
128 next_rem_words = next_rem_words * AB::F::from_u32(1 << RV32_CELL_BITS)
129 + next_cols.rem_words_limbs[i];
130 mem_ptr = mem_ptr * AB::F::from_u32(1 << RV32_CELL_BITS) + local_cols.mem_ptr_limbs[i];
131 next_mem_ptr =
132 next_mem_ptr * AB::F::from_u32(1 << RV32_CELL_BITS) + next_cols.mem_ptr_limbs[i];
133 }
134
135 builder
137 .when_transition()
138 .when(not::<AB::Expr>(is_valid.clone()))
139 .assert_zero(next_cols.is_single + next_cols.is_buffer);
140
141 builder
143 .when(local_cols.is_single)
144 .assert_one(is_end.clone());
145 builder
146 .when_first_row()
147 .assert_one(not::<AB::Expr>(local_cols.is_buffer) + local_cols.is_buffer_start);
148
149 self.memory_bridge
151 .read(
152 MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.mem_ptr_ptr),
153 local_cols.mem_ptr_limbs,
154 timestamp_pp(),
155 &local_cols.mem_ptr_aux_cols,
156 )
157 .eval(builder, is_start.clone());
158
159 self.memory_bridge
161 .read(
162 MemoryAddress::new(AB::F::from_u32(RV32_REGISTER_AS), local_cols.num_words_ptr),
163 local_cols.rem_words_limbs,
164 timestamp_pp(),
165 &local_cols.num_words_aux_cols,
166 )
167 .eval(builder, local_cols.is_buffer_start);
168
169 self.memory_bridge
171 .write(
172 MemoryAddress::new(AB::F::from_u32(RV32_MEMORY_AS), mem_ptr.clone()),
173 local_cols.data,
174 timestamp_pp(),
175 &local_cols.write_aux,
176 )
177 .eval(builder, is_valid.clone());
178 let expected_opcode = (local_cols.is_single
179 * AB::F::from_usize(HINT_STOREW as usize + self.offset))
180 + (local_cols.is_buffer * AB::F::from_usize(HINT_BUFFER as usize + self.offset));
181
182 self.execution_bridge
183 .execute_and_increment_pc(
184 expected_opcode,
185 [
186 local_cols.is_buffer * (local_cols.num_words_ptr),
187 local_cols.mem_ptr_ptr.into(),
188 AB::Expr::ZERO,
189 AB::Expr::from_u32(RV32_REGISTER_AS),
190 AB::Expr::from_u32(RV32_MEMORY_AS),
191 ],
192 local_cols.from_state,
193 rem_words.clone() * AB::F::from_usize(timestamp_delta),
194 )
195 .eval(builder, is_start.clone());
196
197 self.bitwise_operation_lookup_bus
202 .send_range(
203 local_cols.mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1]
204 * AB::F::from_usize(
205 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
206 ),
207 local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1]
208 * AB::F::from_usize(
209 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
210 ),
211 )
212 .eval(builder, is_start.clone());
213
214 for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
216 self.bitwise_operation_lookup_bus
217 .send_range(local_cols.data[2 * i], local_cols.data[(2 * i) + 1])
218 .eval(builder, is_valid.clone());
219 }
220
221 builder
227 .when(is_valid)
228 .when(is_end.clone())
229 .assert_one(rem_words.clone());
230
231 let mut when_buffer_transition = builder.when(not::<AB::Expr>(is_end.clone()));
232 when_buffer_transition.assert_one(rem_words.clone() - next_rem_words.clone());
241 when_buffer_transition.assert_eq(
248 next_mem_ptr.clone() - mem_ptr.clone(),
249 AB::F::from_usize(RV32_REGISTER_NUM_LIMBS),
250 );
251 when_buffer_transition.assert_eq(
252 timestamp + AB::F::from_usize(timestamp_delta),
253 next_cols.from_state.timestamp,
254 );
255 }
256}
257
258#[derive(Copy, Clone, Debug)]
259pub struct Rv32HintStoreMetadata {
260 num_words: usize,
261}
262
263impl MultiRowMetadata for Rv32HintStoreMetadata {
264 #[inline(always)]
265 fn get_num_rows(&self) -> usize {
266 self.num_words
267 }
268}
269
270pub type Rv32HintStoreLayout = MultiRowLayout<Rv32HintStoreMetadata>;
271
272#[repr(C)]
274#[derive(AlignedBytesBorrow, Debug)]
275pub struct Rv32HintStoreRecordHeader {
276 pub num_words: u32,
277
278 pub from_pc: u32,
279 pub timestamp: u32,
280
281 pub mem_ptr_ptr: u32,
282 pub mem_ptr: u32,
283 pub mem_ptr_aux_record: MemoryReadAuxRecord,
284
285 pub num_words_ptr: u32,
287 pub num_words_read: MemoryReadAuxRecord,
288}
289
290#[repr(C)]
292#[derive(AlignedBytesBorrow, Debug)]
293pub struct Rv32HintStoreVar {
294 pub data_write_aux: MemoryWriteBytesAuxRecord<RV32_REGISTER_NUM_LIMBS>,
295 pub data: [u8; RV32_REGISTER_NUM_LIMBS],
296}
297
298#[derive(Debug)]
302pub struct Rv32HintStoreRecordMut<'a> {
303 pub inner: &'a mut Rv32HintStoreRecordHeader,
304 pub var: &'a mut [Rv32HintStoreVar],
305}
306
307impl<'a> CustomBorrow<'a, Rv32HintStoreRecordMut<'a>, Rv32HintStoreLayout> for [u8] {
312 fn custom_borrow(&'a mut self, layout: Rv32HintStoreLayout) -> Rv32HintStoreRecordMut<'a> {
313 let (header_buf, rest) =
318 unsafe { self.split_at_mut_unchecked(size_of::<Rv32HintStoreRecordHeader>()) };
319
320 let (_, vars, _) = unsafe { rest.align_to_mut::<Rv32HintStoreVar>() };
325 Rv32HintStoreRecordMut {
326 inner: header_buf.borrow_mut(),
327 var: &mut vars[..layout.metadata.num_words],
328 }
329 }
330
331 unsafe fn extract_layout(&self) -> Rv32HintStoreLayout {
332 let header: &Rv32HintStoreRecordHeader = self.borrow();
333 MultiRowLayout::new(Rv32HintStoreMetadata {
334 num_words: header.num_words as usize,
335 })
336 }
337}
338
339impl SizedRecord<Rv32HintStoreLayout> for Rv32HintStoreRecordMut<'_> {
340 fn size(layout: &Rv32HintStoreLayout) -> usize {
341 let mut total_len = size_of::<Rv32HintStoreRecordHeader>();
342 total_len = total_len.next_multiple_of(align_of::<Rv32HintStoreVar>());
344 total_len += size_of::<Rv32HintStoreVar>() * layout.metadata.num_words;
345 total_len
346 }
347
348 fn alignment(_layout: &Rv32HintStoreLayout) -> usize {
349 align_of::<Rv32HintStoreRecordHeader>()
350 }
351}
352
353#[derive(Clone, Copy, derive_new::new)]
354pub struct Rv32HintStoreExecutor {
355 pub pointer_max_bits: usize,
356 pub offset: usize,
357}
358
359#[derive(Clone, derive_new::new)]
360pub struct Rv32HintStoreFiller {
361 pointer_max_bits: usize,
362 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
363}
364
365impl<F, RA> PreflightExecutor<F, RA> for Rv32HintStoreExecutor
366where
367 F: PrimeField32,
368 for<'buf> RA:
369 RecordArena<'buf, MultiRowLayout<Rv32HintStoreMetadata>, Rv32HintStoreRecordMut<'buf>>,
370{
371 fn get_opcode_name(&self, opcode: usize) -> String {
372 if opcode == HINT_STOREW.global_opcode().as_usize() {
373 String::from("HINT_STOREW")
374 } else if opcode == HINT_BUFFER.global_opcode().as_usize() {
375 String::from("HINT_BUFFER")
376 } else {
377 unreachable!("unsupported opcode: {opcode}")
378 }
379 }
380
381 fn execute(
382 &self,
383 state: VmStateMut<F, TracingMemory, RA>,
384 instruction: &Instruction<F>,
385 ) -> Result<(), ExecutionError> {
386 let &Instruction {
387 opcode, a, b, d, e, ..
388 } = instruction;
389
390 let a = a.as_canonical_u32();
391 let b = b.as_canonical_u32();
392 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
393 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
394
395 let local_opcode = Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.offset));
396
397 let num_words = if local_opcode == HINT_STOREW {
399 1
400 } else {
401 read_rv32_register(state.memory.data(), a)
402 };
403
404 let record = state.ctx.alloc(MultiRowLayout::new(Rv32HintStoreMetadata {
405 num_words: num_words as usize,
406 }));
407
408 record.inner.from_pc = *state.pc;
409 record.inner.timestamp = state.memory.timestamp;
410 record.inner.mem_ptr_ptr = b;
411
412 record.inner.mem_ptr = u32::from_le_bytes(tracing_read(
413 state.memory,
414 RV32_REGISTER_AS,
415 b,
416 &mut record.inner.mem_ptr_aux_record.prev_timestamp,
417 ));
418
419 debug_assert!(record.inner.mem_ptr <= (1 << self.pointer_max_bits));
420 debug_assert_ne!(num_words, 0);
421 debug_assert!(num_words <= (1 << self.pointer_max_bits));
422
423 record.inner.num_words = num_words;
424 if local_opcode == HINT_STOREW {
425 state.memory.increment_timestamp();
426 record.inner.num_words_ptr = u32::MAX;
427 } else {
428 record.inner.num_words_ptr = a;
429 tracing_read::<RV32_REGISTER_NUM_LIMBS>(
430 state.memory,
431 RV32_REGISTER_AS,
432 record.inner.num_words_ptr,
433 &mut record.inner.num_words_read.prev_timestamp,
434 );
435 };
436
437 if state.streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize {
438 return Err(ExecutionError::HintOutOfBounds { pc: *state.pc });
439 }
440
441 for idx in 0..(num_words as usize) {
442 if idx != 0 {
443 state.memory.increment_timestamp();
444 state.memory.increment_timestamp();
445 }
446
447 let data_f: [F; RV32_REGISTER_NUM_LIMBS] =
448 std::array::from_fn(|_| state.streams.hint_stream.pop_front().unwrap());
449 let data: [u8; RV32_REGISTER_NUM_LIMBS] =
450 data_f.map(|byte| byte.as_canonical_u32() as u8);
451
452 record.var[idx].data = data;
453
454 tracing_write(
455 state.memory,
456 RV32_MEMORY_AS,
457 record.inner.mem_ptr + (RV32_REGISTER_NUM_LIMBS * idx) as u32,
458 data,
459 &mut record.var[idx].data_write_aux.prev_timestamp,
460 &mut record.var[idx].data_write_aux.prev_data,
461 );
462 }
463 *state.pc = state.pc.wrapping_add(DEFAULT_PC_STEP);
464
465 Ok(())
466 }
467}
468
469impl<F: PrimeField32> TraceFiller<F> for Rv32HintStoreFiller {
470 fn fill_trace(
471 &self,
472 mem_helper: &MemoryAuxColsFactory<F>,
473 trace: &mut RowMajorMatrix<F>,
474 rows_used: usize,
475 ) {
476 if rows_used == 0 {
477 return;
478 }
479
480 let width = trace.width;
481 debug_assert_eq!(width, size_of::<Rv32HintStoreCols<u8>>());
482 let mut trace = &mut trace.values[..width * rows_used];
483 let mut sizes = Vec::with_capacity(rows_used);
484 let mut chunks = Vec::with_capacity(rows_used);
485
486 while !trace.is_empty() {
487 let record: &Rv32HintStoreRecordHeader =
492 unsafe { get_record_from_slice(&mut trace, ()) };
493 let (chunk, rest) = trace.split_at_mut(width * record.num_words as usize);
494 sizes.push(record.num_words);
495 chunks.push(chunk);
496 trace = rest;
497 }
498
499 let msl_rshift: u32 = ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS) as u32;
500 let msl_lshift: u32 =
501 (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits) as u32;
502
503 chunks
504 .par_iter_mut()
505 .zip(sizes.par_iter())
506 .for_each(|(chunk, &num_words)| {
507 let record: Rv32HintStoreRecordMut = unsafe {
514 get_record_from_slice(
515 chunk,
516 MultiRowLayout::new(Rv32HintStoreMetadata {
517 num_words: num_words as usize,
518 }),
519 )
520 };
521 self.bitwise_lookup_chip.request_range(
522 (record.inner.mem_ptr >> msl_rshift) << msl_lshift,
523 (num_words >> msl_rshift) << msl_lshift,
524 );
525
526 let mut timestamp = record.inner.timestamp + num_words * 3;
527 let mut mem_ptr = record.inner.mem_ptr + num_words * RV32_REGISTER_NUM_LIMBS as u32;
528
529 chunk
534 .rchunks_exact_mut(width)
535 .zip(record.var.iter().enumerate().rev())
536 .for_each(|(row, (idx, var))| {
537 for pair in var.data.chunks_exact(2) {
538 self.bitwise_lookup_chip
539 .request_range(pair[0] as u32, pair[1] as u32);
540 }
541
542 let cols: &mut Rv32HintStoreCols<F> = row.borrow_mut();
543 let is_single = record.inner.num_words_ptr == u32::MAX;
544 timestamp -= 3;
545 if idx == 0 && !is_single {
546 mem_helper.fill(
547 record.inner.num_words_read.prev_timestamp,
548 timestamp + 1,
549 cols.num_words_aux_cols.as_mut(),
550 );
551 cols.num_words_ptr = F::from_u32(record.inner.num_words_ptr);
552 } else {
553 mem_helper.fill_zero(cols.num_words_aux_cols.as_mut());
554 cols.num_words_ptr = F::ZERO;
555 }
556
557 cols.is_buffer_start = F::from_bool(idx == 0 && !is_single);
558
559 cols.data = var.data.map(|x| F::from_u8(x));
561
562 cols.write_aux
563 .set_prev_data(var.data_write_aux.prev_data.map(|x| F::from_u8(x)));
564 mem_helper.fill(
565 var.data_write_aux.prev_timestamp,
566 timestamp + 2,
567 cols.write_aux.as_mut(),
568 );
569
570 if idx == 0 {
571 mem_helper.fill(
572 record.inner.mem_ptr_aux_record.prev_timestamp,
573 timestamp,
574 cols.mem_ptr_aux_cols.as_mut(),
575 );
576 } else {
577 mem_helper.fill_zero(cols.mem_ptr_aux_cols.as_mut());
578 }
579
580 mem_ptr -= RV32_REGISTER_NUM_LIMBS as u32;
581 cols.mem_ptr_limbs = mem_ptr.to_le_bytes().map(|x| F::from_u8(x));
582 cols.mem_ptr_ptr = F::from_u32(record.inner.mem_ptr_ptr);
583
584 cols.from_state.timestamp = F::from_u32(timestamp);
585 cols.from_state.pc = F::from_u32(record.inner.from_pc);
586
587 cols.rem_words_limbs = (num_words - idx as u32)
588 .to_le_bytes()
589 .map(|x| F::from_u8(x));
590 cols.is_buffer = F::from_bool(!is_single);
591 cols.is_single = F::from_bool(is_single);
592 });
593 })
594 }
595}
596
597pub type Rv32HintStoreChip<F> = VmChipWrapper<F, Rv32HintStoreFiller>;