openvm_stark_backend/prover/
helper.rs

1//! Helper methods for testing use
2use std::sync::Arc;
3
4use itertools::izip;
5use p3_matrix::{dense::RowMajorMatrix, Matrix};
6
7use crate::{
8    config::{StarkGenericConfig, Val},
9    prover::types::{AirProofInput, AirProofRawInput},
10};
11
12/// Test helper trait for AirProofInput
13/// Don't use this trait in production code
14pub trait AirProofInputTestHelper<SC: StarkGenericConfig> {
15    fn cached_traces_no_pis(
16        cached_traces: Vec<RowMajorMatrix<Val<SC>>>,
17        common_trace: RowMajorMatrix<Val<SC>>,
18    ) -> Self;
19}
20
21impl<SC: StarkGenericConfig> AirProofInputTestHelper<SC> for AirProofInput<SC> {
22    fn cached_traces_no_pis(
23        cached_traces: Vec<RowMajorMatrix<Val<SC>>>,
24        common_trace: RowMajorMatrix<Val<SC>>,
25    ) -> Self {
26        Self {
27            cached_mains_pdata: vec![],
28            raw: AirProofRawInput {
29                cached_mains: cached_traces.into_iter().map(Arc::new).collect(),
30                common_main: Some(common_trace),
31                public_values: vec![],
32            },
33        }
34    }
35}
36impl<SC: StarkGenericConfig> AirProofInput<SC> {
37    pub fn simple(trace: RowMajorMatrix<Val<SC>>, public_values: Vec<Val<SC>>) -> Self {
38        Self {
39            cached_mains_pdata: vec![],
40            raw: AirProofRawInput {
41                cached_mains: vec![],
42                common_main: Some(trace),
43                public_values,
44            },
45        }
46    }
47    pub fn simple_no_pis(trace: RowMajorMatrix<Val<SC>>) -> Self {
48        Self::simple(trace, vec![])
49    }
50
51    pub fn multiple_simple(
52        traces: Vec<RowMajorMatrix<Val<SC>>>,
53        public_values: Vec<Vec<Val<SC>>>,
54    ) -> Vec<Self> {
55        izip!(traces, public_values)
56            .map(|(trace, pis)| AirProofInput::simple(trace, pis))
57            .collect()
58    }
59
60    pub fn multiple_simple_no_pis(traces: Vec<RowMajorMatrix<Val<SC>>>) -> Vec<Self> {
61        traces
62            .into_iter()
63            .map(AirProofInput::simple_no_pis)
64            .collect()
65    }
66    /// Return the height of the main trace.
67    pub fn main_trace_height(&self) -> usize {
68        if self.raw.cached_mains.is_empty() {
69            // An AIR must have a main trace.
70            self.raw.common_main.as_ref().unwrap().height()
71        } else {
72            self.raw.cached_mains[0].height()
73        }
74    }
75}