openvm_circuit/utils/
stark_utils.rs

1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3    config::{Com, Val},
4    engine::VerificationData,
5    p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8    config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9    engine::{StarkFriEngine, VerificationDataWithFriParams},
10    p3_baby_bear::BabyBear,
11};
12
13use crate::{
14    arch::{
15        debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
16        MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
17        VmCircuitConfig, VmConfig, VmExecutionConfig,
18    },
19    system::memory::{MemoryImage, CHUNK},
20};
21
22cfg_if::cfg_if! {
23    if #[cfg(feature = "cuda")] {
24        pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
25        use crate::arch::DenseRecordArena;
26        pub type TestRecordArena = DenseRecordArena;
27    } else {
28        pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
29        use crate::arch::MatrixRecordArena;
30        pub type TestRecordArena = MatrixRecordArena<BabyBear>;
31    }
32}
33type RA = TestRecordArena;
34
35// NOTE on trait bounds: the compiler cannot figure out Val<SC>=BabyBear without the
36// VmExecutionConfig and VmCircuitConfig bounds even though VmProverBuilder already includes them.
37// The compiler also seems to need the extra VC even though VC=VB::VmConfig
38pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
39where
40    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
41    VC: VmExecutionConfig<BabyBear>
42        + VmCircuitConfig<BabyBearPoseidon2Config>
43        + VmConfig<BabyBearPoseidon2Config>,
44    <VC as VmExecutionConfig<BabyBear>>::Executor:
45        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
46{
47    air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
48}
49
50/// Executes and proves the VM and returns the final memory state.
51pub fn air_test_with_min_segments<VB, VC>(
52    builder: VB,
53    config: VC,
54    exe: impl Into<VmExe<BabyBear>>,
55    input: impl Into<Streams<BabyBear>>,
56    min_segments: usize,
57) -> Option<MemoryImage>
58where
59    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
60    VC: VmExecutionConfig<BabyBear>
61        + VmCircuitConfig<BabyBearPoseidon2Config>
62        + VmConfig<BabyBearPoseidon2Config>,
63    <VC as VmExecutionConfig<BabyBear>>::Executor:
64        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
65{
66    let mut log_blowup = 1;
67    while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
68        log_blowup += 1;
69    }
70    let fri_params = FriParameters::new_for_testing(log_blowup);
71    let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
72    let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
73        fri_params,
74        builder,
75        config,
76        exe,
77        input,
78        min_segments,
79        debug,
80    )
81    .unwrap();
82    final_memory
83}
84
85/// Executes and proves the VM and returns the final memory state.
86/// If `debug` is true, runs the debug prover.
87//
88// Same implementation as VmLocalProver, but we need to do something special to run the debug prover
89#[allow(clippy::type_complexity)]
90pub fn air_test_impl<E, VB>(
91    fri_params: FriParameters,
92    builder: VB,
93    config: VB::VmConfig,
94    exe: impl Into<VmExe<Val<E::SC>>>,
95    input: impl Into<Streams<Val<E::SC>>>,
96    min_segments: usize,
97    debug: bool,
98) -> eyre::Result<(
99    Option<MemoryImage>,
100    Vec<VerificationDataWithFriParams<E::SC>>,
101)>
102where
103    E: StarkFriEngine,
104    Val<E::SC>: PrimeField32,
105    VB: VmBuilder<E>,
106    <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
107        + MeteredExecutor<Val<E::SC>>
108        + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
109    Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
110{
111    setup_tracing();
112    let engine = E::new(fri_params);
113    let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config)?;
114    let vk = pk.get_vk();
115    let exe = exe.into();
116    let input = input.into();
117    let metered_ctx = vm.build_metered_ctx(&exe);
118    let (segments, _) = vm
119        .metered_interpreter(&exe)?
120        .execute_metered(input.clone(), metered_ctx)?;
121    let cached_program_trace = vm.commit_program_on_device(&exe.program);
122    vm.load_program(cached_program_trace);
123    let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
124
125    let mut state = Some(vm.create_initial_state(&exe, input));
126    let mut proofs = Vec::new();
127    let mut exit_code = None;
128    for segment in segments {
129        let Segment {
130            instret_start,
131            num_insns,
132            trace_heights,
133        } = segment;
134        assert_eq!(state.as_ref().unwrap().instret(), instret_start);
135        let from_state = Option::take(&mut state).unwrap();
136        vm.transport_init_memory_to_device(&from_state.memory);
137        let PreflightExecutionOutput {
138            system_records,
139            record_arenas,
140            to_state,
141        } = vm.execute_preflight(
142            &mut preflight_interpreter,
143            from_state,
144            Some(num_insns),
145            &trace_heights,
146        )?;
147        state = Some(to_state);
148        exit_code = system_records.exit_code;
149
150        let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
151        if debug {
152            debug_proving_ctx(&vm, &pk, &ctx);
153        }
154        let proof = vm.engine.prove(vm.pk(), ctx);
155        proofs.push(proof);
156    }
157    assert!(proofs.len() >= min_segments);
158    vm.verify(&vk, &proofs)
159        .expect("segment proofs should verify");
160    let state = state.unwrap();
161    let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
162    let vdata = proofs
163        .into_iter()
164        .map(|proof| VerificationDataWithFriParams {
165            data: VerificationData {
166                vk: vk.clone(),
167                proof,
168            },
169            fri_params: vm.engine.fri_params(),
170        })
171        .collect();
172
173    Ok((final_memory, vdata))
174}