openvm_sdk/prover/
root.rs1use getset::Getters;
2use itertools::zip_eq;
3use openvm_circuit::arch::{
4 GenerationError, PreflightExecutionOutput, SingleSegmentVmProver, Streams, VirtualMachine,
5 VirtualMachineError, VmInstance,
6};
7use openvm_continuations::verifier::root::types::RootVmVerifierInput;
8use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS};
9use openvm_native_recursion::hints::Hintable;
10use openvm_stark_sdk::{
11 config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters},
12 engine::StarkEngine,
13 openvm_stark_backend::proof::Proof,
14};
15
16use crate::{
17 keygen::{perm::AirIdPermutation, RootVerifierProvingKey},
18 prover::vm::new_local_prover,
19 RootSC, F, SC,
20};
21
22#[derive(Getters)]
24pub struct RootVerifierLocalProver {
25 inner: VmInstance<BabyBearPoseidon2RootEngine, NativeCpuBuilder>,
29 #[getset(get = "pub")]
31 fixed_air_heights: Vec<u32>,
32 air_id_perm: AirIdPermutation,
33 air_id_inv_perm: AirIdPermutation,
34}
35
36impl RootVerifierLocalProver {
37 pub fn new(root_verifier_pk: &RootVerifierProvingKey) -> Result<Self, VirtualMachineError> {
38 let inner = new_local_prover(
39 NativeCpuBuilder,
40 &root_verifier_pk.vm_pk,
41 root_verifier_pk.root_committed_exe.exe.clone(),
42 )?;
43 let fixed_air_heights = root_verifier_pk.air_heights.clone();
44 let air_id_perm = AirIdPermutation::compute(&fixed_air_heights);
45 let mut inverse_perm = vec![0usize; air_id_perm.perm.len()];
46 for (i, &perm_i) in air_id_perm.perm.iter().enumerate() {
47 inverse_perm[perm_i] = i;
48 }
49 let air_id_inv_perm = AirIdPermutation { perm: inverse_perm };
50
51 Ok(Self {
52 inner,
53 fixed_air_heights,
54 air_id_perm,
55 air_id_inv_perm,
56 })
57 }
58
59 pub fn vm_config(&self) -> &NativeConfig {
60 self.inner.vm.config()
61 }
62
63 #[allow(dead_code)]
64 pub(crate) fn fri_params(&self) -> &FriParameters {
65 &self.inner.vm.engine.fri_params
66 }
67
68 pub fn execute_for_air_heights(
69 &mut self,
70 input: RootVmVerifierInput<SC>,
71 ) -> Result<Vec<u32>, VirtualMachineError> {
72 let exe = self.inner.exe().clone();
73 let vm = &mut self.inner.vm;
75 Self::permute_pk(vm, &self.air_id_inv_perm);
76 assert!(!vm.config().as_ref().continuation_enabled);
77 let input = input.write();
78 let state = vm.create_initial_state(&exe, input);
79 vm.transport_init_memory_to_device(&state.memory);
80 let PreflightExecutionOutput {
81 system_records,
82 record_arenas,
83 ..
84 } = vm.execute_preflight(
85 &mut self.inner.interpreter,
86 state,
87 None,
88 NATIVE_MAX_TRACE_HEIGHTS,
89 )?;
90 let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
93 let air_heights = ctx
94 .per_air
95 .iter()
96 .map(|(_, air_ctx)| air_ctx.main_trace_height() as u32)
97 .collect();
98 Self::permute_pk(vm, &self.air_id_perm);
99 Ok(air_heights)
100 }
101
102 fn permute_pk(
105 vm: &mut VirtualMachine<BabyBearPoseidon2RootEngine, NativeCpuBuilder>,
106 perm: &AirIdPermutation,
107 ) {
108 perm.permute(&mut vm.pk_mut().per_air);
109 for thc in &mut vm.pk_mut().trace_height_constraints {
110 perm.permute(&mut thc.coefficients);
111 }
112 }
113}
114
115impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
116 fn prove(
125 &mut self,
126 input: impl Into<Streams<F>>,
127 _: &[u32],
128 ) -> Result<Proof<RootSC>, VirtualMachineError> {
129 assert!(!self.vm_config().as_ref().continuation_enabled);
130 self.inner.reset_state(input);
134 let state = self
135 .inner
136 .state_mut()
137 .take()
138 .expect("State should always be present");
139 let vm = &mut self.inner.vm;
140 Self::permute_pk(vm, &self.air_id_inv_perm);
146 assert!(!vm.config().as_ref().continuation_enabled);
147 vm.transport_init_memory_to_device(&state.memory);
148
149 let trace_heights = &self.fixed_air_heights;
150 let PreflightExecutionOutput {
151 system_records,
152 mut record_arenas,
153 to_state,
154 } = vm.execute_preflight(&mut self.inner.interpreter, state, None, trace_heights)?;
155 for ra in &mut record_arenas {
158 ra.force_matrix_dimensions();
159 }
160 vm.override_system_trace_heights(trace_heights);
161
162 let mut ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
163 for (air_idx, (fixed_height, (idx, air_ctx))) in
165 zip_eq(trace_heights, &ctx.per_air).enumerate()
166 {
167 let fixed_height = *fixed_height as usize;
168 if air_idx != *idx {
169 return Err(GenerationError::ForceTraceHeightIncorrect {
170 air_idx,
171 actual: 0,
172 expected: fixed_height,
173 }
174 .into());
175 }
176 if fixed_height != air_ctx.main_trace_height() {
177 return Err(GenerationError::ForceTraceHeightIncorrect {
178 air_idx,
179 actual: air_ctx.main_trace_height(),
180 expected: fixed_height,
181 }
182 .into());
183 }
184 }
185 self.air_id_perm.permute(&mut ctx.per_air);
187 for (i, (air_idx, _)) in ctx.per_air.iter_mut().enumerate() {
188 *air_idx = i;
189 }
190 Self::permute_pk(vm, &self.air_id_perm);
192 let proof = vm.engine.prove(vm.pk(), ctx);
193 *self.inner.state_mut() = Some(to_state);
194 Ok(proof)
195 }
196}