openvm_circuit/arch/testing/
cpu.rs

1use std::sync::Arc;
2
3use itertools::zip_eq;
4use openvm_circuit_primitives::var_range::{
5    SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
6};
7use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, NATIVE_AS};
8use openvm_stark_backend::{
9    config::{StarkGenericConfig, Val},
10    engine::VerificationData,
11    interaction::PermutationCheckBus,
12    p3_field::{Field, PrimeField32},
13    p3_matrix::dense::RowMajorMatrix,
14    p3_util::log2_strict_usize,
15    prover::{
16        cpu::{CpuBackend, CpuDevice},
17        types::AirProvingContext,
18    },
19    rap::AnyRap,
20    verifier::VerificationError,
21    AirRef, Chip,
22};
23use openvm_stark_sdk::{
24    config::{
25        baby_bear_blake3::{BabyBearBlake3Config, BabyBearBlake3Engine},
26        baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
27        setup_tracing_with_log_level, FriParameters,
28    },
29    engine::{StarkEngine, StarkFriEngine},
30    p3_baby_bear::BabyBear,
31};
32use rand::{rngs::StdRng, RngCore, SeedableRng};
33use tracing::Level;
34
35use crate::{
36    arch::{
37        testing::{
38            execution::air::ExecutionDummyAir,
39            program::{air::ProgramDummyAir, ProgramTester},
40            ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS,
41            MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS,
42        },
43        vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState,
44        MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmStateMut,
45    },
46    system::{
47        memory::{
48            adapter::records::arena_size_bound,
49            offline_checker::{MemoryBridge, MemoryBus},
50            online::TracingMemory,
51            MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK,
52        },
53        poseidon2::Poseidon2PeripheryChip,
54        program::ProgramBus,
55        SystemPort,
56    },
57};
58
59pub struct VmChipTestBuilder<F: Field> {
60    pub memory: MemoryTester<F>,
61    pub streams: Streams<F>,
62    pub rng: StdRng,
63    pub execution: ExecutionTester<F>,
64    pub program: ProgramTester<F>,
65    internal_rng: StdRng,
66    custom_pvs: Vec<Option<F>>,
67    default_register: usize,
68    default_pointer: usize,
69}
70
71impl<F> TestBuilder<F> for VmChipTestBuilder<F>
72where
73    F: PrimeField32,
74{
75    fn execute<E, RA>(&mut self, executor: &mut E, arena: &mut RA, instruction: &Instruction<F>)
76    where
77        E: PreflightExecutor<F, RA>,
78        RA: Arena,
79    {
80        let initial_pc = self.next_elem_size_u32();
81        self.execute_with_pc(executor, arena, instruction, initial_pc);
82    }
83
84    fn execute_with_pc<E, RA>(
85        &mut self,
86        executor: &mut E,
87        arena: &mut RA,
88        instruction: &Instruction<F>,
89        initial_pc: u32,
90    ) where
91        E: PreflightExecutor<F, RA>,
92        RA: Arena,
93    {
94        let initial_state = ExecutionState {
95            pc: initial_pc,
96            timestamp: self.memory.memory.timestamp(),
97        };
98        tracing::debug!("initial_timestamp={}", self.memory.memory.timestamp());
99
100        let mut pc = initial_pc;
101        let state_mut = VmStateMut {
102            pc: &mut pc,
103            memory: &mut self.memory.memory,
104            streams: &mut self.streams,
105            rng: &mut self.rng,
106            custom_pvs: &mut self.custom_pvs,
107            ctx: arena,
108            #[cfg(feature = "metrics")]
109            metrics: &mut Default::default(),
110        };
111        executor
112            .execute(state_mut, instruction)
113            .expect("Expected the execution not to fail");
114        let final_state = ExecutionState {
115            pc,
116            timestamp: self.memory.memory.timestamp(),
117        };
118
119        self.program.execute(instruction, &initial_state);
120        self.execution.execute(initial_state, final_state);
121    }
122
123    fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
124        self.memory.read(address_space, pointer)
125    }
126
127    fn write<const N: usize>(&mut self, address_space: usize, pointer: usize, value: [F; N]) {
128        self.memory.write(address_space, pointer, value);
129    }
130
131    fn write_usize<const N: usize>(
132        &mut self,
133        address_space: usize,
134        pointer: usize,
135        value: [usize; N],
136    ) {
137        self.memory
138            .write(address_space, pointer, value.map(F::from_canonical_usize));
139    }
140
141    fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
142        self.write(address_space, pointer, [value]);
143    }
144
145    fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
146        self.read::<1>(address_space, pointer)[0]
147    }
148
149    fn address_bits(&self) -> usize {
150        self.memory.controller.memory_config().pointer_max_bits
151    }
152
153    fn last_to_pc(&self) -> F {
154        self.execution.last_to_pc()
155    }
156
157    fn last_from_pc(&self) -> F {
158        self.execution.last_from_pc()
159    }
160
161    fn execution_final_state(&self) -> ExecutionState<F> {
162        self.execution.records.last().unwrap().final_state
163    }
164
165    fn streams_mut(&mut self) -> &mut Streams<F> {
166        &mut self.streams
167    }
168
169    fn get_default_register(&mut self, increment: usize) -> usize {
170        self.default_register += increment;
171        self.default_register - increment
172    }
173
174    fn get_default_pointer(&mut self, increment: usize) -> usize {
175        self.default_pointer += increment;
176        self.default_pointer - increment
177    }
178
179    fn write_heap_pointer_default(
180        &mut self,
181        reg_increment: usize,
182        pointer_increment: usize,
183    ) -> (usize, usize) {
184        let register = self.get_default_register(reg_increment);
185        let pointer = self.get_default_pointer(pointer_increment);
186        self.write(1, register, pointer.to_le_bytes().map(F::from_canonical_u8));
187        (register, pointer)
188    }
189
190    fn write_heap_default<const NUM_LIMBS: usize>(
191        &mut self,
192        reg_increment: usize,
193        pointer_increment: usize,
194        writes: Vec<[F; NUM_LIMBS]>,
195    ) -> (usize, usize) {
196        let register = self.get_default_register(reg_increment);
197        let pointer = self.get_default_pointer(pointer_increment);
198        self.write_heap(register, pointer, writes);
199        (register, pointer)
200    }
201}
202
203impl<F: PrimeField32> VmChipTestBuilder<F> {
204    pub fn new(
205        controller: MemoryController<F>,
206        memory: TracingMemory,
207        streams: Streams<F>,
208        rng: StdRng,
209        execution_bus: ExecutionBus,
210        program_bus: ProgramBus,
211        internal_rng: StdRng,
212    ) -> Self {
213        setup_tracing_with_log_level(Level::WARN);
214        Self {
215            memory: MemoryTester::new(controller, memory),
216            streams,
217            rng,
218            custom_pvs: Vec::new(),
219            execution: ExecutionTester::new(execution_bus),
220            program: ProgramTester::new(program_bus),
221            internal_rng,
222            default_register: 0,
223            default_pointer: 0,
224        }
225    }
226
227    fn next_elem_size_u32(&mut self) -> u32 {
228        self.internal_rng.next_u32() % (1 << (F::bits() - 2))
229    }
230
231    pub fn set_num_public_values(&mut self, num_public_values: usize) {
232        self.custom_pvs.resize(num_public_values, None);
233    }
234
235    fn write_heap<const NUM_LIMBS: usize>(
236        &mut self,
237        register: usize,
238        pointer: usize,
239        writes: Vec<[F; NUM_LIMBS]>,
240    ) {
241        self.write(
242            1usize,
243            register,
244            pointer.to_le_bytes().map(F::from_canonical_u8),
245        );
246        if NUM_LIMBS.is_power_of_two() {
247            for (i, &write) in writes.iter().enumerate() {
248                self.write(2usize, pointer + i * NUM_LIMBS, write);
249            }
250        } else {
251            for (i, &write) in writes.iter().enumerate() {
252                let ptr = pointer + i * NUM_LIMBS;
253                for j in (0..NUM_LIMBS).step_by(4) {
254                    self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap());
255                }
256            }
257        }
258    }
259
260    pub fn system_port(&self) -> SystemPort {
261        SystemPort {
262            execution_bus: self.execution.bus,
263            program_bus: self.program.bus,
264            memory_bridge: self.memory_bridge(),
265        }
266    }
267
268    pub fn execution_bridge(&self) -> ExecutionBridge {
269        ExecutionBridge::new(self.execution.bus, self.program.bus)
270    }
271
272    pub fn execution_bus(&self) -> ExecutionBus {
273        self.execution.bus
274    }
275
276    pub fn program_bus(&self) -> ProgramBus {
277        self.program.bus
278    }
279
280    pub fn memory_bus(&self) -> MemoryBus {
281        self.memory.controller.memory_bus
282    }
283
284    pub fn range_checker(&self) -> SharedVariableRangeCheckerChip {
285        self.memory.controller.range_checker.clone()
286    }
287
288    pub fn memory_bridge(&self) -> MemoryBridge {
289        self.memory.controller.memory_bridge()
290    }
291
292    pub fn memory_helper(&self) -> SharedMemoryHelper<F> {
293        self.memory.controller.helper()
294    }
295}
296
297// Use Blake3 as hash for faster tests.
298pub type TestSC = BabyBearBlake3Config;
299
300impl VmChipTestBuilder<BabyBear> {
301    pub fn build(self) -> VmChipTester<TestSC> {
302        let tester = VmChipTester {
303            memory: Some(self.memory),
304            ..Default::default()
305        };
306        let tester =
307            tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
308        tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
309    }
310    pub fn build_babybear_poseidon2(self) -> VmChipTester<BabyBearPoseidon2Config> {
311        let tester = VmChipTester {
312            memory: Some(self.memory),
313            ..Default::default()
314        };
315        let tester =
316            tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
317        tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
318    }
319}
320
321impl<F: PrimeField32> VmChipTestBuilder<F> {
322    pub fn default_persistent() -> Self {
323        let mut mem_config = MemoryConfig::default();
324        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
325        mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
326        Self::persistent(mem_config)
327    }
328
329    pub fn default_native() -> Self {
330        Self::volatile(MemoryConfig::aggregation())
331    }
332
333    fn range_checker_and_memory(
334        mem_config: &MemoryConfig,
335        init_block_size: usize,
336    ) -> (SharedVariableRangeCheckerChip, TracingMemory) {
337        let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new(
338            RANGE_CHECKER_BUS,
339            mem_config.decomp,
340        )));
341        let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n);
342        let arena_size_bound = arena_size_bound(&vec![1 << 16; max_access_adapter_n]);
343        let memory = TracingMemory::new(mem_config, init_block_size, arena_size_bound);
344
345        (range_checker, memory)
346    }
347
348    pub fn persistent(mem_config: MemoryConfig) -> Self {
349        setup_tracing_with_log_level(Level::INFO);
350        let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, CHUNK);
351        let hasher_chip = Arc::new(Poseidon2PeripheryChip::new(
352            vm_poseidon2_config(),
353            POSEIDON2_DIRECT_BUS,
354            3,
355        ));
356        let memory_controller = MemoryController::with_persistent_memory(
357            MemoryBus::new(MEMORY_BUS),
358            mem_config,
359            range_checker,
360            PermutationCheckBus::new(MEMORY_MERKLE_BUS),
361            PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
362            hasher_chip,
363        );
364        Self {
365            memory: MemoryTester::new(memory_controller, memory),
366            streams: Default::default(),
367            rng: StdRng::seed_from_u64(0),
368            custom_pvs: Vec::new(),
369            execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
370            program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
371            internal_rng: StdRng::seed_from_u64(0),
372            default_register: 0,
373            default_pointer: 0,
374        }
375    }
376
377    pub fn volatile(mem_config: MemoryConfig) -> Self {
378        setup_tracing_with_log_level(Level::INFO);
379        let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, 1);
380        let memory_controller = MemoryController::with_volatile_memory(
381            MemoryBus::new(MEMORY_BUS),
382            mem_config,
383            range_checker,
384        );
385        Self {
386            memory: MemoryTester::new(memory_controller, memory),
387            streams: Default::default(),
388            rng: StdRng::seed_from_u64(0),
389            custom_pvs: Vec::new(),
390            execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
391            program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
392            internal_rng: StdRng::seed_from_u64(0),
393            default_register: 0,
394            default_pointer: 0,
395        }
396    }
397}
398
399impl<F: PrimeField32> Default for VmChipTestBuilder<F> {
400    fn default() -> Self {
401        let mut mem_config = MemoryConfig::default();
402        // TODO[jpw]: this is because old tests use `gen_pointer` on address space 1; this can be
403        // removed when tests are updated.
404        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
405        mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
406        Self::volatile(mem_config)
407    }
408}
409
410pub struct VmChipTester<SC: StarkGenericConfig> {
411    pub memory: Option<MemoryTester<Val<SC>>>,
412    pub air_ctxs: Vec<(AirRef<SC>, AirProvingContext<CpuBackend<SC>>)>,
413}
414
415impl<SC: StarkGenericConfig> Default for VmChipTester<SC> {
416    fn default() -> Self {
417        Self {
418            memory: None,
419            air_ctxs: vec![],
420        }
421    }
422}
423
424impl<SC: StarkGenericConfig> VmChipTester<SC>
425where
426    Val<SC>: PrimeField32,
427{
428    pub fn load<E, A, C>(
429        mut self,
430        harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
431    ) -> Self
432    where
433        A: AnyRap<SC> + 'static,
434        C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
435    {
436        let arena = harness.arena;
437        let rows_used = arena.trace_offset.div_ceil(arena.width);
438        if rows_used > 0 {
439            let air = Arc::new(harness.air) as AirRef<SC>;
440            let ctx = harness.chip.generate_proving_ctx(arena);
441            tracing::debug!("Generated air proving context for {}", air.name());
442            self.air_ctxs.push((air, ctx));
443        }
444
445        self
446    }
447
448    pub fn load_periphery<A, C>(self, (air, chip): (A, C)) -> Self
449    where
450        A: AnyRap<SC> + 'static,
451        C: Chip<(), CpuBackend<SC>>,
452    {
453        let air = Arc::new(air) as AirRef<SC>;
454        self.load_periphery_ref((air, chip))
455    }
456
457    pub fn load_periphery_ref<C>(mut self, (air, chip): (AirRef<SC>, C)) -> Self
458    where
459        C: Chip<(), CpuBackend<SC>>,
460    {
461        let ctx = chip.generate_proving_ctx(());
462        tracing::debug!("Generated air proving context for {}", air.name());
463        self.air_ctxs.push((air, ctx));
464
465        self
466    }
467
468    pub fn finalize(mut self) -> Self {
469        if let Some(memory_tester) = self.memory.take() {
470            let mut memory_controller = memory_tester.controller;
471            let is_persistent = memory_controller.continuation_enabled();
472            let mut memory = memory_tester.memory;
473            let touched_memory = memory.finalize::<Val<SC>>(is_persistent);
474            // Balance memory boundaries
475            let range_checker = memory_controller.range_checker.clone();
476            for mem_chip in memory_tester.chip_for_block.into_values() {
477                self = self.load_periphery((mem_chip.air, mem_chip));
478            }
479            let mem_inventory = MemoryAirInventory::new(
480                memory_controller.memory_bridge(),
481                memory_controller.memory_config(),
482                range_checker.bus(),
483                is_persistent.then_some((
484                    PermutationCheckBus::new(MEMORY_MERKLE_BUS),
485                    PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
486                )),
487            );
488            let ctxs = memory_controller
489                .generate_proving_ctx(memory.access_adapter_records, touched_memory);
490            for (air, ctx) in zip_eq(mem_inventory.into_airs(), ctxs)
491                .filter(|(_, ctx)| ctx.main_trace_height() > 0)
492            {
493                self.air_ctxs.push((air, ctx));
494            }
495            if let Some(hasher_chip) = memory_controller.hasher_chip {
496                let air: AirRef<SC> = match hasher_chip.as_ref() {
497                    Poseidon2PeripheryChip::Register0(chip) => chip.air.clone(),
498                    Poseidon2PeripheryChip::Register1(chip) => chip.air.clone(),
499                };
500                self = self.load_periphery_ref((air, hasher_chip));
501            }
502            // this must be last because other trace generation mutates its state
503            self = self.load_periphery((range_checker.air, range_checker));
504        }
505        self
506    }
507
508    pub fn load_air_proving_ctx(
509        mut self,
510        air_proving_ctx: (AirRef<SC>, AirProvingContext<CpuBackend<SC>>),
511    ) -> Self {
512        self.air_ctxs.push(air_proving_ctx);
513        self
514    }
515
516    pub fn load_and_prank_trace<E, A, C, P>(
517        mut self,
518        harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
519        modify_trace: P,
520    ) -> Self
521    where
522        A: AnyRap<SC> + 'static,
523        C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
524        P: Fn(&mut RowMajorMatrix<Val<SC>>),
525    {
526        let arena = harness.arena;
527        let mut ctx = harness.chip.generate_proving_ctx(arena);
528        let trace: Arc<RowMajorMatrix<Val<SC>>> = Option::take(&mut ctx.common_main).unwrap();
529        let mut trace = Arc::into_inner(trace).unwrap();
530        modify_trace(&mut trace);
531        ctx.common_main = Some(Arc::new(trace));
532        self.air_ctxs.push((Arc::new(harness.air), ctx));
533        self
534    }
535
536    /// Given a function to produce an engine from the max trace height,
537    /// runs a simple test on that engine
538    pub fn test<E, P: Fn() -> E>(
539        self, // do no take ownership so it's easier to prank
540        engine_provider: P,
541    ) -> Result<VerificationData<SC>, VerificationError>
542    where
543        E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
544    {
545        assert!(self.memory.is_none(), "Memory must be finalized");
546        let (airs, ctxs): (Vec<_>, Vec<_>) = self.air_ctxs.into_iter().unzip();
547        engine_provider().run_test_impl(airs, ctxs)
548    }
549}
550
551impl VmChipTester<BabyBearPoseidon2Config> {
552    pub fn simple_test(
553        self,
554    ) -> Result<VerificationData<BabyBearPoseidon2Config>, VerificationError> {
555        self.test(|| BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)))
556    }
557
558    pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
559        let msg = format!(
560            "Expected verification to fail with {:?}, but it didn't",
561            &expected_error
562        );
563        let result = self.simple_test();
564        assert_eq!(result.err(), Some(expected_error), "{}", msg);
565    }
566}
567
568impl VmChipTester<BabyBearBlake3Config> {
569    pub fn simple_test(self) -> Result<VerificationData<BabyBearBlake3Config>, VerificationError> {
570        self.test(|| BabyBearBlake3Engine::new(FriParameters::new_for_testing(1)))
571    }
572
573    pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
574        let msg = format!(
575            "Expected verification to fail with {:?}, but it didn't",
576            &expected_error
577        );
578        let result = self.simple_test();
579        assert_eq!(result.err(), Some(expected_error), "{}", msg);
580    }
581}