openvm_sdk/prover/
app.rs

1use std::sync::{Arc, OnceLock};
2
3#[cfg(feature = "async")]
4pub use async_prover::*;
5use getset::Getters;
6use itertools::Itertools;
7use openvm_circuit::{
8    arch::{
9        hasher::poseidon2::{vm_poseidon2_hasher, Poseidon2Hasher},
10        instructions::exe::VmExe,
11        verify_segments, ContinuationVmProof, ContinuationVmProver, Executor, MeteredExecutor,
12        PreflightExecutor, VerifiedExecutionPayload, VirtualMachine, VirtualMachineError,
13        VmBuilder, VmExecutionConfig, VmInstance, VmVerificationError,
14    },
15    system::memory::CHUNK,
16};
17use openvm_stark_backend::{
18    config::{Com, Val},
19    keygen::types::MultiStarkVerifyingKey,
20    p3_field::PrimeField32,
21};
22use openvm_stark_sdk::{
23    config::baby_bear_poseidon2::BabyBearPoseidon2Engine,
24    engine::{StarkEngine, StarkFriEngine},
25};
26use tracing::instrument;
27
28use crate::{
29    commit::{AppExecutionCommit, CommitBytes},
30    keygen::AppVerifyingKey,
31    prover::vm::{new_local_prover, types::VmProvingKey},
32    util::check_max_constraint_degrees,
33    StdIn, F, SC,
34};
35
36#[derive(Getters)]
37pub struct AppProver<E, VB>
38where
39    E: StarkEngine,
40    VB: VmBuilder<E>,
41{
42    pub program_name: Option<String>,
43    #[getset(get = "pub")]
44    instance: VmInstance<E, VB>,
45    #[getset(get = "pub")]
46    app_vm_vk: MultiStarkVerifyingKey<E::SC>,
47    #[getset(get = "pub")]
48    leaf_verifier_program_commit: Com<E::SC>,
49
50    app_execution_commit: OnceLock<AppExecutionCommit>,
51}
52
53impl<E, VB> AppProver<E, VB>
54where
55    E: StarkFriEngine,
56    VB: VmBuilder<E>,
57    Val<E::SC>: PrimeField32,
58    Com<E::SC>: AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]> + Into<[Val<E::SC>; CHUNK]>,
59{
60    /// Creates a new [AppProver] instance. This method will re-commit the `exe` program on device.
61    /// If a cached version of the program already exists on device, then directly use the
62    /// [`Self::new_from_instance`] constructor.
63    ///
64    /// The `leaf_verifier_program_commit` is the commitment to the program of the leaf verifier
65    /// that verifies the App VM circuit. It can be found in the `AppProvingKey`.
66    pub fn new(
67        vm_builder: VB,
68        app_vm_pk: &VmProvingKey<E::SC, VB::VmConfig>,
69        app_exe: Arc<VmExe<Val<E::SC>>>,
70        leaf_verifier_program_commit: Com<E::SC>,
71    ) -> Result<Self, VirtualMachineError> {
72        let instance = new_local_prover(vm_builder, app_vm_pk, app_exe)?;
73        let app_vm_vk = app_vm_pk.vm_pk.get_vk();
74
75        Ok(Self::new_from_instance(
76            instance,
77            app_vm_vk,
78            leaf_verifier_program_commit,
79        ))
80    }
81
82    pub fn new_from_instance(
83        instance: VmInstance<E, VB>,
84        app_vm_vk: MultiStarkVerifyingKey<E::SC>,
85        leaf_verifier_program_commit: Com<E::SC>,
86    ) -> Self {
87        Self {
88            program_name: None,
89            instance,
90            app_vm_vk,
91            leaf_verifier_program_commit,
92            app_execution_commit: OnceLock::new(),
93        }
94    }
95
96    pub fn set_program_name(&mut self, program_name: impl AsRef<str>) -> &mut Self {
97        self.program_name = Some(program_name.as_ref().to_string());
98        self
99    }
100    pub fn with_program_name(mut self, program_name: impl AsRef<str>) -> Self {
101        self.set_program_name(program_name);
102        self
103    }
104
105    /// Returns [AppExecutionCommit], which is a commitment to **both** the App VM and the App
106    /// VmExe.
107    pub fn app_commit(&self) -> AppExecutionCommit {
108        *self.app_execution_commit.get_or_init(|| {
109            AppExecutionCommit::compute::<E::SC>(
110                &self.instance().vm.config().as_ref().memory_config,
111                self.instance().exe(),
112                self.instance().program_commitment().clone(),
113                self.leaf_verifier_program_commit.clone(),
114            )
115        })
116    }
117
118    pub fn app_program_commit(&self) -> Com<E::SC> {
119        self.instance().program_commitment().clone()
120    }
121
122    /// Generates proof for every continuation segment
123    #[instrument(
124        name = "app_prove",
125        skip_all,
126        fields(group = self.program_name.as_ref().unwrap_or(&"app_proof".to_string()))
127    )]
128    pub fn prove(
129        &mut self,
130        input: StdIn<Val<E::SC>>,
131    ) -> Result<ContinuationVmProof<E::SC>, VirtualMachineError>
132    where
133        <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
134            + MeteredExecutor<Val<E::SC>>
135            + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
136    {
137        assert!(self.vm_config().as_ref().continuation_enabled);
138        check_max_constraint_degrees(
139            self.vm_config().as_ref(),
140            &self.instance.vm.engine.fri_params(),
141        );
142        #[cfg(feature = "metrics")]
143        metrics::counter!("fri.log_blowup")
144            .absolute(self.instance.vm.engine.fri_params().log_blowup as u64);
145        ContinuationVmProver::prove(&mut self.instance, input)
146    }
147
148    /// Generates proof for every continuation segment
149    ///
150    /// This function internally calls [verify_segments] to verify the result before returning the
151    /// proof.
152    ///
153    /// **Note**: This function calls [`app_commit`](Self::app_commit), which is computationally
154    /// intensive if it is the first time it is called within an `AppProver` instance.
155    #[instrument(name = "app_prove_and_verify", skip_all)]
156    pub fn prove_and_verify(
157        &mut self,
158        input: StdIn<Val<E::SC>>,
159    ) -> Result<ContinuationVmProof<E::SC>, VirtualMachineError>
160    where
161        <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
162            + MeteredExecutor<Val<E::SC>>
163            + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
164    {
165        let proofs = self.prove(input)?;
166        // We skip verification of the user public values proof here because it is directly computed
167        // from the merkle tree above
168        let res = verify_segments(
169            &self.instance.vm.engine,
170            &self.app_vm_vk,
171            &proofs.per_segment,
172        )?;
173        let app_exe_commit_u32s = self.app_commit().app_exe_commit.to_u32_digest();
174        let exe_commit_u32s = res.exe_commit.map(|x| x.as_canonical_u32());
175        if exe_commit_u32s != app_exe_commit_u32s {
176            return Err(VmVerificationError::ExeCommitMismatch {
177                expected: app_exe_commit_u32s,
178                actual: exe_commit_u32s,
179            }
180            .into());
181        }
182        Ok(proofs)
183    }
184
185    /// App Exe
186    pub fn exe(&self) -> Arc<VmExe<Val<E::SC>>> {
187        self.instance.exe().clone()
188    }
189
190    /// App VM
191    pub fn vm(&self) -> &VirtualMachine<E, VB> {
192        &self.instance.vm
193    }
194
195    /// App VM config
196    pub fn vm_config(&self) -> &VB::VmConfig {
197        self.instance.vm.config()
198    }
199}
200
201/// The payload of a verified guest VM execution with user public values extracted and
202/// verified.
203pub struct VerifiedAppArtifacts {
204    /// The Merklelized hash of:
205    /// - Program code commitment (commitment of the cached trace)
206    /// - Merkle root of the initial memory
207    /// - Starting program counter (`pc_start`)
208    ///
209    /// The Merklelization uses Poseidon2 as a cryptographic hash function (for the leaves)
210    /// and a cryptographic compression function (for internal nodes).
211    pub app_exe_commit: CommitBytes,
212    pub user_public_values: Vec<u8>,
213}
214
215/// Verifies the [ContinuationVmProof], which is a collection of STARK proofs as well as
216/// additional Merkle proof for user public values.
217///
218/// This function verifies the STARK proofs and additional conditions to ensure that the
219/// `proof` is a valid proof of guest VM execution that terminates successfully (exit code 0)
220/// _with respect to_ a commitment to some VM executable.
221/// It is the responsibility of the caller to check that the commitment matches the expected
222/// VM executable.
223pub fn verify_app_proof(
224    app_vk: &AppVerifyingKey,
225    proof: &ContinuationVmProof<SC>,
226) -> Result<VerifiedAppArtifacts, VmVerificationError> {
227    static POSEIDON2_HASHER: OnceLock<Poseidon2Hasher<F>> = OnceLock::new();
228    let engine = BabyBearPoseidon2Engine::new(app_vk.fri_params);
229    let VerifiedExecutionPayload {
230        exe_commit,
231        final_memory_root,
232    } = verify_segments(&engine, &app_vk.vk, &proof.per_segment)?;
233
234    proof.user_public_values.verify(
235        POSEIDON2_HASHER.get_or_init(vm_poseidon2_hasher),
236        app_vk.memory_dimensions,
237        final_memory_root,
238    )?;
239
240    let app_exe_commit = CommitBytes::from_u32_digest(&exe_commit.map(|x| x.as_canonical_u32()));
241    // The user public values address space has cells have type u8
242    let user_public_values = proof
243        .user_public_values
244        .public_values
245        .iter()
246        .map(|x| x.as_canonical_u32().try_into().unwrap())
247        .collect_vec();
248    Ok(VerifiedAppArtifacts {
249        app_exe_commit,
250        user_public_values,
251    })
252}
253
254#[cfg(feature = "async")]
255mod async_prover {
256    use derivative::Derivative;
257    use openvm_circuit::{
258        arch::ExecutionError, system::memory::merkle::public_values::UserPublicValuesProof,
259    };
260    use openvm_stark_sdk::config::FriParameters;
261    use tokio::{spawn, sync::Semaphore, task::spawn_blocking};
262    use tracing::{info_span, instrument, Instrument};
263
264    use super::*;
265
266    /// Thread-safe asynchronous app prover.
267    #[derive(Derivative, Getters)]
268    #[derivative(Clone)]
269    pub struct AsyncAppProver<E, VB>
270    where
271        E: StarkEngine,
272        VB: VmBuilder<E>,
273    {
274        pub program_name: Option<String>,
275        #[getset(get = "pub")]
276        vm_builder: VB,
277        #[getset(get = "pub")]
278        app_vm_pk: Arc<VmProvingKey<E::SC, VB::VmConfig>>,
279        app_exe: Arc<VmExe<Val<E::SC>>>,
280        #[getset(get = "pub")]
281        leaf_verifier_program_commit: Com<E::SC>,
282
283        semaphore: Arc<Semaphore>,
284    }
285
286    impl<E, VB> AsyncAppProver<E, VB>
287    where
288        E: StarkFriEngine + 'static,
289        VB: VmBuilder<E> + Clone + Send + Sync + 'static,
290        VB::VmConfig: Send + Sync,
291        <VB::VmConfig as VmExecutionConfig<Val<E::SC>>>::Executor: Executor<Val<E::SC>>
292            + MeteredExecutor<Val<E::SC>>
293            + PreflightExecutor<Val<E::SC>, VB::RecordArena>,
294        Val<E::SC>: PrimeField32,
295        Com<E::SC>:
296            AsRef<[Val<E::SC>; CHUNK]> + From<[Val<E::SC>; CHUNK]> + Into<[Val<E::SC>; CHUNK]>,
297    {
298        pub fn new(
299            vm_builder: VB,
300            app_vm_pk: Arc<VmProvingKey<E::SC, VB::VmConfig>>,
301            app_exe: Arc<VmExe<Val<E::SC>>>,
302            leaf_verifier_program_commit: Com<E::SC>,
303            max_concurrency: usize,
304        ) -> Result<Self, VirtualMachineError> {
305            Ok(Self {
306                program_name: None,
307                vm_builder,
308                app_vm_pk,
309                app_exe,
310                leaf_verifier_program_commit,
311                semaphore: Arc::new(Semaphore::new(max_concurrency)),
312            })
313        }
314
315        pub fn set_program_name(&mut self, program_name: impl AsRef<str>) -> &mut Self {
316            self.program_name = Some(program_name.as_ref().to_string());
317            self
318        }
319        pub fn with_program_name(mut self, program_name: impl AsRef<str>) -> Self {
320            self.set_program_name(program_name);
321            self
322        }
323
324        /// App Exe
325        pub fn exe(&self) -> Arc<VmExe<Val<E::SC>>> {
326            self.app_exe.clone()
327        }
328
329        /// App VM config
330        pub fn vm_config(&self) -> &VB::VmConfig {
331            &self.app_vm_pk.vm_config
332        }
333
334        pub fn fri_params(&self) -> FriParameters {
335            self.app_vm_pk.fri_params
336        }
337
338        /// Creates an [AppProver] within a particular thread. The former instance is not
339        /// thread-safe and should **not** be moved between threads.
340        pub fn local(&self) -> Result<AppProver<E, VB>, VirtualMachineError> {
341            AppProver::new(
342                self.vm_builder.clone(),
343                &self.app_vm_pk,
344                self.app_exe.clone(),
345                self.leaf_verifier_program_commit.clone(),
346            )
347        }
348
349        #[instrument(
350            name = "app proof",
351            skip_all,
352            fields(
353                group = self.program_name.as_ref().unwrap_or(&"app_proof".to_string())
354            )
355        )]
356        pub async fn prove(
357            self,
358            input: StdIn<Val<E::SC>>,
359        ) -> eyre::Result<ContinuationVmProof<E::SC>> {
360            assert!(self.vm_config().as_ref().continuation_enabled);
361            check_max_constraint_degrees(self.vm_config().as_ref(), &self.fri_params());
362            #[cfg(feature = "metrics")]
363            metrics::counter!("fri.log_blowup").absolute(self.fri_params().log_blowup as u64);
364
365            // PERF[jpw]: it is possible to create metered_interpreter without creating vm. The
366            // latter is more convenient, but does unnecessary setup (e.g., transfer pk to
367            // device). Also, app_commit should be cached.
368            let mut local_prover = self.local()?;
369            let app_commit = local_prover.app_commit();
370            local_prover.instance.reset_state(input.clone());
371            let mut state = local_prover.instance.state_mut().take().unwrap();
372            let vm = &mut local_prover.instance.vm;
373            let metered_ctx = vm.build_metered_ctx(&self.app_exe);
374            let metered_interpreter = vm.metered_interpreter(&self.app_exe)?;
375            let (segments, _) = metered_interpreter.execute_metered(input, metered_ctx)?;
376            drop(metered_interpreter);
377            let pure_interpreter = vm.interpreter(&self.app_exe)?;
378            let mut tasks = Vec::with_capacity(segments.len());
379            let terminal_instret = segments
380                .last()
381                .map(|s| s.instret_start + s.num_insns)
382                .unwrap_or(u64::MAX);
383            for (seg_idx, segment) in segments.into_iter().enumerate() {
384                tracing::info!(
385                    %seg_idx,
386                    instret = state.instret(),
387                    %segment.instret_start,
388                    pc = state.pc(),
389                    "Re-executing",
390                );
391                let num_insns = segment.instret_start.checked_sub(state.instret()).unwrap();
392                state = pure_interpreter.execute_from_state(state, Some(num_insns))?;
393
394                let semaphore = self.semaphore.clone();
395                let async_worker = self.clone();
396                let start_state = state.clone();
397                let task = spawn(
398                    async move {
399                        let _permit = semaphore.acquire().await?;
400                        let span = tracing::Span::current();
401                        spawn_blocking(move || {
402                            let _span = span.enter();
403                            info_span!("prove_segment", segment = seg_idx).in_scope(
404                                || -> eyre::Result<_> {
405                                    // We need a separate span so the metric label includes
406                                    // "segment"
407                                    // from _segment_span
408                                    let _prove_span = info_span!(
409                                        "vm_prove",
410                                        thread_id = ?std::thread::current().id()
411                                    )
412                                    .entered();
413                                    let mut worker = async_worker.local()?;
414                                    let instance = &mut worker.instance;
415                                    let vm = &mut instance.vm;
416                                    let preflight_interpreter = &mut instance.interpreter;
417                                    let (segment_proof, _) = vm.prove(
418                                        preflight_interpreter,
419                                        start_state,
420                                        Some(segment.num_insns),
421                                        &segment.trace_heights,
422                                    )?;
423                                    Ok(segment_proof)
424                                },
425                            )
426                        })
427                        .await?
428                    }
429                    .in_current_span(),
430                );
431                tasks.push(task);
432            }
433            // Finish execution to termination
434            state = pure_interpreter.execute_from_state(state, None)?;
435            if state.instret() != terminal_instret {
436                tracing::warn!(
437                    "Pure execution terminal instret={}, metered execution terminal instret={}",
438                    state.instret(),
439                    terminal_instret
440                );
441                // This should never happen
442                return Err(ExecutionError::DidNotTerminate.into());
443            }
444            let final_memory = &state.memory.memory;
445            let user_public_values = UserPublicValuesProof::compute(
446                vm.config().as_ref().memory_config.memory_dimensions(),
447                vm.config().as_ref().num_public_values,
448                &vm_poseidon2_hasher(),
449                final_memory,
450            );
451
452            let mut proofs = Vec::with_capacity(tasks.len());
453            for task in tasks {
454                let proof = task.await??;
455                proofs.push(proof);
456            }
457            let cont_proof = ContinuationVmProof {
458                per_segment: proofs,
459                user_public_values,
460            };
461
462            // We skip verification of the user public values proof here because it is directly
463            // computed from the merkle tree above
464            let engine = E::new(self.fri_params());
465            let res = verify_segments(
466                &engine,
467                &self.app_vm_pk.vm_pk.get_vk(),
468                &cont_proof.per_segment,
469            )?;
470            let app_exe_commit_u32s = app_commit.app_exe_commit.to_u32_digest();
471            let exe_commit_u32s = res.exe_commit.map(|x| x.as_canonical_u32());
472            if exe_commit_u32s != app_exe_commit_u32s {
473                return Err(VmVerificationError::ExeCommitMismatch {
474                    expected: app_exe_commit_u32s,
475                    actual: exe_commit_u32s,
476                }
477                .into());
478            }
479            Ok(cont_proof)
480        }
481    }
482}