openvm_circuit/arch/testing/
cpu.rs

1use std::sync::Arc;
2
3use itertools::zip_eq;
4use openvm_circuit_primitives::var_range::{
5    SharedVariableRangeCheckerChip, VariableRangeCheckerBus, VariableRangeCheckerChip,
6};
7use openvm_instructions::{instruction::Instruction, riscv::RV32_REGISTER_AS, NATIVE_AS};
8use openvm_stark_backend::{
9    config::{StarkGenericConfig, Val},
10    engine::VerificationData,
11    interaction::PermutationCheckBus,
12    p3_matrix::dense::RowMajorMatrix,
13    p3_util::log2_strict_usize,
14    prover::{
15        cpu::{CpuBackend, CpuDevice},
16        types::AirProvingContext,
17    },
18    rap::AnyRap,
19    verifier::VerificationError,
20    AirRef, Chip,
21};
22use openvm_stark_sdk::{
23    config::{
24        baby_bear_blake3::{BabyBearBlake3Config, BabyBearBlake3Engine},
25        baby_bear_poseidon2::{BabyBearPoseidon2Config, BabyBearPoseidon2Engine},
26        setup_tracing_with_log_level, FriParameters,
27    },
28    engine::{StarkEngine, StarkFriEngine},
29    p3_baby_bear::BabyBear,
30};
31use rand::{rngs::StdRng, RngCore, SeedableRng};
32use tracing::Level;
33
34use crate::{
35    arch::{
36        testing::{
37            execution::air::ExecutionDummyAir,
38            program::{air::ProgramDummyAir, ProgramTester},
39            ExecutionTester, MemoryTester, TestBuilder, TestChipHarness, EXECUTION_BUS, MEMORY_BUS,
40            MEMORY_MERKLE_BUS, POSEIDON2_DIRECT_BUS, RANGE_CHECKER_BUS, READ_INSTRUCTION_BUS,
41        },
42        vm_poseidon2_config, Arena, ExecutionBridge, ExecutionBus, ExecutionState,
43        MatrixRecordArena, MemoryConfig, PreflightExecutor, Streams, VmField, VmStateMut,
44    },
45    system::{
46        memory::{
47            adapter::records::arena_size_bound,
48            offline_checker::{MemoryBridge, MemoryBus},
49            online::TracingMemory,
50            MemoryAirInventory, MemoryController, SharedMemoryHelper, CHUNK,
51        },
52        poseidon2::Poseidon2PeripheryChip,
53        program::ProgramBus,
54        SystemPort,
55    },
56};
57
58pub struct VmChipTestBuilder<F: VmField> {
59    pub memory: MemoryTester<F>,
60    pub streams: Streams<F>,
61    pub rng: StdRng,
62    pub execution: ExecutionTester<F>,
63    pub program: ProgramTester<F>,
64    internal_rng: StdRng,
65    custom_pvs: Vec<Option<F>>,
66    default_register: usize,
67    default_pointer: usize,
68}
69
70impl<F> TestBuilder<F> for VmChipTestBuilder<F>
71where
72    F: VmField,
73{
74    fn execute<E, RA>(&mut self, executor: &mut E, arena: &mut RA, instruction: &Instruction<F>)
75    where
76        E: PreflightExecutor<F, RA>,
77        RA: Arena,
78    {
79        let initial_pc = self.next_elem_size_u32();
80        self.execute_with_pc(executor, arena, instruction, initial_pc);
81    }
82
83    fn execute_with_pc<E, RA>(
84        &mut self,
85        executor: &mut E,
86        arena: &mut RA,
87        instruction: &Instruction<F>,
88        initial_pc: u32,
89    ) where
90        E: PreflightExecutor<F, RA>,
91        RA: Arena,
92    {
93        let initial_state = ExecutionState {
94            pc: initial_pc,
95            timestamp: self.memory.memory.timestamp(),
96        };
97        tracing::debug!("initial_timestamp={}", self.memory.memory.timestamp());
98
99        let mut pc = initial_pc;
100        let state_mut = VmStateMut {
101            pc: &mut pc,
102            memory: &mut self.memory.memory,
103            streams: &mut self.streams,
104            rng: &mut self.rng,
105            custom_pvs: &mut self.custom_pvs,
106            ctx: arena,
107            #[cfg(feature = "metrics")]
108            metrics: &mut Default::default(),
109        };
110        executor
111            .execute(state_mut, instruction)
112            .expect("Expected the execution not to fail");
113        let final_state = ExecutionState {
114            pc,
115            timestamp: self.memory.memory.timestamp(),
116        };
117
118        self.program.execute(instruction, &initial_state);
119        self.execution.execute(initial_state, final_state);
120    }
121
122    fn read<const N: usize>(&mut self, address_space: usize, pointer: usize) -> [F; N] {
123        self.memory.read(address_space, pointer)
124    }
125
126    fn write<const N: usize>(&mut self, address_space: usize, pointer: usize, value: [F; N]) {
127        self.memory.write(address_space, pointer, value);
128    }
129
130    fn write_usize<const N: usize>(
131        &mut self,
132        address_space: usize,
133        pointer: usize,
134        value: [usize; N],
135    ) {
136        self.memory
137            .write(address_space, pointer, value.map(F::from_usize));
138    }
139
140    fn write_cell(&mut self, address_space: usize, pointer: usize, value: F) {
141        self.write(address_space, pointer, [value]);
142    }
143
144    fn read_cell(&mut self, address_space: usize, pointer: usize) -> F {
145        self.read::<1>(address_space, pointer)[0]
146    }
147
148    fn address_bits(&self) -> usize {
149        self.memory.controller.memory_config().pointer_max_bits
150    }
151
152    fn last_to_pc(&self) -> F {
153        self.execution.last_to_pc()
154    }
155
156    fn last_from_pc(&self) -> F {
157        self.execution.last_from_pc()
158    }
159
160    fn execution_final_state(&self) -> ExecutionState<F> {
161        self.execution.records.last().unwrap().final_state
162    }
163
164    fn streams_mut(&mut self) -> &mut Streams<F> {
165        &mut self.streams
166    }
167
168    fn get_default_register(&mut self, increment: usize) -> usize {
169        self.default_register += increment;
170        self.default_register - increment
171    }
172
173    fn get_default_pointer(&mut self, increment: usize) -> usize {
174        self.default_pointer += increment;
175        self.default_pointer - increment
176    }
177
178    fn write_heap_pointer_default(
179        &mut self,
180        reg_increment: usize,
181        pointer_increment: usize,
182    ) -> (usize, usize) {
183        let register = self.get_default_register(reg_increment);
184        let pointer = self.get_default_pointer(pointer_increment);
185        self.write(1, register, pointer.to_le_bytes().map(F::from_u8));
186        (register, pointer)
187    }
188
189    fn write_heap_default<const NUM_LIMBS: usize>(
190        &mut self,
191        reg_increment: usize,
192        pointer_increment: usize,
193        writes: Vec<[F; NUM_LIMBS]>,
194    ) -> (usize, usize) {
195        let register = self.get_default_register(reg_increment);
196        let pointer = self.get_default_pointer(pointer_increment);
197        self.write_heap(register, pointer, writes);
198        (register, pointer)
199    }
200}
201
202impl<F: VmField> VmChipTestBuilder<F> {
203    pub fn new(
204        controller: MemoryController<F>,
205        memory: TracingMemory,
206        streams: Streams<F>,
207        rng: StdRng,
208        execution_bus: ExecutionBus,
209        program_bus: ProgramBus,
210        internal_rng: StdRng,
211    ) -> Self {
212        setup_tracing_with_log_level(Level::WARN);
213        Self {
214            memory: MemoryTester::new(controller, memory),
215            streams,
216            rng,
217            custom_pvs: Vec::new(),
218            execution: ExecutionTester::new(execution_bus),
219            program: ProgramTester::new(program_bus),
220            internal_rng,
221            default_register: 0,
222            default_pointer: 0,
223        }
224    }
225
226    fn next_elem_size_u32(&mut self) -> u32 {
227        self.internal_rng.next_u32() % (1 << (F::bits() - 2))
228    }
229
230    pub fn set_num_public_values(&mut self, num_public_values: usize) {
231        self.custom_pvs.resize(num_public_values, None);
232    }
233
234    fn write_heap<const NUM_LIMBS: usize>(
235        &mut self,
236        register: usize,
237        pointer: usize,
238        writes: Vec<[F; NUM_LIMBS]>,
239    ) {
240        self.write(1usize, register, pointer.to_le_bytes().map(F::from_u8));
241        if NUM_LIMBS.is_power_of_two() {
242            for (i, &write) in writes.iter().enumerate() {
243                self.write(2usize, pointer + i * NUM_LIMBS, write);
244            }
245        } else {
246            for (i, &write) in writes.iter().enumerate() {
247                let ptr = pointer + i * NUM_LIMBS;
248                for j in (0..NUM_LIMBS).step_by(4) {
249                    self.write::<4>(2usize, ptr + j, write[j..j + 4].try_into().unwrap());
250                }
251            }
252        }
253    }
254
255    pub fn system_port(&self) -> SystemPort {
256        SystemPort {
257            execution_bus: self.execution.bus,
258            program_bus: self.program.bus,
259            memory_bridge: self.memory_bridge(),
260        }
261    }
262
263    pub fn execution_bridge(&self) -> ExecutionBridge {
264        ExecutionBridge::new(self.execution.bus, self.program.bus)
265    }
266
267    pub fn execution_bus(&self) -> ExecutionBus {
268        self.execution.bus
269    }
270
271    pub fn program_bus(&self) -> ProgramBus {
272        self.program.bus
273    }
274
275    pub fn memory_bus(&self) -> MemoryBus {
276        self.memory.controller.memory_bus
277    }
278
279    pub fn range_checker(&self) -> SharedVariableRangeCheckerChip {
280        self.memory.controller.range_checker.clone()
281    }
282
283    pub fn memory_bridge(&self) -> MemoryBridge {
284        self.memory.controller.memory_bridge()
285    }
286
287    pub fn memory_helper(&self) -> SharedMemoryHelper<F> {
288        self.memory.controller.helper()
289    }
290}
291
292// Use Blake3 as hash for faster tests.
293pub type TestSC = BabyBearBlake3Config;
294
295impl VmChipTestBuilder<BabyBear> {
296    pub fn build(self) -> VmChipTester<TestSC> {
297        let tester = VmChipTester {
298            memory: Some(self.memory),
299            ..Default::default()
300        };
301        let tester =
302            tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
303        tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
304    }
305    pub fn build_babybear_poseidon2(self) -> VmChipTester<BabyBearPoseidon2Config> {
306        let tester = VmChipTester {
307            memory: Some(self.memory),
308            ..Default::default()
309        };
310        let tester =
311            tester.load_periphery((ExecutionDummyAir::new(self.execution.bus), self.execution));
312        tester.load_periphery((ProgramDummyAir::new(self.program.bus), self.program))
313    }
314}
315
316impl<F: VmField> VmChipTestBuilder<F> {
317    pub fn default_persistent() -> Self {
318        let mut mem_config = MemoryConfig::default();
319        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
320        mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
321        Self::persistent(mem_config)
322    }
323
324    pub fn default_native() -> Self {
325        Self::volatile(MemoryConfig::aggregation())
326    }
327
328    fn range_checker_and_memory(
329        mem_config: &MemoryConfig,
330        init_block_size: usize,
331    ) -> (SharedVariableRangeCheckerChip, TracingMemory) {
332        let range_checker = Arc::new(VariableRangeCheckerChip::new(VariableRangeCheckerBus::new(
333            RANGE_CHECKER_BUS,
334            mem_config.decomp,
335        )));
336        let max_access_adapter_n = log2_strict_usize(mem_config.max_access_adapter_n);
337        let arena_size_bound = arena_size_bound(&vec![1 << 16; max_access_adapter_n]);
338        let memory = TracingMemory::new(mem_config, init_block_size, arena_size_bound);
339
340        (range_checker, memory)
341    }
342
343    pub fn persistent(mem_config: MemoryConfig) -> Self {
344        setup_tracing_with_log_level(Level::INFO);
345        let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, CHUNK);
346        let hasher_chip = Arc::new(Poseidon2PeripheryChip::new(
347            vm_poseidon2_config(),
348            POSEIDON2_DIRECT_BUS,
349            3,
350        ));
351        let memory_controller = MemoryController::with_persistent_memory(
352            MemoryBus::new(MEMORY_BUS),
353            mem_config,
354            range_checker,
355            PermutationCheckBus::new(MEMORY_MERKLE_BUS),
356            PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
357            hasher_chip,
358        );
359        Self {
360            memory: MemoryTester::new(memory_controller, memory),
361            streams: Default::default(),
362            rng: StdRng::seed_from_u64(0),
363            custom_pvs: Vec::new(),
364            execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
365            program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
366            internal_rng: StdRng::seed_from_u64(0),
367            default_register: 0,
368            default_pointer: 0,
369        }
370    }
371
372    pub fn volatile(mem_config: MemoryConfig) -> Self {
373        setup_tracing_with_log_level(Level::INFO);
374        let (range_checker, memory) = Self::range_checker_and_memory(&mem_config, 1);
375        let memory_controller = MemoryController::with_volatile_memory(
376            MemoryBus::new(MEMORY_BUS),
377            mem_config,
378            range_checker,
379        );
380        Self {
381            memory: MemoryTester::new(memory_controller, memory),
382            streams: Default::default(),
383            rng: StdRng::seed_from_u64(0),
384            custom_pvs: Vec::new(),
385            execution: ExecutionTester::new(ExecutionBus::new(EXECUTION_BUS)),
386            program: ProgramTester::new(ProgramBus::new(READ_INSTRUCTION_BUS)),
387            internal_rng: StdRng::seed_from_u64(0),
388            default_register: 0,
389            default_pointer: 0,
390        }
391    }
392}
393
394impl<F: VmField> Default for VmChipTestBuilder<F> {
395    fn default() -> Self {
396        let mut mem_config = MemoryConfig::default();
397        // TODO[jpw]: this is because old tests use `gen_pointer` on address space 1; this can be
398        // removed when tests are updated.
399        mem_config.addr_spaces[RV32_REGISTER_AS as usize].num_cells = 1 << 29;
400        mem_config.addr_spaces[NATIVE_AS as usize].num_cells = 0;
401        Self::volatile(mem_config)
402    }
403}
404
405pub struct VmChipTester<SC: StarkGenericConfig>
406where
407    Val<SC>: VmField,
408{
409    pub memory: Option<MemoryTester<Val<SC>>>,
410    pub air_ctxs: Vec<(AirRef<SC>, AirProvingContext<CpuBackend<SC>>)>,
411}
412
413impl<SC> Default for VmChipTester<SC>
414where
415    SC: StarkGenericConfig,
416    Val<SC>: VmField,
417{
418    fn default() -> Self {
419        Self {
420            memory: None,
421            air_ctxs: vec![],
422        }
423    }
424}
425
426impl<SC> VmChipTester<SC>
427where
428    SC: StarkGenericConfig,
429    Val<SC>: VmField,
430{
431    pub fn load<E, A, C>(
432        mut self,
433        harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
434    ) -> Self
435    where
436        A: AnyRap<SC> + 'static,
437        C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
438    {
439        let arena = harness.arena;
440        let rows_used = arena.trace_offset.div_ceil(arena.width);
441        if rows_used > 0 {
442            let air = Arc::new(harness.air) as AirRef<SC>;
443            let ctx = harness.chip.generate_proving_ctx(arena);
444            tracing::debug!("Generated air proving context for {}", air.name());
445            self.air_ctxs.push((air, ctx));
446        }
447
448        self
449    }
450
451    pub fn load_periphery<A, C>(self, (air, chip): (A, C)) -> Self
452    where
453        A: AnyRap<SC> + 'static,
454        C: Chip<(), CpuBackend<SC>>,
455    {
456        let air = Arc::new(air) as AirRef<SC>;
457        self.load_periphery_ref((air, chip))
458    }
459
460    pub fn load_periphery_ref<C>(mut self, (air, chip): (AirRef<SC>, C)) -> Self
461    where
462        C: Chip<(), CpuBackend<SC>>,
463    {
464        let ctx = chip.generate_proving_ctx(());
465        tracing::debug!("Generated air proving context for {}", air.name());
466        self.air_ctxs.push((air, ctx));
467
468        self
469    }
470
471    pub fn finalize(mut self) -> Self {
472        if let Some(memory_tester) = self.memory.take() {
473            let mut memory_controller = memory_tester.controller;
474            let is_persistent = memory_controller.continuation_enabled();
475            let mut memory = memory_tester.memory;
476            let touched_memory = memory.finalize::<Val<SC>>(is_persistent);
477            // Balance memory boundaries
478            let range_checker = memory_controller.range_checker.clone();
479            for mem_chip in memory_tester.chip_for_block.into_values() {
480                self = self.load_periphery((mem_chip.air, mem_chip));
481            }
482            let mem_inventory = MemoryAirInventory::new(
483                memory_controller.memory_bridge(),
484                memory_controller.memory_config(),
485                range_checker.bus(),
486                is_persistent.then_some((
487                    PermutationCheckBus::new(MEMORY_MERKLE_BUS),
488                    PermutationCheckBus::new(POSEIDON2_DIRECT_BUS),
489                )),
490            );
491            let ctxs = memory_controller
492                .generate_proving_ctx(memory.access_adapter_records, touched_memory);
493            for (air, ctx) in zip_eq(mem_inventory.into_airs(), ctxs)
494                .filter(|(_, ctx)| ctx.main_trace_height() > 0)
495            {
496                self.air_ctxs.push((air, ctx));
497            }
498            if let Some(hasher_chip) = memory_controller.hasher_chip {
499                let air: AirRef<SC> = match hasher_chip.as_ref() {
500                    Poseidon2PeripheryChip::Register0(chip) => chip.air.clone(),
501                    Poseidon2PeripheryChip::Register1(chip) => chip.air.clone(),
502                };
503                self = self.load_periphery_ref((air, hasher_chip));
504            }
505            // this must be last because other trace generation mutates its state
506            self = self.load_periphery((range_checker.air, range_checker));
507        }
508        self
509    }
510
511    pub fn load_air_proving_ctx(
512        mut self,
513        air_proving_ctx: (AirRef<SC>, AirProvingContext<CpuBackend<SC>>),
514    ) -> Self {
515        self.air_ctxs.push(air_proving_ctx);
516        self
517    }
518
519    pub fn load_and_prank_trace<E, A, C, P>(
520        mut self,
521        harness: TestChipHarness<Val<SC>, E, A, C, MatrixRecordArena<Val<SC>>>,
522        modify_trace: P,
523    ) -> Self
524    where
525        A: AnyRap<SC> + 'static,
526        C: Chip<MatrixRecordArena<Val<SC>>, CpuBackend<SC>>,
527        P: Fn(&mut RowMajorMatrix<Val<SC>>),
528    {
529        let arena = harness.arena;
530        let mut ctx = harness.chip.generate_proving_ctx(arena);
531        let trace: Arc<RowMajorMatrix<Val<SC>>> = Option::take(&mut ctx.common_main).unwrap();
532        let mut trace = Arc::into_inner(trace).unwrap();
533        modify_trace(&mut trace);
534        ctx.common_main = Some(Arc::new(trace));
535        self.air_ctxs.push((Arc::new(harness.air), ctx));
536        self
537    }
538
539    /// Given a function to produce an engine from the max trace height,
540    /// runs a simple test on that engine
541    pub fn test<E, P: Fn() -> E>(
542        self, // do no take ownership so it's easier to prank
543        engine_provider: P,
544    ) -> Result<VerificationData<SC>, VerificationError>
545    where
546        E: StarkEngine<SC = SC, PB = CpuBackend<SC>, PD = CpuDevice<SC>>,
547    {
548        assert!(self.memory.is_none(), "Memory must be finalized");
549        let (airs, ctxs): (Vec<_>, Vec<_>) = self.air_ctxs.into_iter().unzip();
550        engine_provider().run_test_impl(airs, ctxs)
551    }
552}
553
554impl VmChipTester<BabyBearPoseidon2Config> {
555    pub fn simple_test(
556        self,
557    ) -> Result<VerificationData<BabyBearPoseidon2Config>, VerificationError> {
558        self.test(|| BabyBearPoseidon2Engine::new(FriParameters::new_for_testing(1)))
559    }
560
561    pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
562        let msg = format!(
563            "Expected verification to fail with {:?}, but it didn't",
564            &expected_error
565        );
566        let result = self.simple_test();
567        assert_eq!(result.err(), Some(expected_error), "{msg}");
568    }
569}
570
571impl VmChipTester<BabyBearBlake3Config> {
572    pub fn simple_test(self) -> Result<VerificationData<BabyBearBlake3Config>, VerificationError> {
573        self.test(|| BabyBearBlake3Engine::new(FriParameters::new_for_testing(1)))
574    }
575
576    pub fn simple_test_with_expected_error(self, expected_error: VerificationError) {
577        let msg = format!(
578            "Expected verification to fail with {:?}, but it didn't",
579            &expected_error
580        );
581        let result = self.simple_test();
582        assert_eq!(result.err(), Some(expected_error), "{msg}");
583    }
584}