openvm_circuit/utils/
stark_utils.rs

1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3    config::{Com, Val},
4    engine::VerificationData,
5    p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8    config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9    engine::{StarkFriEngine, VerificationDataWithFriParams},
10    p3_baby_bear::BabyBear,
11};
12
13use crate::{
14    arch::{
15        debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
16        MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
17        VmCircuitConfig, VmConfig, VmExecutionConfig,
18    },
19    system::memory::{MemoryImage, CHUNK},
20};
21
22cfg_if::cfg_if! {
23    if #[cfg(feature = "cuda")] {
24        pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
25        use crate::arch::DenseRecordArena;
26        pub type TestRecordArena = DenseRecordArena;
27    } else {
28        pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
29        use crate::arch::MatrixRecordArena;
30        pub type TestRecordArena = MatrixRecordArena<BabyBear>;
31    }
32}
33type RA = TestRecordArena;
34
35// NOTE on trait bounds: the compiler cannot figure out Val<SC>=BabyBear without the
36// VmExecutionConfig and VmCircuitConfig bounds even though VmProverBuilder already includes them.
37// The compiler also seems to need the extra VC even though VC=VB::VmConfig
38pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
39where
40    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
41    VC: VmExecutionConfig<BabyBear>
42        + VmCircuitConfig<BabyBearPoseidon2Config>
43        + VmConfig<BabyBearPoseidon2Config>,
44    <VC as VmExecutionConfig<BabyBear>>::Executor:
45        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
46{
47    air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
48}
49
50/// Executes and proves the VM and returns the final memory state.
51pub fn air_test_with_min_segments<VB, VC>(
52    builder: VB,
53    config: VC,
54    exe: impl Into<VmExe<BabyBear>>,
55    input: impl Into<Streams<BabyBear>>,
56    min_segments: usize,
57) -> Option<MemoryImage>
58where
59    VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
60    VC: VmExecutionConfig<BabyBear>
61        + VmCircuitConfig<BabyBearPoseidon2Config>
62        + VmConfig<BabyBearPoseidon2Config>,
63    <VC as VmExecutionConfig<BabyBear>>::Executor:
64        Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
65{
66    let mut log_blowup = 1;
67    while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
68        log_blowup += 1;
69    }
70    let fri_params = FriParameters::new_for_testing(log_blowup);
71    #[cfg(feature = "cuda")]
72    let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
73    #[cfg(not(feature = "cuda"))]
74    let debug = true;
75    let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
76        fri_params,
77        builder,
78        config,
79        exe,
80        input,
81        min_segments,
82        debug,
83    )
84    .unwrap();
85    final_memory
86}
87
88/// Executes and proves the VM and returns the final memory state.
89/// If `debug` is true, runs the debug prover.
90//
91// Same implementation as VmLocalProver, but we need to do something special to run the debug prover
92#[allow(clippy::type_complexity)]
93pub fn air_test_impl<E, VB>(
94    fri_params: FriParameters,
95    builder: VB,
96    config: VB::VmConfig,
97    exe: impl Into<VmExe<Val<E::SC>>>,
98    input: impl Into<Streams<Val<E::SC>>>,
99    min_segments: usize,
100    debug: bool,
101) -> eyre::Result<(
102    Option<MemoryImage>,
103    Vec<VerificationDataWithFriParams<E::SC>>,
104)>
105where
106    E: StarkFriEngine,
107    Val<E::SC>: PrimeField32,
108    VB: VmBuilder<E>,
109    <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
110        + MeteredExecutor<Val<E::SC>>
111        + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
112    Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
113{
114    setup_tracing();
115    let engine = E::new(fri_params);
116    let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config)?;
117    let vk = pk.get_vk();
118    let exe = exe.into();
119    let input = input.into();
120    let metered_ctx = vm.build_metered_ctx();
121    let (segments, _) = vm
122        .metered_interpreter(&exe)?
123        .execute_metered(input.clone(), metered_ctx)?;
124    let cached_program_trace = vm.commit_program_on_device(&exe.program);
125    vm.load_program(cached_program_trace);
126    let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
127
128    let mut state = Some(vm.create_initial_state(&exe, input));
129    let mut proofs = Vec::new();
130    let mut exit_code = None;
131    for segment in segments {
132        let Segment {
133            instret_start,
134            num_insns,
135            trace_heights,
136        } = segment;
137        assert_eq!(state.as_ref().unwrap().instret, instret_start);
138        let from_state = Option::take(&mut state).unwrap();
139        vm.transport_init_memory_to_device(&from_state.memory);
140        let PreflightExecutionOutput {
141            system_records,
142            record_arenas,
143            to_state,
144        } = vm.execute_preflight(
145            &mut preflight_interpreter,
146            from_state,
147            Some(num_insns),
148            &trace_heights,
149        )?;
150        state = Some(to_state);
151        exit_code = system_records.exit_code;
152
153        let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
154        if debug {
155            debug_proving_ctx(&vm, &pk, &ctx);
156        }
157        let proof = vm.engine.prove(vm.pk(), ctx);
158        proofs.push(proof);
159    }
160    assert!(proofs.len() >= min_segments);
161    vm.verify(&vk, &proofs)
162        .expect("segment proofs should verify");
163    let state = state.unwrap();
164    let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
165    let vdata = proofs
166        .into_iter()
167        .map(|proof| VerificationDataWithFriParams {
168            data: VerificationData {
169                vk: vk.clone(),
170                proof,
171            },
172            fri_params: vm.engine.fri_params(),
173        })
174        .collect();
175
176    Ok((final_memory, vdata))
177}