1use std::sync::Arc;
2
3use itertools::zip_eq;
4use openvm_circuit_primitives::var_range::{
5 SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
6};
7use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, NATIVE_AS};
8use openvm_stark_backend::{
9 config::{StarkGenericConfig, Val},
10 engine::VerificationData,
11 interaction::PermutationCheckBus,
12 p3_matrix::dense::RowMajorMatrix,
13 p3_util::log2_strict_usize,
14 prover::{
15 cpu::{CpuBackend, CpuDevice},
16 types::AirProvingContext,
17 },
18 rap::AnyRap,
19 verifier::VerificationError,
20 AirRef, Chip,
21};
22use openvm_stark_sdk::{
23 config::{
24 baby_bear_blake3::{BabyBearBlake3Config, BabyBearBlake3Engine},
25 baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
26 setup_tracing_with_log_level, FriParameters,
27 },
28 engine::{StarkEngine, StarkFriEngine},
29 p3_baby_bear::BabyBear,
30};
31use rand::{rngs::StdRng, RngCore, SeedableRng};
32use tracing::Level;
33
34use crate::{
35 arch::{
36 testing::{
37 execution::air::ExecutionDummyAir,
38 program::{air::ProgramDummyAir, ProgramTester},
39 ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS,
40 MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS,
41 },
42 vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState,
43 MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmField, VmStateMut,
44 },
45 system::{
46 memory::{
47 adapter::records::arena_size_bound,
48 offline_checker::{MemoryBridge, MemoryBus},
49 online::TracingMemory,
50 MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK,
51 },
52 poseidon2::Poseidon2PeripheryChip,
53 program::ProgramBus,
54 SystemPort,
55 },
56};
57
58pub struct VmChipTestBuilder<F: VmField> {
59 pub memory: MemoryTester<F>,
60 pub streams: Streams<F>,
61 pub rng: StdRng,
62 pub execution: ExecutionTester<F>,
63 pub program: ProgramTester<F>,
64 internal_rng: StdRng,
65 custom_pvs: Vec<Option<F>>,
66 default_register: usize,
67 default_pointer: usize,
68}
69
70impl<F> TestBuilder<F> for VmChipTestBuilder<F>
71where
72 F: VmField,
73{
74 fn execute<E, RA>(&mut self, executor: &mut E, arena: &mut RA, instruction: &Instruction<F>)
75 where
76 E: PreflightExecutor<F, RA>,
77 RA: Arena,
78 {
79 let initial_pc = self.next_elem_size_u32();
80 self.execute_with_pc(executor, arena, instruction, initial_pc);
81 }
82
83 fn execute_with_pc<E, RA>(
84 &mut self,
85 executor: &mut E,
86 arena: &mut RA,
87 instruction: &Instruction<F>,
88 initial_pc: u32,
89 ) where
90 E: PreflightExecutor<F, RA>,
91 RA: Arena,
92 {
93 let initial_state = ExecutionState {
94 pc: initial_pc,
95 timestamp: self.memory.memory.timestamp(),
96 };
97 tracing::debug!("initial_timestamp={}", self.memory.memory.timestamp());
98
99 let mut pc = initial_pc;
100 let state_mut = VmStateMut {
101 pc: &mut pc,
102 memory: &mut self.memory.memory,
103 streams: &mut self.streams,
104 rng: &mut self.rng,
105 custom_pvs: &mut self.custom_pvs,
106 ctx: arena,
107 #[cfg(feature = "metrics")]
108 metrics: &mut Default::default(),
109 };
110 executor
111 .execute(state_mut, instruction)
112 .expect("Expected the execution not to fail");
113 let final_state = ExecutionState {
114 pc,
115 timestamp: self.memory.memory.timestamp(),
116 };
117
118 self.program.execute(instruction, &initial_state);
119 self.execution.execute(initial_state, final_state);
120 }
121
122 fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
123 self.memory.read(address_space, pointer)
124 }
125
126 fn write<const N: usize>(&mut self, address_space: usize, pointer: usize, value: [F; N]) {
127 self.memory.write(address_space, pointer, value);
128 }
129
130 fn write_usize<const N: usize>(
131 &mut self,
132 address_space: usize,
133 pointer: usize,
134 value: [usize; N],
135 ) {
136 self.memory
137 .write(address_space, pointer, value.map(F::from_usize));
138 }
139
140 fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
141 self.write(address_space, pointer, [value]);
142 }
143
144 fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
145 self.read::<1>(address_space, pointer)[0]
146 }
147
148 fn address_bits(&self) -> usize {
149 self.memory.controller.memory_config().pointer_max_bits
150 }
151
152 fn last_to_pc(&self) -> F {
153 self.execution.last_to_pc()
154 }
155
156 fn last_from_pc(&self) -> F {
157 self.execution.last_from_pc()
158 }
159
160 fn execution_final_state(&self) -> ExecutionState<F> {
161 self.execution.records.last().unwrap().final_state
162 }
163
164 fn streams_mut(&mut self) -> &mut Streams<F> {
165 &mut self.streams
166 }
167
168 fn get_default_register(&mut self, increment: usize) -> usize {
169 self.default_register += increment;
170 self.default_register - increment
171 }
172
173 fn get_default_pointer(&mut self, increment: usize) -> usize {
174 self.default_pointer += increment;
175 self.default_pointer - increment
176 }
177
178 fn write_heap_pointer_default(
179 &mut self,
180 reg_increment: usize,
181 pointer_increment: usize,
182 ) -> (usize, usize) {
183 let register = self.get_default_register(reg_increment);
184 let pointer = self.get_default_pointer(pointer_increment);
185 self.write(1, register, pointer.to_le_bytes().map(F::from_u8));
186 (register, pointer)
187 }
188
189 fn write_heap_default<const NUM_LIMBS: usize>(
190 &mut self,
191 reg_increment: usize,
192 pointer_increment: usize,
193 writes: Vec<[F; NUM_LIMBS]>,
194 ) -> (usize, usize) {
195 let register = self.get_default_register(reg_increment);
196 let pointer = self.get_default_pointer(pointer_increment);
197 self.write_heap(register, pointer, writes);
198 (register, pointer)
199 }
200}
201
202impl<F: VmField> VmChipTestBuilder<F> {
203 pub fn new(
204 controller: MemoryController<F>,
205 memory: TracingMemory,
206 streams: Streams<F>,
207 rng: StdRng,
208 execution_bus: ExecutionBus,
209 program_bus: ProgramBus,
210 internal_rng: StdRng,
211 ) -> Self {
212 setup_tracing_with_log_level(Level::WARN);
213 Self {
214 memory: MemoryTester::new(controller, memory),
215 streams,
216 rng,
217 custom_pvs: Vec::new(),
218 execution: ExecutionTester::new(execution_bus),
219 program: ProgramTester::new(program_bus),
220 internal_rng,
221 default_register: 0,
222 default_pointer: 0,
223 }
224 }
225
226 fn next_elem_size_u32(&mut self) -> u32 {
227 self.internal_rng.next_u32() % (1 << (F::bits() - 2))
228 }
229
230 pub fn set_num_public_values(&mut self, num_public_values: usize) {
231 self.custom_pvs.resize(num_public_values, None);
232 }
233
234 fn write_heap<const NUM_LIMBS: usize>(
235 &mut self,
236 register: usize,
237 pointer: usize,
238 writes: Vec<[F; NUM_LIMBS]>,
239 ) {
240 self.write(1usize, register, pointer.to_le_bytes().map(F::from_u8));
241 if NUM_LIMBS.is_power_of_two() {
242 for (i, &write) in writes.iter().enumerate() {
243 self.write(2usize, pointer + i * NUM_LIMBS, write);
244 }
245 } else {
246 for (i, &write) in writes.iter().enumerate() {
247 let ptr = pointer + i * NUM_LIMBS;
248 for j in (0..NUM_LIMBS).step_by(4) {
249 self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap());
250 }
251 }
252 }
253 }
254
255 pub fn system_port(&self) -> SystemPort {
256 SystemPort {
257 execution_bus: self.execution.bus,
258 program_bus: self.program.bus,
259 memory_bridge: self.memory_bridge(),
260 }
261 }
262
263 pub fn execution_bridge(&self) -> ExecutionBridge {
264 ExecutionBridge::new(self.execution.bus, self.program.bus)
265 }
266
267 pub fn execution_bus(&self) -> ExecutionBus {
268 self.execution.bus
269 }
270
271 pub fn program_bus(&self) -> ProgramBus {
272 self.program.bus
273 }
274
275 pub fn memory_bus(&self) -> MemoryBus {
276 self.memory.controller.memory_bus
277 }
278
279 pub fn range_checker(&self) -> SharedVariableRangeCheckerChip {
280 self.memory.controller.range_checker.clone()
281 }
282
283 pub fn memory_bridge(&self) -> MemoryBridge {
284 self.memory.controller.memory_bridge()
285 }
286
287 pub fn memory_helper(&self) -> SharedMemoryHelper<F> {
288 self.memory.controller.helper()
289 }
290}
291
292pub type TestSC = BabyBearBlake3Config;
294
295impl VmChipTestBuilder<BabyBear> {
296 pub fn build(self) -> VmChipTester<TestSC> {
297 let tester = VmChipTester {
298 memory: Some(self.memory),
299 ..Default::default()
300 };
301 let tester =
302 tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
303 tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
304 }
305 pub fn build_babybear_poseidon2(self) -> VmChipTester<BabyBearPoseidon2Config> {
306 let tester = VmChipTester {
307 memory: Some(self.memory),
308 ..Default::default()
309 };
310 let tester =
311 tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
312 tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
313 }
314}
315
316impl<F: VmField> VmChipTestBuilder<F> {
317 pub fn default_persistent() -> Self {
318 let mut mem_config = MemoryConfig::default();
319 mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
320 mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
321 Self::persistent(mem_config)
322 }
323
324 pub fn default_native() -> Self {
325 Self::volatile(MemoryConfig::aggregation())
326 }
327
328 fn range_checker_and_memory(
329 mem_config: &MemoryConfig,
330 init_block_size: usize,
331 ) -> (SharedVariableRangeCheckerChip, TracingMemory) {
332 let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new(
333 RANGE_CHECKER_BUS,
334 mem_config.decomp,
335 )));
336 let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n);
337 let arena_size_bound = arena_size_bound(&vec![1 << 16; max_access_adapter_n]);
338 let memory = TracingMemory::new(mem_config, init_block_size, arena_size_bound);
339
340 (range_checker, memory)
341 }
342
343 pub fn persistent(mem_config: MemoryConfig) -> Self {
344 setup_tracing_with_log_level(Level::INFO);
345 let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, CHUNK);
346 let hasher_chip = Arc::new(Poseidon2PeripheryChip::new(
347 vm_poseidon2_config(),
348 POSEIDON2_DIRECT_BUS,
349 3,
350 ));
351 let memory_controller = MemoryController::with_persistent_memory(
352 MemoryBus::new(MEMORY_BUS),
353 mem_config,
354 range_checker,
355 PermutationCheckBus::new(MEMORY_MERKLE_BUS),
356 PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
357 hasher_chip,
358 );
359 Self {
360 memory: MemoryTester::new(memory_controller, memory),
361 streams: Default::default(),
362 rng: StdRng::seed_from_u64(0),
363 custom_pvs: Vec::new(),
364 execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
365 program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
366 internal_rng: StdRng::seed_from_u64(0),
367 default_register: 0,
368 default_pointer: 0,
369 }
370 }
371
372 pub fn volatile(mem_config: MemoryConfig) -> Self {
373 setup_tracing_with_log_level(Level::INFO);
374 let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, 1);
375 let memory_controller = MemoryController::with_volatile_memory(
376 MemoryBus::new(MEMORY_BUS),
377 mem_config,
378 range_checker,
379 );
380 Self {
381 memory: MemoryTester::new(memory_controller, memory),
382 streams: Default::default(),
383 rng: StdRng::seed_from_u64(0),
384 custom_pvs: Vec::new(),
385 execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
386 program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
387 internal_rng: StdRng::seed_from_u64(0),
388 default_register: 0,
389 default_pointer: 0,
390 }
391 }
392}
393
394impl<F: VmField> Default for VmChipTestBuilder<F> {
395 fn default() -> Self {
396 let mut mem_config = MemoryConfig::default();
397 mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
400 mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
401 Self::volatile(mem_config)
402 }
403}
404
405pub struct VmChipTester<SC: StarkGenericConfig>
406where
407 Val<SC>: VmField,
408{
409 pub memory: Option<MemoryTester<Val<SC>>>,
410 pub air_ctxs: Vec<(AirRef<SC>, AirProvingContext<CpuBackend<SC>>)>,
411}
412
413impl<SC> Default for VmChipTester<SC>
414where
415 SC: StarkGenericConfig,
416 Val<SC>: VmField,
417{
418 fn default() -> Self {
419 Self {
420 memory: None,
421 air_ctxs: vec![],
422 }
423 }
424}
425
426impl<SC> VmChipTester<SC>
427where
428 SC: StarkGenericConfig,
429 Val<SC>: VmField,
430{
431 pub fn load<E, A, C>(
432 mut self,
433 harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
434 ) -> Self
435 where
436 A: AnyRap<SC> + 'static,
437 C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
438 {
439 let arena = harness.arena;
440 let rows_used = arena.trace_offset.div_ceil(arena.width);
441 if rows_used > 0 {
442 let air = Arc::new(harness.air) as AirRef<SC>;
443 let ctx = harness.chip.generate_proving_ctx(arena);
444 tracing::debug!("Generated air proving context for {}", air.name());
445 self.air_ctxs.push((air, ctx));
446 }
447
448 self
449 }
450
451 pub fn load_periphery<A, C>(self, (air, chip): (A, C)) -> Self
452 where
453 A: AnyRap<SC> + 'static,
454 C: Chip<(), CpuBackend<SC>>,
455 {
456 let air = Arc::new(air) as AirRef<SC>;
457 self.load_periphery_ref((air, chip))
458 }
459
460 pub fn load_periphery_ref<C>(mut self, (air, chip): (AirRef<SC>, C)) -> Self
461 where
462 C: Chip<(), CpuBackend<SC>>,
463 {
464 let ctx = chip.generate_proving_ctx(());
465 tracing::debug!("Generated air proving context for {}", air.name());
466 self.air_ctxs.push((air, ctx));
467
468 self
469 }
470
471 pub fn finalize(mut self) -> Self {
472 if let Some(memory_tester) = self.memory.take() {
473 let mut memory_controller = memory_tester.controller;
474 let is_persistent = memory_controller.continuation_enabled();
475 let mut memory = memory_tester.memory;
476 let touched_memory = memory.finalize::<Val<SC>>(is_persistent);
477 let range_checker = memory_controller.range_checker.clone();
479 for mem_chip in memory_tester.chip_for_block.into_values() {
480 self = self.load_periphery((mem_chip.air, mem_chip));
481 }
482 let mem_inventory = MemoryAirInventory::new(
483 memory_controller.memory_bridge(),
484 memory_controller.memory_config(),
485 range_checker.bus(),
486 is_persistent.then_some((
487 PermutationCheckBus::new(MEMORY_MERKLE_BUS),
488 PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
489 )),
490 );
491 let ctxs = memory_controller
492 .generate_proving_ctx(memory.access_adapter_records, touched_memory);
493 for (air, ctx) in zip_eq(mem_inventory.into_airs(), ctxs)
494 .filter(|(_, ctx)| ctx.main_trace_height() > 0)
495 {
496 self.air_ctxs.push((air, ctx));
497 }
498 if let Some(hasher_chip) = memory_controller.hasher_chip {
499 let air: AirRef<SC> = match hasher_chip.as_ref() {
500 Poseidon2PeripheryChip::Register0(chip) => chip.air.clone(),
501 Poseidon2PeripheryChip::Register1(chip) => chip.air.clone(),
502 };
503 self = self.load_periphery_ref((air, hasher_chip));
504 }
505 self = self.load_periphery((range_checker.air, range_checker));
507 }
508 self
509 }
510
511 pub fn load_air_proving_ctx(
512 mut self,
513 air_proving_ctx: (AirRef<SC>, AirProvingContext<CpuBackend<SC>>),
514 ) -> Self {
515 self.air_ctxs.push(air_proving_ctx);
516 self
517 }
518
519 pub fn load_and_prank_trace<E, A, C, P>(
520 mut self,
521 harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
522 modify_trace: P,
523 ) -> Self
524 where
525 A: AnyRap<SC> + 'static,
526 C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
527 P: Fn(&mut RowMajorMatrix<Val<SC>>),
528 {
529 let arena = harness.arena;
530 let mut ctx = harness.chip.generate_proving_ctx(arena);
531 let trace: Arc<RowMajorMatrix<Val<SC>>> = Option::take(&mut ctx.common_main).unwrap();
532 let mut trace = Arc::into_inner(trace).unwrap();
533 modify_trace(&mut trace);
534 ctx.common_main = Some(Arc::new(trace));
535 self.air_ctxs.push((Arc::new(harness.air), ctx));
536 self
537 }
538
539 pub fn test<E, P: Fn() -> E>(
542 self, engine_provider: P,
544 ) -> Result<VerificationData<SC>, VerificationError>
545 where
546 E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
547 {
548 assert!(self.memory.is_none(), "Memory must be finalized");
549 let (airs, ctxs): (Vec<_>, Vec<_>) = self.air_ctxs.into_iter().unzip();
550 engine_provider().run_test_impl(airs, ctxs)
551 }
552}
553
554impl VmChipTester<BabyBearPoseidon2Config> {
555 pub fn simple_test(
556 self,
557 ) -> Result<VerificationData<BabyBearPoseidon2Config>, VerificationError> {
558 self.test(|| BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)))
559 }
560
561 pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
562 let msg = format!(
563 "Expected verification to fail with {:?}, but it didn't",
564 &expected_error
565 );
566 let result = self.simple_test();
567 assert_eq!(result.err(), Some(expected_error), "{msg}");
568 }
569}
570
571impl VmChipTester<BabyBearBlake3Config> {
572 pub fn simple_test(self) -> Result<VerificationData<BabyBearBlake3Config>, VerificationError> {
573 self.test(|| BabyBearBlake3Engine::new(FriParameters::new_for_testing(1)))
574 }
575
576 pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
577 let msg = format!(
578 "Expected verification to fail with {:?}, but it didn't",
579 &expected_error
580 );
581 let result = self.simple_test();
582 assert_eq!(result.err(), Some(expected_error), "{msg}");
583 }
584}