openvm_circuit/utils/
stark_utils.rs

1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3    config::{Com, Val},
4    engine::VerificationData,
5    p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8    config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9    engine::{StarkFriEngine, VerificationDataWithFriParams},
10    p3_baby_bear::BabyBear,
11};
12
13#[cfg(feature = "aot")]
14use crate::arch::{SystemConfig, VmState};
15#[cfg(feature = "aot")]
16use crate::system::memory::online::GuestMemory;
17use crate::{
18    arch::{
19        debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
20        MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
21        VmCircuitConfig, VmConfig, VmExecutionConfig,
22    },
23    system::memory::{MemoryImage, CHUNK},
24};
25
26cfg_if::cfg_if! {
27    if #[cfg(feature = "cuda")] {
28        pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
29        use crate::arch::DenseRecordArena;
30        pub type TestRecordArena = DenseRecordArena;
31    } else {
32        pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
33        use crate::arch::MatrixRecordArena;
34        pub type TestRecordArena = MatrixRecordArena<BabyBear>;
35    }
36}
37type RA = TestRecordArena;
38
39// NOTE on trait bounds: the compiler cannot figure out Val<SC>=BabyBear without the
40// VmExecutionConfig and VmCircuitConfig bounds even though VmProverBuilder already includes them.
41// The compiler also seems to need the extra VC even though VC=VB::VmConfig
42pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
43where
44    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
45    VC: VmExecutionConfig<BabyBear>
46        + VmCircuitConfig<BabyBearPoseidon2Config>
47        + VmConfig<BabyBearPoseidon2Config>,
48    <VC as VmExecutionConfig<BabyBear>>::Executor:
49        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
50{
51    air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
52}
53
54/// Executes and proves the VM and returns the final memory state.
55pub fn air_test_with_min_segments<VB, VC>(
56    builder: VB,
57    config: VC,
58    exe: impl Into<VmExe<BabyBear>>,
59    input: impl Into<Streams<BabyBear>>,
60    min_segments: usize,
61) -> Option<MemoryImage>
62where
63    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
64    VC: VmExecutionConfig<BabyBear>
65        + VmCircuitConfig<BabyBearPoseidon2Config>
66        + VmConfig<BabyBearPoseidon2Config>,
67    <VC as VmExecutionConfig<BabyBear>>::Executor:
68        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
69{
70    let mut log_blowup = 1;
71    while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
72        log_blowup += 1;
73    }
74    let fri_params = FriParameters::new_for_testing(log_blowup);
75    let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
76    let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
77        fri_params,
78        builder,
79        config,
80        exe,
81        input,
82        min_segments,
83        debug,
84    )
85    .unwrap();
86    final_memory
87}
88
89// Compares the output of the interpreter and the AOT instance for pure and metered execution
90#[cfg(feature = "aot")]
91pub fn check_aot_equivalence<E, VB>(
92    vm: &VirtualMachine<E, VB>,
93    config: &VB::VmConfig,
94    exe: &VmExe<Val<E::SC>>,
95    input: &Streams<Val<E::SC>>,
96) -> eyre::Result<()>
97where
98    E: StarkFriEngine,
99    Val<E::SC>: PrimeField32,
100    VB: VmBuilder<E>,
101    <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
102        + MeteredExecutor<Val<E::SC>>
103        + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
104    Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
105{
106    /*
107    Assertions for Pure Execution AOT
108    */
109    {
110        let interp_state_pure = vm
111            .naive_interpreter(exe)?
112            .execute(input.clone(), None)
113            .expect("Failed to execute");
114
115        let aot_state_pure = vm
116            .get_aot_instance(exe)?
117            .execute(input.clone(), None)
118            .expect("Failed to execute");
119
120        let system_config: &SystemConfig = config.as_ref();
121        let addr_spaces = &system_config.memory_config.addr_spaces;
122        let assert_vm_state_eq =
123            |lhs: &VmState<Val<E::SC>, GuestMemory>, rhs: &VmState<Val<E::SC>, GuestMemory>| {
124                assert_eq!(lhs.pc(), rhs.pc());
125                for r in 0..addr_spaces[1].num_cells {
126                    let a = unsafe { lhs.memory.read::<u8, 1>(1, r as u32) };
127                    let b = unsafe { rhs.memory.read::<u8, 1>(1, r as u32) };
128                    assert_eq!(a, b);
129                }
130            };
131        assert_vm_state_eq(&interp_state_pure, &aot_state_pure);
132    }
133
134    /*
135    Assertions for Metered AOT
136    */
137    println!("Checking metered AOT equivalence");
138    {
139        let metered_ctx = vm.build_metered_ctx(exe);
140        let (aot_segments, aot_state_metered) = vm
141            .get_metered_aot_instance(exe)?
142            .execute_metered(input.clone(), metered_ctx.clone())?;
143
144        let (segments, interp_state_metered) = vm
145            .naive_metered_interpreter(exe)?
146            .execute_metered(input.clone(), metered_ctx.clone())?;
147
148        assert_eq!(interp_state_metered.pc(), aot_state_metered.pc());
149
150        let system_config: &SystemConfig = config.as_ref();
151        let addr_spaces = &system_config.memory_config.addr_spaces;
152
153        for r in 0..addr_spaces[1].num_cells {
154            let interp = unsafe { interp_state_metered.memory.read::<u8, 1>(1, r as u32) };
155            let aot_interp = unsafe { aot_state_metered.memory.read::<u8, 1>(1, r as u32) };
156            assert_eq!(interp, aot_interp);
157        }
158
159        assert_eq!(segments.len(), aot_segments.len());
160        for i in 0..segments.len() {
161            assert_eq!(segments[i].instret_start, aot_segments[i].instret_start);
162            assert_eq!(segments[i].num_insns, aot_segments[i].num_insns);
163            assert_eq!(segments[i].trace_heights, aot_segments[i].trace_heights);
164        }
165    }
166
167    Ok(())
168}
169
170/// Executes and proves the VM and returns the final memory state.
171/// If `debug` is true, runs the debug prover.
172//
173// Same implementation as VmLocalProver, but we need to do something special to run the debug prover
174#[allow(clippy::type_complexity)]
175pub fn air_test_impl<E, VB>(
176    fri_params: FriParameters,
177    builder: VB,
178    config: VB::VmConfig,
179    exe: impl Into<VmExe<Val<E::SC>>>,
180    input: impl Into<Streams<Val<E::SC>>>,
181    min_segments: usize,
182    debug: bool,
183) -> eyre::Result<(
184    Option<MemoryImage>,
185    Vec<VerificationDataWithFriParams<E::SC>>,
186)>
187where
188    E: StarkFriEngine,
189    Val<E::SC>: PrimeField32,
190    VB: VmBuilder<E>,
191    <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
192        + MeteredExecutor<Val<E::SC>>
193        + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
194    Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
195{
196    setup_tracing();
197    let engine = E::new(fri_params);
198    let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config.clone())?;
199    let vk = pk.get_vk();
200    let exe = exe.into();
201    let input = input.into();
202    let metered_ctx = vm.build_metered_ctx(&exe);
203
204    #[cfg(feature = "aot")]
205    check_aot_equivalence(&vm, &config, &exe, &input)?;
206
207    let (segments, _) = vm
208        .metered_interpreter(&exe)?
209        .execute_metered(input.clone(), metered_ctx.clone())?;
210
211    let cached_program_trace = vm.commit_program_on_device(&exe.program);
212    vm.load_program(cached_program_trace);
213    let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
214
215    let mut state = Some(vm.create_initial_state(&exe, input));
216    let mut proofs = Vec::new();
217    let mut exit_code = None;
218    for segment in segments {
219        let Segment {
220            num_insns,
221            trace_heights,
222            ..
223        } = segment;
224        let from_state = Option::take(&mut state).unwrap();
225        vm.transport_init_memory_to_device(&from_state.memory);
226        let PreflightExecutionOutput {
227            system_records,
228            record_arenas,
229            to_state,
230        } = vm.execute_preflight(
231            &mut preflight_interpreter,
232            from_state,
233            Some(num_insns),
234            &trace_heights,
235        )?;
236        state = Some(to_state);
237        exit_code = system_records.exit_code;
238
239        let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
240        if debug {
241            debug_proving_ctx(&vm, &pk, &ctx);
242        }
243        let proof = vm.engine.prove(vm.pk(), ctx);
244        proofs.push(proof);
245    }
246    assert!(proofs.len() >= min_segments);
247    vm.verify(&vk, &proofs)
248        .expect("segment proofs should verify");
249    let state = state.unwrap();
250    let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
251    let vdata = proofs
252        .into_iter()
253        .map(|proof| VerificationDataWithFriParams {
254            data: VerificationData {
255                vk: vk.clone(),
256                proof,
257            },
258            fri_params: vm.engine.fri_params(),
259        })
260        .collect();
261
262    Ok((final_memory, vdata))
263}