1use std::sync::Arc;
2
3use itertools::zip_eq;
4use openvm_circuit_primitives::var_range::{
5 SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
6};
7use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, NATIVE_AS};
8use openvm_stark_backend::{
9 config::{StarkGenericConfig, Val},
10 engine::VerificationData,
11 interaction::PermutationCheckBus,
12 p3_field::{Field, PrimeField32},
13 p3_matrix::dense::RowMajorMatrix,
14 p3_util::log2_strict_usize,
15 prover::{
16 cpu::{CpuBackend, CpuDevice},
17 types::AirProvingContext,
18 },
19 rap::AnyRap,
20 verifier::VerificationError,
21 AirRef, Chip,
22};
23use openvm_stark_sdk::{
24 config::{
25 baby_bear_blake3::{BabyBearBlake3Config, BabyBearBlake3Engine},
26 baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
27 setup_tracing_with_log_level, FriParameters,
28 },
29 engine::{StarkEngine, StarkFriEngine},
30 p3_baby_bear::BabyBear,
31};
32use rand::{rngs::StdRng, RngCore, SeedableRng};
33use tracing::Level;
34
35use crate::{
36 arch::{
37 testing::{
38 execution::air::ExecutionDummyAir,
39 program::{air::ProgramDummyAir, ProgramTester},
40 ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS,
41 MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS,
42 },
43 vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState,
44 MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmStateMut,
45 },
46 system::{
47 memory::{
48 adapter::records::arena_size_bound,
49 offline_checker::{MemoryBridge, MemoryBus},
50 online::TracingMemory,
51 MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK,
52 },
53 poseidon2::Poseidon2PeripheryChip,
54 program::ProgramBus,
55 SystemPort,
56 },
57};
58
59pub struct VmChipTestBuilder<F: Field> {
60 pub memory: MemoryTester<F>,
61 pub streams: Streams<F>,
62 pub rng: StdRng,
63 pub execution: ExecutionTester<F>,
64 pub program: ProgramTester<F>,
65 internal_rng: StdRng,
66 custom_pvs: Vec<Option<F>>,
67 default_register: usize,
68 default_pointer: usize,
69}
70
71impl<F> TestBuilder<F> for VmChipTestBuilder<F>
72where
73 F: PrimeField32,
74{
75 fn execute<E, RA>(&mut self, executor: &mut E, arena: &mut RA, instruction: &Instruction<F>)
76 where
77 E: PreflightExecutor<F, RA>,
78 RA: Arena,
79 {
80 let initial_pc = self.next_elem_size_u32();
81 self.execute_with_pc(executor, arena, instruction, initial_pc);
82 }
83
84 fn execute_with_pc<E, RA>(
85 &mut self,
86 executor: &mut E,
87 arena: &mut RA,
88 instruction: &Instruction<F>,
89 initial_pc: u32,
90 ) where
91 E: PreflightExecutor<F, RA>,
92 RA: Arena,
93 {
94 let initial_state = ExecutionState {
95 pc: initial_pc,
96 timestamp: self.memory.memory.timestamp(),
97 };
98 tracing::debug!("initial_timestamp={}", self.memory.memory.timestamp());
99
100 let mut pc = initial_pc;
101 let state_mut = VmStateMut {
102 pc: &mut pc,
103 memory: &mut self.memory.memory,
104 streams: &mut self.streams,
105 rng: &mut self.rng,
106 custom_pvs: &mut self.custom_pvs,
107 ctx: arena,
108 #[cfg(feature = "metrics")]
109 metrics: &mut Default::default(),
110 };
111 executor
112 .execute(state_mut, instruction)
113 .expect("Expected the execution not to fail");
114 let final_state = ExecutionState {
115 pc,
116 timestamp: self.memory.memory.timestamp(),
117 };
118
119 self.program.execute(instruction, &initial_state);
120 self.execution.execute(initial_state, final_state);
121 }
122
123 fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
124 self.memory.read(address_space, pointer)
125 }
126
127 fn write<const N: usize>(&mut self, address_space: usize, pointer: usize, value: [F; N]) {
128 self.memory.write(address_space, pointer, value);
129 }
130
131 fn write_usize<const N: usize>(
132 &mut self,
133 address_space: usize,
134 pointer: usize,
135 value: [usize; N],
136 ) {
137 self.memory
138 .write(address_space, pointer, value.map(F::from_canonical_usize));
139 }
140
141 fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
142 self.write(address_space, pointer, [value]);
143 }
144
145 fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
146 self.read::<1>(address_space, pointer)[0]
147 }
148
149 fn address_bits(&self) -> usize {
150 self.memory.controller.memory_config().pointer_max_bits
151 }
152
153 fn last_to_pc(&self) -> F {
154 self.execution.last_to_pc()
155 }
156
157 fn last_from_pc(&self) -> F {
158 self.execution.last_from_pc()
159 }
160
161 fn execution_final_state(&self) -> ExecutionState<F> {
162 self.execution.records.last().unwrap().final_state
163 }
164
165 fn streams_mut(&mut self) -> &mut Streams<F> {
166 &mut self.streams
167 }
168
169 fn get_default_register(&mut self, increment: usize) -> usize {
170 self.default_register += increment;
171 self.default_register - increment
172 }
173
174 fn get_default_pointer(&mut self, increment: usize) -> usize {
175 self.default_pointer += increment;
176 self.default_pointer - increment
177 }
178
179 fn write_heap_pointer_default(
180 &mut self,
181 reg_increment: usize,
182 pointer_increment: usize,
183 ) -> (usize, usize) {
184 let register = self.get_default_register(reg_increment);
185 let pointer = self.get_default_pointer(pointer_increment);
186 self.write(1, register, pointer.to_le_bytes().map(F::from_canonical_u8));
187 (register, pointer)
188 }
189
190 fn write_heap_default<const NUM_LIMBS: usize>(
191 &mut self,
192 reg_increment: usize,
193 pointer_increment: usize,
194 writes: Vec<[F; NUM_LIMBS]>,
195 ) -> (usize, usize) {
196 let register = self.get_default_register(reg_increment);
197 let pointer = self.get_default_pointer(pointer_increment);
198 self.write_heap(register, pointer, writes);
199 (register, pointer)
200 }
201}
202
203impl<F: PrimeField32> VmChipTestBuilder<F> {
204 pub fn new(
205 controller: MemoryController<F>,
206 memory: TracingMemory,
207 streams: Streams<F>,
208 rng: StdRng,
209 execution_bus: ExecutionBus,
210 program_bus: ProgramBus,
211 internal_rng: StdRng,
212 ) -> Self {
213 setup_tracing_with_log_level(Level::WARN);
214 Self {
215 memory: MemoryTester::new(controller, memory),
216 streams,
217 rng,
218 custom_pvs: Vec::new(),
219 execution: ExecutionTester::new(execution_bus),
220 program: ProgramTester::new(program_bus),
221 internal_rng,
222 default_register: 0,
223 default_pointer: 0,
224 }
225 }
226
227 fn next_elem_size_u32(&mut self) -> u32 {
228 self.internal_rng.next_u32() % (1 << (F::bits() - 2))
229 }
230
231 pub fn set_num_public_values(&mut self, num_public_values: usize) {
232 self.custom_pvs.resize(num_public_values, None);
233 }
234
235 fn write_heap<const NUM_LIMBS: usize>(
236 &mut self,
237 register: usize,
238 pointer: usize,
239 writes: Vec<[F; NUM_LIMBS]>,
240 ) {
241 self.write(
242 1usize,
243 register,
244 pointer.to_le_bytes().map(F::from_canonical_u8),
245 );
246 if NUM_LIMBS.is_power_of_two() {
247 for (i, &write) in writes.iter().enumerate() {
248 self.write(2usize, pointer + i * NUM_LIMBS, write);
249 }
250 } else {
251 for (i, &write) in writes.iter().enumerate() {
252 let ptr = pointer + i * NUM_LIMBS;
253 for j in (0..NUM_LIMBS).step_by(4) {
254 self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap());
255 }
256 }
257 }
258 }
259
260 pub fn system_port(&self) -> SystemPort {
261 SystemPort {
262 execution_bus: self.execution.bus,
263 program_bus: self.program.bus,
264 memory_bridge: self.memory_bridge(),
265 }
266 }
267
268 pub fn execution_bridge(&self) -> ExecutionBridge {
269 ExecutionBridge::new(self.execution.bus, self.program.bus)
270 }
271
272 pub fn execution_bus(&self) -> ExecutionBus {
273 self.execution.bus
274 }
275
276 pub fn program_bus(&self) -> ProgramBus {
277 self.program.bus
278 }
279
280 pub fn memory_bus(&self) -> MemoryBus {
281 self.memory.controller.memory_bus
282 }
283
284 pub fn range_checker(&self) -> SharedVariableRangeCheckerChip {
285 self.memory.controller.range_checker.clone()
286 }
287
288 pub fn memory_bridge(&self) -> MemoryBridge {
289 self.memory.controller.memory_bridge()
290 }
291
292 pub fn memory_helper(&self) -> SharedMemoryHelper<F> {
293 self.memory.controller.helper()
294 }
295}
296
297pub type TestSC = BabyBearBlake3Config;
299
300impl VmChipTestBuilder<BabyBear> {
301 pub fn build(self) -> VmChipTester<TestSC> {
302 let tester = VmChipTester {
303 memory: Some(self.memory),
304 ..Default::default()
305 };
306 let tester =
307 tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
308 tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
309 }
310 pub fn build_babybear_poseidon2(self) -> VmChipTester<BabyBearPoseidon2Config> {
311 let tester = VmChipTester {
312 memory: Some(self.memory),
313 ..Default::default()
314 };
315 let tester =
316 tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
317 tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
318 }
319}
320
321impl<F: PrimeField32> VmChipTestBuilder<F> {
322 pub fn default_persistent() -> Self {
323 let mut mem_config = MemoryConfig::default();
324 mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
325 mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
326 Self::persistent(mem_config)
327 }
328
329 pub fn default_native() -> Self {
330 Self::volatile(MemoryConfig::aggregation())
331 }
332
333 fn range_checker_and_memory(
334 mem_config: &MemoryConfig,
335 init_block_size: usize,
336 ) -> (SharedVariableRangeCheckerChip, TracingMemory) {
337 let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new(
338 RANGE_CHECKER_BUS,
339 mem_config.decomp,
340 )));
341 let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n);
342 let arena_size_bound = arena_size_bound(&vec![1 << 16; max_access_adapter_n]);
343 let memory = TracingMemory::new(mem_config, init_block_size, arena_size_bound);
344
345 (range_checker, memory)
346 }
347
348 pub fn persistent(mem_config: MemoryConfig) -> Self {
349 setup_tracing_with_log_level(Level::INFO);
350 let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, CHUNK);
351 let hasher_chip = Arc::new(Poseidon2PeripheryChip::new(
352 vm_poseidon2_config(),
353 POSEIDON2_DIRECT_BUS,
354 3,
355 ));
356 let memory_controller = MemoryController::with_persistent_memory(
357 MemoryBus::new(MEMORY_BUS),
358 mem_config,
359 range_checker,
360 PermutationCheckBus::new(MEMORY_MERKLE_BUS),
361 PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
362 hasher_chip,
363 );
364 Self {
365 memory: MemoryTester::new(memory_controller, memory),
366 streams: Default::default(),
367 rng: StdRng::seed_from_u64(0),
368 custom_pvs: Vec::new(),
369 execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
370 program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
371 internal_rng: StdRng::seed_from_u64(0),
372 default_register: 0,
373 default_pointer: 0,
374 }
375 }
376
377 pub fn volatile(mem_config: MemoryConfig) -> Self {
378 setup_tracing_with_log_level(Level::INFO);
379 let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, 1);
380 let memory_controller = MemoryController::with_volatile_memory(
381 MemoryBus::new(MEMORY_BUS),
382 mem_config,
383 range_checker,
384 );
385 Self {
386 memory: MemoryTester::new(memory_controller, memory),
387 streams: Default::default(),
388 rng: StdRng::seed_from_u64(0),
389 custom_pvs: Vec::new(),
390 execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
391 program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
392 internal_rng: StdRng::seed_from_u64(0),
393 default_register: 0,
394 default_pointer: 0,
395 }
396 }
397}
398
399impl<F: PrimeField32> Default for VmChipTestBuilder<F> {
400 fn default() -> Self {
401 let mut mem_config = MemoryConfig::default();
402 mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
405 mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
406 Self::volatile(mem_config)
407 }
408}
409
410pub struct VmChipTester<SC: StarkGenericConfig> {
411 pub memory: Option<MemoryTester<Val<SC>>>,
412 pub air_ctxs: Vec<(AirRef<SC>, AirProvingContext<CpuBackend<SC>>)>,
413}
414
415impl<SC: StarkGenericConfig> Default for VmChipTester<SC> {
416 fn default() -> Self {
417 Self {
418 memory: None,
419 air_ctxs: vec![],
420 }
421 }
422}
423
424impl<SC: StarkGenericConfig> VmChipTester<SC>
425where
426 Val<SC>: PrimeField32,
427{
428 pub fn load<E, A, C>(
429 mut self,
430 harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
431 ) -> Self
432 where
433 A: AnyRap<SC> + 'static,
434 C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
435 {
436 let arena = harness.arena;
437 let rows_used = arena.trace_offset.div_ceil(arena.width);
438 if rows_used > 0 {
439 let air = Arc::new(harness.air) as AirRef<SC>;
440 let ctx = harness.chip.generate_proving_ctx(arena);
441 tracing::debug!("Generated air proving context for {}", air.name());
442 self.air_ctxs.push((air, ctx));
443 }
444
445 self
446 }
447
448 pub fn load_periphery<A, C>(self, (air, chip): (A, C)) -> Self
449 where
450 A: AnyRap<SC> + 'static,
451 C: Chip<(), CpuBackend<SC>>,
452 {
453 let air = Arc::new(air) as AirRef<SC>;
454 self.load_periphery_ref((air, chip))
455 }
456
457 pub fn load_periphery_ref<C>(mut self, (air, chip): (AirRef<SC>, C)) -> Self
458 where
459 C: Chip<(), CpuBackend<SC>>,
460 {
461 let ctx = chip.generate_proving_ctx(());
462 tracing::debug!("Generated air proving context for {}", air.name());
463 self.air_ctxs.push((air, ctx));
464
465 self
466 }
467
468 pub fn finalize(mut self) -> Self {
469 if let Some(memory_tester) = self.memory.take() {
470 let mut memory_controller = memory_tester.controller;
471 let is_persistent = memory_controller.continuation_enabled();
472 let mut memory = memory_tester.memory;
473 let touched_memory = memory.finalize::<Val<SC>>(is_persistent);
474 let range_checker = memory_controller.range_checker.clone();
476 for mem_chip in memory_tester.chip_for_block.into_values() {
477 self = self.load_periphery((mem_chip.air, mem_chip));
478 }
479 let mem_inventory = MemoryAirInventory::new(
480 memory_controller.memory_bridge(),
481 memory_controller.memory_config(),
482 range_checker.bus(),
483 is_persistent.then_some((
484 PermutationCheckBus::new(MEMORY_MERKLE_BUS),
485 PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
486 )),
487 );
488 let ctxs = memory_controller
489 .generate_proving_ctx(memory.access_adapter_records, touched_memory);
490 for (air, ctx) in zip_eq(mem_inventory.into_airs(), ctxs)
491 .filter(|(_, ctx)| ctx.main_trace_height() > 0)
492 {
493 self.air_ctxs.push((air, ctx));
494 }
495 if let Some(hasher_chip) = memory_controller.hasher_chip {
496 let air: AirRef<SC> = match hasher_chip.as_ref() {
497 Poseidon2PeripheryChip::Register0(chip) => chip.air.clone(),
498 Poseidon2PeripheryChip::Register1(chip) => chip.air.clone(),
499 };
500 self = self.load_periphery_ref((air, hasher_chip));
501 }
502 self = self.load_periphery((range_checker.air, range_checker));
504 }
505 self
506 }
507
508 pub fn load_air_proving_ctx(
509 mut self,
510 air_proving_ctx: (AirRef<SC>, AirProvingContext<CpuBackend<SC>>),
511 ) -> Self {
512 self.air_ctxs.push(air_proving_ctx);
513 self
514 }
515
516 pub fn load_and_prank_trace<E, A, C, P>(
517 mut self,
518 harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
519 modify_trace: P,
520 ) -> Self
521 where
522 A: AnyRap<SC> + 'static,
523 C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
524 P: Fn(&mut RowMajorMatrix<Val<SC>>),
525 {
526 let arena = harness.arena;
527 let mut ctx = harness.chip.generate_proving_ctx(arena);
528 let trace: Arc<RowMajorMatrix<Val<SC>>> = Option::take(&mut ctx.common_main).unwrap();
529 let mut trace = Arc::into_inner(trace).unwrap();
530 modify_trace(&mut trace);
531 ctx.common_main = Some(Arc::new(trace));
532 self.air_ctxs.push((Arc::new(harness.air), ctx));
533 self
534 }
535
536 pub fn test<E, P: Fn() -> E>(
539 self, engine_provider: P,
541 ) -> Result<VerificationData<SC>, VerificationError>
542 where
543 E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
544 {
545 assert!(self.memory.is_none(), "Memory must be finalized");
546 let (airs, ctxs): (Vec<_>, Vec<_>) = self.air_ctxs.into_iter().unzip();
547 engine_provider().run_test_impl(airs, ctxs)
548 }
549}
550
551impl VmChipTester<BabyBearPoseidon2Config> {
552 pub fn simple_test(
553 self,
554 ) -> Result<VerificationData<BabyBearPoseidon2Config>, VerificationError> {
555 self.test(|| BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)))
556 }
557
558 pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
559 let msg = format!(
560 "Expected verification to fail with {:?}, but it didn't",
561 &expected_error
562 );
563 let result = self.simple_test();
564 assert_eq!(result.err(), Some(expected_error), "{}", msg);
565 }
566}
567
568impl VmChipTester<BabyBearBlake3Config> {
569 pub fn simple_test(self) -> Result<VerificationData<BabyBearBlake3Config>, VerificationError> {
570 self.test(|| BabyBearBlake3Engine::new(FriParameters::new_for_testing(1)))
571 }
572
573 pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
574 let msg = format!(
575 "Expected verification to fail with {:?}, but it didn't",
576 &expected_error
577 );
578 let result = self.simple_test();
579 assert_eq!(result.err(), Some(expected_error), "{}", msg);
580 }
581}