1use openvm_instructions::exe::VmExe;
2use openvm_stark_backend::{
3 config::{Com, Val},
4 engine::VerificationData,
5 p3_field::PrimeField32,
6};
7use openvm_stark_sdk::{
8 config::{baby_bear_poseidon2::BabyBearPoseidon2Config, setup_tracing, FriParameters},
9 engine::{StarkFriEngine, VerificationDataWithFriParams},
10 p3_baby_bear::BabyBear,
11};
12
13use crate::{
14 arch::{
15 debug_proving_ctx, execution_mode::Segment, vm::VirtualMachine, Executor, ExitCode,
16 MeteredExecutor, PreflightExecutionOutput, PreflightExecutor, Streams, VmBuilder,
17 VmCircuitConfig, VmConfig, VmExecutionConfig,
18 },
19 system::memory::{MemoryImage, CHUNK},
20};
21
22cfg_if::cfg_if! {
23 if #[cfg(feature = "cuda")] {
24 pub use openvm_cuda_backend::{engine::GpuBabyBearPoseidon2Engine as TestStarkEngine, chip::cpu_proving_ctx_to_gpu};
25 use crate::arch::DenseRecordArena;
26 pub type TestRecordArena = DenseRecordArena;
27 } else {
28 pub use openvm_stark_sdk::config::baby_bear_poseidon2::BabyBearPoseidon2Engine as TestStarkEngine;
29 use crate::arch::MatrixRecordArena;
30 pub type TestRecordArena = MatrixRecordArena<BabyBear>;
31 }
32}
33type RA = TestRecordArena;
34
35pub fn air_test<VB, VC>(builder: VB, config: VC, exe: impl Into<VmExe<BabyBear>>)
39where
40 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
41 VC: VmExecutionConfig<BabyBear>
42 + VmCircuitConfig<BabyBearPoseidon2Config>
43 + VmConfig<BabyBearPoseidon2Config>,
44 <VC as VmExecutionConfig<BabyBear>>::Executor:
45 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
46{
47 air_test_with_min_segments(builder, config, exe, Streams::default(), 1);
48}
49
50pub fn air_test_with_min_segments<VB, VC>(
52 builder: VB,
53 config: VC,
54 exe: impl Into<VmExe<BabyBear>>,
55 input: impl Into<Streams<BabyBear>>,
56 min_segments: usize,
57) -> Option<MemoryImage>
58where
59 VB: VmBuilder<TestStarkEngine, VmConfig = VC, RecordArena = RA>,
60 VC: VmExecutionConfig<BabyBear>
61 + VmCircuitConfig<BabyBearPoseidon2Config>
62 + VmConfig<BabyBearPoseidon2Config>,
63 <VC as VmExecutionConfig<BabyBear>>::Executor:
64 Executor<BabyBear> + MeteredExecutor<BabyBear> + PreflightExecutor<BabyBear, RA>,
65{
66 let mut log_blowup = 1;
67 while config.as_ref().max_constraint_degree > (1 << log_blowup) + 1 {
68 log_blowup += 1;
69 }
70 let fri_params = FriParameters::new_for_testing(log_blowup);
71 let debug = std::env::var("OPENVM_SKIP_DEBUG") != Result::Ok(String::from("1"));
72 let (final_memory, _) = air_test_impl::<TestStarkEngine, VB>(
73 fri_params,
74 builder,
75 config,
76 exe,
77 input,
78 min_segments,
79 debug,
80 )
81 .unwrap();
82 final_memory
83}
84
85#[allow(clippy::type_complexity)]
90pub fn air_test_impl<E, VB>(
91 fri_params: FriParameters,
92 builder: VB,
93 config: VB::VmConfig,
94 exe: impl Into<VmExe<Val<E::SC>>>,
95 input: impl Into<Streams<Val<E::SC>>>,
96 min_segments: usize,
97 debug: bool,
98) -> eyre::Result<(
99 Option<MemoryImage>,
100 Vec<VerificationDataWithFriParams<E::SC>>,
101)>
102where
103 E: StarkFriEngine,
104 Val<E::SC>: PrimeField32,
105 VB: VmBuilder<E>,
106 <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
107 + MeteredExecutor<Val<E::SC>>
108 + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
109 Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]>,
110{
111 setup_tracing();
112 let engine = E::new(fri_params);
113 let (mut vm, pk) = VirtualMachine::<E, VB>::new_with_keygen(engine, builder, config)?;
114 let vk = pk.get_vk();
115 let exe = exe.into();
116 let input = input.into();
117 let metered_ctx = vm.build_metered_ctx(&exe);
118 let (segments, _) = vm
119 .metered_interpreter(&exe)?
120 .execute_metered(input.clone(), metered_ctx)?;
121 let cached_program_trace = vm.commit_program_on_device(&exe.program);
122 vm.load_program(cached_program_trace);
123 let mut preflight_interpreter = vm.preflight_interpreter(&exe)?;
124
125 let mut state = Some(vm.create_initial_state(&exe, input));
126 let mut proofs = Vec::new();
127 let mut exit_code = None;
128 for segment in segments {
129 let Segment {
130 instret_start,
131 num_insns,
132 trace_heights,
133 } = segment;
134 assert_eq!(state.as_ref().unwrap().instret(), instret_start);
135 let from_state = Option::take(&mut state).unwrap();
136 vm.transport_init_memory_to_device(&from_state.memory);
137 let PreflightExecutionOutput {
138 system_records,
139 record_arenas,
140 to_state,
141 } = vm.execute_preflight(
142 &mut preflight_interpreter,
143 from_state,
144 Some(num_insns),
145 &trace_heights,
146 )?;
147 state = Some(to_state);
148 exit_code = system_records.exit_code;
149
150 let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
151 if debug {
152 debug_proving_ctx(&vm, &pk, &ctx);
153 }
154 let proof = vm.engine.prove(vm.pk(), ctx);
155 proofs.push(proof);
156 }
157 assert!(proofs.len() >= min_segments);
158 vm.verify(&vk, &proofs)
159 .expect("segment proofs should verify");
160 let state = state.unwrap();
161 let final_memory = (exit_code == Some(ExitCode::Success as u32)).then_some(state.memory.memory);
162 let vdata = proofs
163 .into_iter()
164 .map(|proof| VerificationDataWithFriParams {
165 data: VerificationData {
166 vk: vk.clone(),
167 proof,
168 },
169 fri_params: vm.engine.fri_params(),
170 })
171 .collect();
172
173 Ok((final_memory, vdata))
174}