openvm_sdk/prover/
root.rs

1use getset::Getters;
2use itertools::zip_eq;
3use openvm_circuit::arch::{
4    GenerationError, PreflightExecutionOutput, SingleSegmentVmProver, Streams, VirtualMachine,
5    VirtualMachineError, VmInstance,
6};
7use openvm_continuations::verifier::root::types::RootVmVerifierInput;
8use openvm_native_circuit::{NativeConfig, NativeCpuBuilder, NATIVE_MAX_TRACE_HEIGHTS};
9use openvm_native_recursion::hints::Hintable;
10use openvm_stark_sdk::{
11    config::{baby_bear_poseidon2_root::BabyBearPoseidon2RootEngine, FriParameters},
12    engine::StarkEngine,
13    openvm_stark_backend::proof::Proof,
14};
15
16use crate::{
17    keygen::{perm::AirIdPermutation, RootVerifierProvingKey},
18    prover::vm::new_local_prover,
19    RootSC, F, SC,
20};
21
22/// Local prover for a root verifier.
23#[derive(Getters)]
24pub struct RootVerifierLocalProver {
25    /// The proving key in `inner` should always have ordering of AIRs in the sorted order by fixed
26    /// trace heights outside of the `prove` function.
27    // This is CPU-only for now because it uses RootSC
28    inner: VmInstance<BabyBearPoseidon2RootEngine, NativeCpuBuilder>,
29    /// The constant trace heights, ordered by AIR ID (the original ordering from VmConfig).
30    #[getset(get = "pub")]
31    fixed_air_heights: Vec<u32>,
32    air_id_perm: AirIdPermutation,
33    air_id_inv_perm: AirIdPermutation,
34}
35
36impl RootVerifierLocalProver {
37    pub fn new(root_verifier_pk: &RootVerifierProvingKey) -> Result<Self, VirtualMachineError> {
38        let inner = new_local_prover(
39            NativeCpuBuilder,
40            &root_verifier_pk.vm_pk,
41            root_verifier_pk.root_committed_exe.exe.clone(),
42        )?;
43        let fixed_air_heights = root_verifier_pk.air_heights.clone();
44        let air_id_perm = AirIdPermutation::compute(&fixed_air_heights);
45        let mut inverse_perm = vec![0usize; air_id_perm.perm.len()];
46        for (i, &perm_i) in air_id_perm.perm.iter().enumerate() {
47            inverse_perm[perm_i] = i;
48        }
49        let air_id_inv_perm = AirIdPermutation { perm: inverse_perm };
50
51        Ok(Self {
52            inner,
53            fixed_air_heights,
54            air_id_perm,
55            air_id_inv_perm,
56        })
57    }
58
59    pub fn vm_config(&self) -> &NativeConfig {
60        self.inner.vm.config()
61    }
62
63    #[allow(dead_code)]
64    pub(crate) fn fri_params(&self) -> &FriParameters {
65        &self.inner.vm.engine.fri_params
66    }
67
68    pub fn execute_for_air_heights(
69        &mut self,
70        input: RootVmVerifierInput<SC>,
71    ) -> Result<Vec<u32>, VirtualMachineError> {
72        let exe = self.inner.exe().clone();
73        // See `SingleSegmentVmProver::prove` for explanation
74        let vm = &mut self.inner.vm;
75        Self::permute_pk(vm, &self.air_id_inv_perm);
76        assert!(!vm.config().as_ref().continuation_enabled);
77        let input = input.write();
78        let state = vm.create_initial_state(&exe, input);
79        vm.transport_init_memory_to_device(&state.memory);
80        let PreflightExecutionOutput {
81            system_records,
82            record_arenas,
83            ..
84        } = vm.execute_preflight(
85            &mut self.inner.interpreter,
86            state,
87            None,
88            NATIVE_MAX_TRACE_HEIGHTS,
89        )?;
90        // Note[jpw]: we could in theory extract trace heights from just preflight execution, but
91        // that requires special logic in the chips so we will just generate the traces for now
92        let ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
93        let air_heights = ctx
94            .per_air
95            .iter()
96            .map(|(_, air_ctx)| air_ctx.main_trace_height() as u32)
97            .collect();
98        Self::permute_pk(vm, &self.air_id_perm);
99        Ok(air_heights)
100    }
101
102    // ATTENTION: this must exactly match the permutation done in
103    // `AggStarkProvingKey::dummy_proof_and_keygen` except on DeviceMultiStarkProvingKey.
104    fn permute_pk(
105        vm: &mut VirtualMachine<BabyBearPoseidon2RootEngine, NativeCpuBuilder>,
106        perm: &AirIdPermutation,
107    ) {
108        perm.permute(&mut vm.pk_mut().per_air);
109        for thc in &mut vm.pk_mut().trace_height_constraints {
110            perm.permute(&mut thc.coefficients);
111        }
112    }
113}
114
115impl SingleSegmentVmProver<RootSC> for RootVerifierLocalProver {
116    // @dev: If this implementation is generalized to prover backends not using MatrixRecordArena,
117    // then it must be ensured that:
118    // - the Native extension chips can ensure that, if the record arenas have
119    //   `force_matrix_dimensions()` set, then the record arena capacity heights must equal the
120    //   trace matrix heights.
121    // - any chips that do not use record arenas (currently system memory chips) have a way to force
122    //   trace heights as well. We currently use the fact that all non-system periphery chips have
123    //   fixed height (in particular, there is no Poseidon2PeripheryChip).
124    fn prove(
125        &mut self,
126        input: impl Into<Streams<F>>,
127        _: &[u32],
128    ) -> Result<Proof<RootSC>, VirtualMachineError> {
129        assert!(!self.vm_config().as_ref().continuation_enabled);
130        // The following is unrolled from SingleSegmentVmProver for VmLocalProver and
131        // VirtualMachine::prove to add special logic around ensuring trace heights are fixed and
132        // then reordering the trace matrices so the heights are sorted.
133        self.inner.reset_state(input);
134        let state = self
135            .inner
136            .state_mut()
137            .take()
138            .expect("State should always be present");
139        let vm = &mut self.inner.vm;
140        // The root_verifier_pk has the AIRs ordered by the fixed AIR height sorted ordering, but
141        // execute_preflight and generate_proving_ctx still expect the original AIR ID ordering from
142        // VmConfig, so we apply the inverse permutation here, and then undo it after tracegen. This
143        // could maybe be replaced by only changing `executor_idx_to_air_idx`, but applying the
144        // permutation is conceptually simpler to track.
145        Self::permute_pk(vm, &self.air_id_inv_perm);
146        assert!(!vm.config().as_ref().continuation_enabled);
147        vm.transport_init_memory_to_device(&state.memory);
148
149        let trace_heights = &self.fixed_air_heights;
150        let PreflightExecutionOutput {
151            system_records,
152            mut record_arenas,
153            to_state,
154        } = vm.execute_preflight(&mut self.inner.interpreter, state, None, trace_heights)?;
155        // record_arenas are created with capacity specified by trace_heights. we must ensure
156        // `generate_proving_ctx` does not resize the trace matrices to make them smaller:
157        for ra in &mut record_arenas {
158            ra.force_matrix_dimensions();
159        }
160        vm.override_system_trace_heights(trace_heights);
161
162        let mut ctx = vm.generate_proving_ctx(system_records, record_arenas)?;
163        // Sanity check: ensure all generated trace matrices actually match the fixed heights.
164        for (air_idx, (fixed_height, (idx, air_ctx))) in
165            zip_eq(trace_heights, &ctx.per_air).enumerate()
166        {
167            let fixed_height = *fixed_height as usize;
168            if air_idx != *idx {
169                return Err(GenerationError::ForceTraceHeightIncorrect {
170                    air_idx,
171                    actual: 0,
172                    expected: fixed_height,
173                }
174                .into());
175            }
176            if fixed_height != air_ctx.main_trace_height() {
177                return Err(GenerationError::ForceTraceHeightIncorrect {
178                    air_idx,
179                    actual: air_ctx.main_trace_height(),
180                    expected: fixed_height,
181                }
182                .into());
183            }
184        }
185        // Reorder the AIRs by heights.
186        self.air_id_perm.permute(&mut ctx.per_air);
187        for (i, (air_idx, _)) in ctx.per_air.iter_mut().enumerate() {
188            *air_idx = i;
189        }
190        // We also undo the permutation on pk because `prove` needs pk and ctx ordering to match.
191        Self::permute_pk(vm, &self.air_id_perm);
192        let proof = vm.engine.prove(vm.pk(), ctx);
193        *self.inner.state_mut() = Some(to_state);
194        Ok(proof)
195    }
196}