1use std::{
2 borrow::{Borrow, BorrowMut},
3 sync::{Arc, Mutex, OnceLock},
4};
5
6use openvm_circuit::{
7 arch::{
8 ExecutionBridge, ExecutionBus, ExecutionError, ExecutionState, InstructionExecutor, Streams,
9 },
10 system::{
11 memory::{
12 offline_checker::{MemoryBridge, MemoryReadAuxCols, MemoryWriteAuxCols},
13 MemoryAddress, MemoryAuxColsFactory, MemoryController, OfflineMemory, RecordId,
14 },
15 program::ProgramBus,
16 },
17};
18use openvm_circuit_primitives::{
19 bitwise_op_lookup::{BitwiseOperationLookupBus, SharedBitwiseOperationLookupChip},
20 utils::{next_power_of_two_or_zero, not},
21};
22use openvm_circuit_primitives_derive::AlignedBorrow;
23use openvm_instructions::{
24 instruction::Instruction,
25 program::DEFAULT_PC_STEP,
26 riscv::{RV32_CELL_BITS, RV32_MEMORY_AS, RV32_REGISTER_AS, RV32_REGISTER_NUM_LIMBS},
27 LocalOpcode,
28};
29use openvm_rv32im_transpiler::{
30 Rv32HintStoreOpcode,
31 Rv32HintStoreOpcode::{HINT_BUFFER, HINT_STOREW},
32};
33use openvm_stark_backend::{
34 config::{StarkGenericConfig, Val},
35 interaction::InteractionBuilder,
36 p3_air::{Air, AirBuilder, BaseAir},
37 p3_field::{Field, FieldAlgebra, PrimeField32},
38 p3_matrix::{dense::RowMajorMatrix, Matrix},
39 prover::types::AirProofInput,
40 rap::{AnyRap, BaseAirWithPublicValues, PartitionedBaseAir},
41 Chip, ChipUsageGetter,
42};
43use serde::{Deserialize, Serialize};
44
45use crate::adapters::{compose, decompose};
46
47#[cfg(test)]
48mod tests;
49
50#[repr(C)]
51#[derive(AlignedBorrow, Debug)]
52pub struct Rv32HintStoreCols<T> {
53 pub is_single: T,
55 pub is_buffer: T,
56 pub rem_words_limbs: [T; RV32_REGISTER_NUM_LIMBS],
58
59 pub from_state: ExecutionState<T>,
60 pub mem_ptr_ptr: T,
61 pub mem_ptr_limbs: [T; RV32_REGISTER_NUM_LIMBS],
62 pub mem_ptr_aux_cols: MemoryReadAuxCols<T>,
63
64 pub write_aux: MemoryWriteAuxCols<T, RV32_REGISTER_NUM_LIMBS>,
65 pub data: [T; RV32_REGISTER_NUM_LIMBS],
66
67 pub is_buffer_start: T,
69 pub num_words_ptr: T,
70 pub num_words_aux_cols: MemoryReadAuxCols<T>,
71}
72
73#[derive(Copy, Clone, Debug)]
74pub struct Rv32HintStoreAir {
75 pub execution_bridge: ExecutionBridge,
76 pub memory_bridge: MemoryBridge,
77 pub bitwise_operation_lookup_bus: BitwiseOperationLookupBus,
78 pub offset: usize,
79 pointer_max_bits: usize,
80}
81
82impl<F: Field> BaseAir<F> for Rv32HintStoreAir {
83 fn width(&self) -> usize {
84 Rv32HintStoreCols::<F>::width()
85 }
86}
87
88impl<F: Field> BaseAirWithPublicValues<F> for Rv32HintStoreAir {}
89impl<F: Field> PartitionedBaseAir<F> for Rv32HintStoreAir {}
90
91impl<AB: InteractionBuilder> Air<AB> for Rv32HintStoreAir {
92 fn eval(&self, builder: &mut AB) {
93 let main = builder.main();
94 let local = main.row_slice(0);
95 let local_cols: &Rv32HintStoreCols<AB::Var> = (*local).borrow();
96 let next = main.row_slice(1);
97 let next_cols: &Rv32HintStoreCols<AB::Var> = (*next).borrow();
98
99 let timestamp: AB::Var = local_cols.from_state.timestamp;
100 let mut timestamp_delta: usize = 0;
101 let mut timestamp_pp = || {
102 timestamp_delta += 1;
103 timestamp + AB::Expr::from_canonical_usize(timestamp_delta - 1)
104 };
105
106 builder.assert_bool(local_cols.is_single);
107 builder.assert_bool(local_cols.is_buffer);
108 builder.assert_bool(local_cols.is_buffer_start);
109 builder
110 .when(local_cols.is_buffer_start)
111 .assert_one(local_cols.is_buffer);
112 builder.assert_bool(local_cols.is_single + local_cols.is_buffer);
113
114 let is_valid = local_cols.is_single + local_cols.is_buffer;
115 let is_start = local_cols.is_single + local_cols.is_buffer_start;
116 let is_end = not::<AB::Expr>(next_cols.is_buffer) + next_cols.is_buffer_start;
120
121 let mut rem_words = AB::Expr::ZERO;
122 let mut next_rem_words = AB::Expr::ZERO;
123 let mut mem_ptr = AB::Expr::ZERO;
124 let mut next_mem_ptr = AB::Expr::ZERO;
125 for i in (0..RV32_REGISTER_NUM_LIMBS).rev() {
126 rem_words = rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
127 + local_cols.rem_words_limbs[i];
128 next_rem_words = next_rem_words * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
129 + next_cols.rem_words_limbs[i];
130 mem_ptr = mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
131 + local_cols.mem_ptr_limbs[i];
132 next_mem_ptr = next_mem_ptr * AB::F::from_canonical_u32(1 << RV32_CELL_BITS)
133 + next_cols.mem_ptr_limbs[i];
134 }
135
136 builder
138 .when_transition()
139 .when(not::<AB::Expr>(is_valid.clone()))
140 .assert_zero(next_cols.is_single + next_cols.is_buffer);
141
142 builder
144 .when(local_cols.is_single)
145 .assert_one(is_end.clone());
146 builder
147 .when_first_row()
148 .assert_one(not::<AB::Expr>(local_cols.is_buffer) + local_cols.is_buffer_start);
149
150 self.memory_bridge
152 .read(
153 MemoryAddress::new(
154 AB::F::from_canonical_u32(RV32_REGISTER_AS),
155 local_cols.mem_ptr_ptr,
156 ),
157 local_cols.mem_ptr_limbs,
158 timestamp_pp(),
159 &local_cols.mem_ptr_aux_cols,
160 )
161 .eval(builder, is_start.clone());
162
163 self.memory_bridge
165 .read(
166 MemoryAddress::new(
167 AB::F::from_canonical_u32(RV32_REGISTER_AS),
168 local_cols.num_words_ptr,
169 ),
170 local_cols.rem_words_limbs,
171 timestamp_pp(),
172 &local_cols.num_words_aux_cols,
173 )
174 .eval(builder, local_cols.is_buffer_start);
175
176 self.memory_bridge
178 .write(
179 MemoryAddress::new(AB::F::from_canonical_u32(RV32_MEMORY_AS), mem_ptr.clone()),
180 local_cols.data,
181 timestamp_pp(),
182 &local_cols.write_aux,
183 )
184 .eval(builder, is_valid.clone());
185
186 let expected_opcode = (local_cols.is_single
187 * AB::F::from_canonical_usize(HINT_STOREW as usize + self.offset))
188 + (local_cols.is_buffer
189 * AB::F::from_canonical_usize(HINT_BUFFER as usize + self.offset));
190
191 self.execution_bridge
192 .execute_and_increment_pc(
193 expected_opcode,
194 [
195 local_cols.is_buffer * (local_cols.num_words_ptr),
196 local_cols.mem_ptr_ptr.into(),
197 AB::Expr::ZERO,
198 AB::Expr::from_canonical_u32(RV32_REGISTER_AS),
199 AB::Expr::from_canonical_u32(RV32_MEMORY_AS),
200 ],
201 local_cols.from_state,
202 rem_words.clone() * AB::F::from_canonical_usize(timestamp_delta),
203 )
204 .eval(builder, is_start.clone());
205
206 self.bitwise_operation_lookup_bus
211 .send_range(
212 local_cols.mem_ptr_limbs[RV32_REGISTER_NUM_LIMBS - 1]
213 * AB::F::from_canonical_usize(
214 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
215 ),
216 local_cols.rem_words_limbs[RV32_REGISTER_NUM_LIMBS - 1]
217 * AB::F::from_canonical_usize(
218 1 << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - self.pointer_max_bits),
219 ),
220 )
221 .eval(builder, is_start.clone());
222
223 for i in 0..RV32_REGISTER_NUM_LIMBS / 2 {
225 self.bitwise_operation_lookup_bus
226 .send_range(local_cols.data[2 * i], local_cols.data[(2 * i) + 1])
227 .eval(builder, is_valid.clone());
228 }
229
230 builder
235 .when(is_valid)
236 .when(is_end.clone())
237 .assert_one(rem_words.clone());
238
239 let mut when_buffer_transition = builder.when(not::<AB::Expr>(is_end.clone()));
240 when_buffer_transition.assert_one(rem_words.clone() - next_rem_words.clone());
247 when_buffer_transition.assert_eq(
253 next_mem_ptr.clone() - mem_ptr.clone(),
254 AB::F::from_canonical_usize(RV32_REGISTER_NUM_LIMBS),
255 );
256 when_buffer_transition.assert_eq(
257 timestamp + AB::F::from_canonical_usize(timestamp_delta),
258 next_cols.from_state.timestamp,
259 );
260 }
261}
262
263#[derive(Serialize, Deserialize)]
264#[serde(bound = "F: Field")]
265pub struct Rv32HintStoreRecord<F: Field> {
266 pub from_state: ExecutionState<u32>,
267 pub instruction: Instruction<F>,
268 pub mem_ptr_read: RecordId,
269 pub mem_ptr: u32,
270 pub num_words: u32,
271
272 pub num_words_read: Option<RecordId>,
273 pub hints: Vec<([F; RV32_REGISTER_NUM_LIMBS], RecordId)>,
274}
275
276pub struct Rv32HintStoreChip<F: Field> {
277 air: Rv32HintStoreAir,
278 pub records: Vec<Rv32HintStoreRecord<F>>,
279 pub height: usize,
280 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
281 pub streams: OnceLock<Arc<Mutex<Streams<F>>>>,
282 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
283}
284
285impl<F: PrimeField32> Rv32HintStoreChip<F> {
286 pub fn new(
287 execution_bus: ExecutionBus,
288 program_bus: ProgramBus,
289 bitwise_lookup_chip: SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
290 memory_bridge: MemoryBridge,
291 offline_memory: Arc<Mutex<OfflineMemory<F>>>,
292 pointer_max_bits: usize,
293 offset: usize,
294 ) -> Self {
295 let air = Rv32HintStoreAir {
296 execution_bridge: ExecutionBridge::new(execution_bus, program_bus),
297 memory_bridge,
298 bitwise_operation_lookup_bus: bitwise_lookup_chip.bus(),
299 offset,
300 pointer_max_bits,
301 };
302 Self {
303 records: vec![],
304 air,
305 height: 0,
306 offline_memory,
307 streams: OnceLock::new(),
308 bitwise_lookup_chip,
309 }
310 }
311 pub fn set_streams(&mut self, streams: Arc<Mutex<Streams<F>>>) {
312 self.streams.set(streams).unwrap();
313 }
314}
315
316impl<F: PrimeField32> InstructionExecutor<F> for Rv32HintStoreChip<F> {
317 fn execute(
318 &mut self,
319 memory: &mut MemoryController<F>,
320 instruction: &Instruction<F>,
321 from_state: ExecutionState<u32>,
322 ) -> Result<ExecutionState<u32>, ExecutionError> {
323 let &Instruction {
324 opcode,
325 a: num_words_ptr,
326 b: mem_ptr_ptr,
327 d,
328 e,
329 ..
330 } = instruction;
331 debug_assert_eq!(d.as_canonical_u32(), RV32_REGISTER_AS);
332 debug_assert_eq!(e.as_canonical_u32(), RV32_MEMORY_AS);
333 let local_opcode =
334 Rv32HintStoreOpcode::from_usize(opcode.local_opcode_idx(self.air.offset));
335
336 let (mem_ptr_read, mem_ptr_limbs) = memory.read::<RV32_REGISTER_NUM_LIMBS>(d, mem_ptr_ptr);
337 let (num_words, num_words_read) = if local_opcode == HINT_STOREW {
338 memory.increment_timestamp();
339 (1, None)
340 } else {
341 let (num_words_read, num_words_limbs) =
342 memory.read::<RV32_REGISTER_NUM_LIMBS>(d, num_words_ptr);
343 (compose(num_words_limbs), Some(num_words_read))
344 };
345 debug_assert_ne!(num_words, 0);
346 debug_assert!(num_words <= (1 << self.air.pointer_max_bits));
347
348 let mem_ptr = compose(mem_ptr_limbs);
349
350 debug_assert!(mem_ptr <= (1 << self.air.pointer_max_bits));
351
352 let mut streams = self.streams.get().unwrap().lock().unwrap();
353 if streams.hint_stream.len() < RV32_REGISTER_NUM_LIMBS * num_words as usize {
354 return Err(ExecutionError::HintOutOfBounds { pc: from_state.pc });
355 }
356
357 let mut record = Rv32HintStoreRecord {
358 from_state,
359 instruction: instruction.clone(),
360 mem_ptr_read,
361 mem_ptr,
362 num_words,
363 num_words_read,
364 hints: vec![],
365 };
366
367 for word_index in 0..num_words {
368 if word_index != 0 {
369 memory.increment_timestamp();
370 memory.increment_timestamp();
371 }
372
373 let data: [F; RV32_REGISTER_NUM_LIMBS] =
374 std::array::from_fn(|_| streams.hint_stream.pop_front().unwrap());
375 let (write, _) = memory.write(
376 e,
377 F::from_canonical_u32(mem_ptr + (RV32_REGISTER_NUM_LIMBS as u32 * word_index)),
378 data,
379 );
380 record.hints.push((data, write));
381 }
382
383 self.height += record.hints.len();
384 self.records.push(record);
385
386 let next_state = ExecutionState {
387 pc: from_state.pc + DEFAULT_PC_STEP,
388 timestamp: memory.timestamp(),
389 };
390 Ok(next_state)
391 }
392
393 fn get_opcode_name(&self, opcode: usize) -> String {
394 if opcode == HINT_STOREW.global_opcode().as_usize() {
395 String::from("HINT_STOREW")
396 } else if opcode == HINT_BUFFER.global_opcode().as_usize() {
397 String::from("HINT_BUFFER")
398 } else {
399 unreachable!("unsupported opcode: {}", opcode)
400 }
401 }
402}
403
404impl<F: Field> ChipUsageGetter for Rv32HintStoreChip<F> {
405 fn air_name(&self) -> String {
406 "Rv32HintStoreAir".to_string()
407 }
408
409 fn current_trace_height(&self) -> usize {
410 self.height
411 }
412
413 fn trace_width(&self) -> usize {
414 Rv32HintStoreCols::<F>::width()
415 }
416}
417
418impl<F: PrimeField32> Rv32HintStoreChip<F> {
419 fn record_to_rows(
421 record: Rv32HintStoreRecord<F>,
422 aux_cols_factory: &MemoryAuxColsFactory<F>,
423 slice: &mut [F],
424 memory: &OfflineMemory<F>,
425 bitwise_lookup_chip: &SharedBitwiseOperationLookupChip<RV32_CELL_BITS>,
426 pointer_max_bits: usize,
427 ) -> usize {
428 let width = Rv32HintStoreCols::<F>::width();
429 let cols: &mut Rv32HintStoreCols<F> = slice[..width].borrow_mut();
430
431 cols.is_single = F::from_bool(record.num_words_read.is_none());
432 cols.is_buffer = F::from_bool(record.num_words_read.is_some());
433 cols.is_buffer_start = cols.is_buffer;
434
435 cols.from_state = record.from_state.map(F::from_canonical_u32);
436 cols.mem_ptr_ptr = record.instruction.b;
437 aux_cols_factory.generate_read_aux(
438 memory.record_by_id(record.mem_ptr_read),
439 &mut cols.mem_ptr_aux_cols,
440 );
441
442 cols.num_words_ptr = record.instruction.a;
443 if let Some(num_words_read) = record.num_words_read {
444 aux_cols_factory.generate_read_aux(
445 memory.record_by_id(num_words_read),
446 &mut cols.num_words_aux_cols,
447 );
448 }
449
450 let mut mem_ptr = record.mem_ptr;
451 let mut rem_words = record.num_words;
452 let mut used_u32s = 0;
453
454 let mem_ptr_msl = mem_ptr >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS);
455 let rem_words_msl = rem_words >> ((RV32_REGISTER_NUM_LIMBS - 1) * RV32_CELL_BITS);
456 bitwise_lookup_chip.request_range(
457 mem_ptr_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits),
458 rem_words_msl << (RV32_REGISTER_NUM_LIMBS * RV32_CELL_BITS - pointer_max_bits),
459 );
460 for (i, &(data, write)) in record.hints.iter().enumerate() {
461 for half in 0..(RV32_REGISTER_NUM_LIMBS / 2) {
462 bitwise_lookup_chip.request_range(
463 data[2 * half].as_canonical_u32(),
464 data[2 * half + 1].as_canonical_u32(),
465 );
466 }
467
468 let cols: &mut Rv32HintStoreCols<F> = slice[used_u32s..used_u32s + width].borrow_mut();
469 cols.from_state.timestamp =
470 F::from_canonical_u32(record.from_state.timestamp + (3 * i as u32));
471 cols.data = data;
472 aux_cols_factory.generate_write_aux(memory.record_by_id(write), &mut cols.write_aux);
473 cols.rem_words_limbs = decompose(rem_words);
474 cols.mem_ptr_limbs = decompose(mem_ptr);
475 if i != 0 {
476 cols.is_buffer = F::ONE;
477 }
478 used_u32s += width;
479 mem_ptr += RV32_REGISTER_NUM_LIMBS as u32;
480 rem_words -= 1;
481 }
482
483 used_u32s
484 }
485
486 fn generate_trace(self) -> RowMajorMatrix<F> {
487 let width = self.trace_width();
488 let height = next_power_of_two_or_zero(self.height);
489 let mut flat_trace = F::zero_vec(width * height);
490
491 let memory = self.offline_memory.lock().unwrap();
492
493 let aux_cols_factory = memory.aux_cols_factory();
494
495 let mut used_u32s = 0;
496 for record in self.records {
497 used_u32s += Self::record_to_rows(
498 record,
499 &aux_cols_factory,
500 &mut flat_trace[used_u32s..],
501 &memory,
502 &self.bitwise_lookup_chip,
503 self.air.pointer_max_bits,
504 );
505 }
506 RowMajorMatrix::new(flat_trace, width)
508 }
509}
510
511impl<SC: StarkGenericConfig> Chip<SC> for Rv32HintStoreChip<Val<SC>>
512where
513 Val<SC>: PrimeField32,
514{
515 fn air(&self) -> Arc<dyn AnyRap<SC>> {
516 Arc::new(self.air)
517 }
518 fn generate_air_proof_input(self) -> AirProofInput<SC> {
519 AirProofInput::simple_no_pis(self.generate_trace())
520 }
521}