1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3 config::{Com, Val},
4 engine::VerificationData,
5 p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8 config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9 engine::{StarkFriEngine, VerificationDataWithFriParams},
10 p3_baby_bear::BabyBear,
11};
12
13#[cfg(feature = "aot")]
14use crate::arch::{SystemConfig, VmState};
15#[cfg(feature = "aot")]
16use crate::system::memory::online::GuestMemory;
17use crate::{
18 arch::{
19 debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
20 MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
21 VmCircuitConfig, VmConfig, VmExecutionConfig,
22 },
23 system::memory::{MemoryImage, CHUNK},
24};
25
26cfg_if::cfg_if! {
27 if #[cfg(feature = "cuda")] {
28 pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
29 use crate::arch::DenseRecordArena;
30 pub type TestRecordArena = DenseRecordArena;
31 } else {
32 pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
33 use crate::arch::MatrixRecordArena;
34 pub type TestRecordArena = MatrixRecordArena<BabyBear>;
35 }
36}
37type RA = TestRecordArena;
38
39pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
43where
44 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
45 VC: VmExecutionConfig<BabyBear>
46 + VmCircuitConfig<BabyBearPoseidon2Config>
47 + VmConfig<BabyBearPoseidon2Config>,
48 <VC as VmExecutionConfig<BabyBear>>::Executor:
49 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
50{
51 air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
52}
53
54pub fn air_test_with_min_segments<VB, VC>(
56 builder: VB,
57 config: VC,
58 exe: impl Into<VmExe<BabyBear>>,
59 input: impl Into<Streams<BabyBear>>,
60 min_segments: usize,
61) -> Option<MemoryImage>
62where
63 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
64 VC: VmExecutionConfig<BabyBear>
65 + VmCircuitConfig<BabyBearPoseidon2Config>
66 + VmConfig<BabyBearPoseidon2Config>,
67 <VC as VmExecutionConfig<BabyBear>>::Executor:
68 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
69{
70 let mut log_blowup = 1;
71 while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
72 log_blowup += 1;
73 }
74 let fri_params = FriParameters::new_for_testing(log_blowup);
75 let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
76 let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
77 fri_params,
78 builder,
79 config,
80 exe,
81 input,
82 min_segments,
83 debug,
84 )
85 .unwrap();
86 final_memory
87}
88
89#[cfg(feature = "aot")]
91pub fn check_aot_equivalence<E, VB>(
92 vm: &VirtualMachine<E, VB>,
93 config: &VB::VmConfig,
94 exe: &VmExe<Val<E::SC>>,
95 input: &Streams<Val<E::SC>>,
96) -> eyre::Result<()>
97where
98 E: StarkFriEngine,
99 Val<E::SC>: PrimeField32,
100 VB: VmBuilder<E>,
101 <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
102 + MeteredExecutor<Val<E::SC>>
103 + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
104 Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
105{
106 {
110 let interp_state_pure = vm
111 .naive_interpreter(exe)?
112 .execute(input.clone(), None)
113 .expect("Failed to execute");
114
115 let aot_state_pure = vm
116 .get_aot_instance(exe)?
117 .execute(input.clone(), None)
118 .expect("Failed to execute");
119
120 let system_config: &SystemConfig = config.as_ref();
121 let addr_spaces = &system_config.memory_config.addr_spaces;
122 let assert_vm_state_eq =
123 |lhs: &VmState<Val<E::SC>, GuestMemory>, rhs: &VmState<Val<E::SC>, GuestMemory>| {
124 assert_eq!(lhs.pc(), rhs.pc());
125 for r in 0..addr_spaces[1].num_cells {
126 let a = unsafe { lhs.memory.read::<u8, 1>(1, r as u32) };
127 let b = unsafe { rhs.memory.read::<u8, 1>(1, r as u32) };
128 assert_eq!(a, b);
129 }
130 };
131 assert_vm_state_eq(&interp_state_pure, &aot_state_pure);
132 }
133
134 println!("Checking metered AOT equivalence");
138 {
139 let metered_ctx = vm.build_metered_ctx(exe);
140 let (aot_segments, aot_state_metered) = vm
141 .get_metered_aot_instance(exe)?
142 .execute_metered(input.clone(), metered_ctx.clone())?;
143
144 let (segments, interp_state_metered) = vm
145 .naive_metered_interpreter(exe)?
146 .execute_metered(input.clone(), metered_ctx.clone())?;
147
148 assert_eq!(interp_state_metered.pc(), aot_state_metered.pc());
149
150 let system_config: &SystemConfig = config.as_ref();
151 let addr_spaces = &system_config.memory_config.addr_spaces;
152
153 for r in 0..addr_spaces[1].num_cells {
154 let interp = unsafe { interp_state_metered.memory.read::<u8, 1>(1, r as u32) };
155 let aot_interp = unsafe { aot_state_metered.memory.read::<u8, 1>(1, r as u32) };
156 assert_eq!(interp, aot_interp);
157 }
158
159 assert_eq!(segments.len(), aot_segments.len());
160 for i in 0..segments.len() {
161 assert_eq!(segments[i].instret_start, aot_segments[i].instret_start);
162 assert_eq!(segments[i].num_insns, aot_segments[i].num_insns);
163 assert_eq!(segments[i].trace_heights, aot_segments[i].trace_heights);
164 }
165 }
166
167 Ok(())
168}
169
170#[allow(clippy::type_complexity)]
175pub fn air_test_impl<E, VB>(
176 fri_params: FriParameters,
177 builder: VB,
178 config: VB::VmConfig,
179 exe: impl Into<VmExe<Val<E::SC>>>,
180 input: impl Into<Streams<Val<E::SC>>>,
181 min_segments: usize,
182 debug: bool,
183) -> eyre::Result<(
184 Option<MemoryImage>,
185 Vec<VerificationDataWithFriParams<E::SC>>,
186)>
187where
188 E: StarkFriEngine,
189 Val<E::SC>: PrimeField32,
190 VB: VmBuilder<E>,
191 <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
192 + MeteredExecutor<Val<E::SC>>
193 + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
194 Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
195{
196 setup_tracing();
197 let engine = E::new(fri_params);
198 let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config.clone())?;
199 let vk = pk.get_vk();
200 let exe = exe.into();
201 let input = input.into();
202 let metered_ctx = vm.build_metered_ctx(&exe);
203
204 #[cfg(feature = "aot")]
205 check_aot_equivalence(&vm, &config, &exe, &input)?;
206
207 let (segments, _) = vm
208 .metered_interpreter(&exe)?
209 .execute_metered(input.clone(), metered_ctx.clone())?;
210
211 let cached_program_trace = vm.commit_program_on_device(&exe.program);
212 vm.load_program(cached_program_trace);
213 let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
214
215 let mut state = Some(vm.create_initial_state(&exe, input));
216 let mut proofs = Vec::new();
217 let mut exit_code = None;
218 for segment in segments {
219 let Segment {
220 num_insns,
221 trace_heights,
222 ..
223 } = segment;
224 let from_state = Option::take(&mut state).unwrap();
225 vm.transport_init_memory_to_device(&from_state.memory);
226 let PreflightExecutionOutput {
227 system_records,
228 record_arenas,
229 to_state,
230 } = vm.execute_preflight(
231 &mut preflight_interpreter,
232 from_state,
233 Some(num_insns),
234 &trace_heights,
235 )?;
236 state = Some(to_state);
237 exit_code = system_records.exit_code;
238
239 let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
240 if debug {
241 debug_proving_ctx(&vm, &pk, &ctx);
242 }
243 let proof = vm.engine.prove(vm.pk(), ctx);
244 proofs.push(proof);
245 }
246 assert!(proofs.len() >= min_segments);
247 vm.verify(&vk, &proofs)
248 .expect("segment proofs should verify");
249 let state = state.unwrap();
250 let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
251 let vdata = proofs
252 .into_iter()
253 .map(|proof| VerificationDataWithFriParams {
254 data: VerificationData {
255 vk: vk.clone(),
256 proof,
257 },
258 fri_params: vm.engine.fri_params(),
259 })
260 .collect();
261
262 Ok((final_memory, vdata))
263}