1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3 config::{Com, Val},
4 engine::VerificationData,
5 p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8 config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9 engine::{StarkFriEngine, VerificationDataWithFriParams},
10 p3_baby_bear::BabyBear,
11};
12
13use crate::{
14 arch::{
15 debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
16 MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
17 VmCircuitConfig, VmConfig, VmExecutionConfig,
18 },
19 system::memory::{MemoryImage, CHUNK},
20};
21
22cfg_if::cfg_if! {
23 if #[cfg(feature = "cuda")] {
24 pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
25 use crate::arch::DenseRecordArena;
26 pub type TestRecordArena = DenseRecordArena;
27 } else {
28 pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
29 use crate::arch::MatrixRecordArena;
30 pub type TestRecordArena = MatrixRecordArena<BabyBear>;
31 }
32}
33type RA = TestRecordArena;
34
35pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
39where
40 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
41 VC: VmExecutionConfig<BabyBear>
42 + VmCircuitConfig<BabyBearPoseidon2Config>
43 + VmConfig<BabyBearPoseidon2Config>,
44 <VC as VmExecutionConfig<BabyBear>>::Executor:
45 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
46{
47 air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
48}
49
50pub fn air_test_with_min_segments<VB, VC>(
52 builder: VB,
53 config: VC,
54 exe: impl Into<VmExe<BabyBear>>,
55 input: impl Into<Streams<BabyBear>>,
56 min_segments: usize,
57) -> Option<MemoryImage>
58where
59 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
60 VC: VmExecutionConfig<BabyBear>
61 + VmCircuitConfig<BabyBearPoseidon2Config>
62 + VmConfig<BabyBearPoseidon2Config>,
63 <VC as VmExecutionConfig<BabyBear>>::Executor:
64 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
65{
66 let mut log_blowup = 1;
67 while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
68 log_blowup += 1;
69 }
70 let fri_params = FriParameters::new_for_testing(log_blowup);
71 #[cfg(feature = "cuda")]
72 let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
73 #[cfg(not(feature = "cuda"))]
74 let debug = true;
75 let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
76 fri_params,
77 builder,
78 config,
79 exe,
80 input,
81 min_segments,
82 debug,
83 )
84 .unwrap();
85 final_memory
86}
87
88#[allow(clippy::type_complexity)]
93pub fn air_test_impl<E, VB>(
94 fri_params: FriParameters,
95 builder: VB,
96 config: VB::VmConfig,
97 exe: impl Into<VmExe<Val<E::SC>>>,
98 input: impl Into<Streams<Val<E::SC>>>,
99 min_segments: usize,
100 debug: bool,
101) -> eyre::Result<(
102 Option<MemoryImage>,
103 Vec<VerificationDataWithFriParams<E::SC>>,
104)>
105where
106 E: StarkFriEngine,
107 Val<E::SC>: PrimeField32,
108 VB: VmBuilder<E>,
109 <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
110 + MeteredExecutor<Val<E::SC>>
111 + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
112 Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
113{
114 setup_tracing();
115 let engine = E::new(fri_params);
116 let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config)?;
117 let vk = pk.get_vk();
118 let exe = exe.into();
119 let input = input.into();
120 let metered_ctx = vm.build_metered_ctx();
121 let (segments, _) = vm
122 .metered_interpreter(&exe)?
123 .execute_metered(input.clone(), metered_ctx)?;
124 let cached_program_trace = vm.commit_program_on_device(&exe.program);
125 vm.load_program(cached_program_trace);
126 let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
127
128 let mut state = Some(vm.create_initial_state(&exe, input));
129 let mut proofs = Vec::new();
130 let mut exit_code = None;
131 for segment in segments {
132 let Segment {
133 instret_start,
134 num_insns,
135 trace_heights,
136 } = segment;
137 assert_eq!(state.as_ref().unwrap().instret, instret_start);
138 let from_state = Option::take(&mut state).unwrap();
139 vm.transport_init_memory_to_device(&from_state.memory);
140 let PreflightExecutionOutput {
141 system_records,
142 record_arenas,
143 to_state,
144 } = vm.execute_preflight(
145 &mut preflight_interpreter,
146 from_state,
147 Some(num_insns),
148 &trace_heights,
149 )?;
150 state = Some(to_state);
151 exit_code = system_records.exit_code;
152
153 let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
154 if debug {
155 debug_proving_ctx(&vm, &pk, &ctx);
156 }
157 let proof = vm.engine.prove(vm.pk(), ctx);
158 proofs.push(proof);
159 }
160 assert!(proofs.len() >= min_segments);
161 vm.verify(&vk, &proofs)
162 .expect("segment proofs should verify");
163 let state = state.unwrap();
164 let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
165 let vdata = proofs
166 .into_iter()
167 .map(|proof| VerificationDataWithFriParams {
168 data: VerificationData {
169 vk: vk.clone(),
170 proof,
171 },
172 fri_params: vm.engine.fri_params(),
173 })
174 .collect();
175
176 Ok((final_memory, vdata))
177}