openvm_stark_sdk/
cost_estimate.rs

1use std::{marker::PhantomData, ops::Add};
2
3use openvm_stark_backend::{
4    config::{Com, StarkGenericConfig, Val},
5    keygen::types::StarkVerifyingKey,
6    p3_field::FieldExtensionAlgebra,
7};
8
9use crate::config::FriParameters;
10
11/// Properties of a multi-trace circuit necessary to estimate verifier cost.
12#[derive(Clone, Copy, Debug)]
13pub struct VerifierCostParameters {
14    /// Total number of base field columns across all AIR traces before challenge.
15    pub num_main_columns: usize,
16    /// Total number of base field columns across all AIR traces for logup permutation.
17    pub num_perm_columns: usize,
18    /// log_2 Maximum height of an AIR trace.
19    pub log_max_height: usize,
20    /// Degree of quotient polynomial. This is `max_constraint_degree - 1`.
21    pub quotient_degree: usize,
22}
23
24/// Mmcs batch verification consist of hashing the leaf and then a normal Merkle proof.
25/// We separate the cost of hashing (which requires proper padding to be a crytographic hash) from the cost of
26/// 2-to-1 compression function on the hash digest because in tree proofs the internal layers do not need to use
27/// a compression function with padding.
28///
29/// Currently the estimate ignores the additional details of hashing in matrices of different heights.
30#[derive(Clone, Copy, Debug)]
31pub struct MmcsVerifyBatchCostEstimate {
32    /// Hash cost in terms of number of field elments to hash. To convert to true hash cost, it depends on the rate
33    /// of the cryptographic hash.
34    pub num_f_to_hash: usize,
35    /// Number of calls of 2-to-1 compression function.
36    pub num_compress: usize,
37}
38
39impl MmcsVerifyBatchCostEstimate {
40    /// `width` is number of base field columns.
41    /// `max_log_height_lde` is the height of the MMCS (which includes blowup)
42    pub fn from_dim(width: usize, max_log_height_lde: usize) -> Self {
43        Self {
44            num_f_to_hash: width,
45            num_compress: max_log_height_lde,
46        }
47    }
48}
49
50impl Add for MmcsVerifyBatchCostEstimate {
51    type Output = Self;
52
53    fn add(self, rhs: Self) -> Self::Output {
54        Self {
55            num_f_to_hash: self.num_f_to_hash + rhs.num_f_to_hash,
56            num_compress: self.num_compress + rhs.num_compress,
57        }
58    }
59}
60
61#[derive(Clone, Copy, Debug)]
62pub struct FriOpenInputCostEstimate {
63    /// Cost from MMCS batch verification.
64    pub mmcs: MmcsVerifyBatchCostEstimate,
65    /// Number of operations of the form $+ \alpha^? \frac{M_j(\zeta) - y_{ij}}{\zeta - z_i}$ in the reduced opening evaluation.
66    pub num_ro_eval: usize,
67}
68
69impl FriOpenInputCostEstimate {
70    /// `width` is number of base field columns.
71    /// `max_log_height` is the trace height, before blowup.
72    /// `num_points` is number of points to open.
73    pub fn new(
74        width: usize,
75        max_log_height: usize,
76        num_points: usize,
77        fri_params: FriParameters,
78    ) -> Self {
79        let mut mmcs =
80            MmcsVerifyBatchCostEstimate::from_dim(width, max_log_height + fri_params.log_blowup);
81        mmcs.num_compress *= fri_params.num_queries;
82        mmcs.num_f_to_hash *= fri_params.num_queries;
83        let num_ro_eval = width * num_points * fri_params.num_queries;
84        Self {
85            mmcs: MmcsVerifyBatchCostEstimate::from_dim(width, max_log_height),
86            num_ro_eval,
87        }
88    }
89}
90
91impl Add for FriOpenInputCostEstimate {
92    type Output = Self;
93
94    fn add(self, rhs: Self) -> Self::Output {
95        Self {
96            mmcs: self.mmcs + rhs.mmcs,
97            num_ro_eval: self.num_ro_eval + rhs.num_ro_eval,
98        }
99    }
100}
101
102pub struct FriQueryCostEstimate {
103    /// Cost from MMCS batch verification.
104    pub mmcs: MmcsVerifyBatchCostEstimate,
105    /// Number of single FRI fold evaluations: `e0 + (beta - xs[0]) * (e1 - e0) / (xs[1] - xs[0])`.
106    pub num_fri_folds: usize,
107}
108
109impl FriQueryCostEstimate {
110    /// `max_log_height` is the trace height, before blowup.
111    pub fn new(max_log_height: usize, fri_params: FriParameters) -> Self {
112        let mut mmcs = MmcsVerifyBatchCostEstimate {
113            num_f_to_hash: 2 * max_log_height,
114            num_compress: max_log_height * (max_log_height + fri_params.log_blowup - 1) / 2,
115        };
116        mmcs.num_compress *= fri_params.num_queries;
117        mmcs.num_f_to_hash *= fri_params.num_queries;
118        let num_fri_folds = max_log_height * fri_params.num_queries;
119        Self {
120            mmcs,
121            num_fri_folds,
122        }
123    }
124}
125
126impl Add for FriQueryCostEstimate {
127    type Output = Self;
128
129    fn add(self, rhs: Self) -> Self::Output {
130        Self {
131            mmcs: self.mmcs + rhs.mmcs,
132            num_fri_folds: self.num_fri_folds + rhs.num_fri_folds,
133        }
134    }
135}
136
137pub struct FriVerifierCostEstimate {
138    pub open_input: FriOpenInputCostEstimate,
139    pub query: FriQueryCostEstimate,
140    /// We currently ignore the constraint evaluation cost because it does not scale with number of FRI queries.
141    pub constraint_eval: PhantomData<usize>,
142}
143
144impl FriVerifierCostEstimate {
145    pub fn new(
146        params: VerifierCostParameters,
147        fri_params: FriParameters,
148        ext_degree: usize,
149    ) -> Self {
150        // Go through different rounds: preprocessed, main, permutation, quotient
151
152        // TODO: ignoring preprocessed trace opening for now
153
154        // Main
155        // Currently assumes opening at just zeta, omega * zeta
156        let mut open_input = FriOpenInputCostEstimate::new(
157            params.num_main_columns,
158            params.log_max_height,
159            2,
160            fri_params,
161        );
162        let mut query = FriQueryCostEstimate::new(params.log_max_height, fri_params);
163
164        // Permutation
165        // Currently assumes opening at just zeta, omega * zeta
166        open_input = open_input
167            + FriOpenInputCostEstimate::new(
168                params.num_perm_columns,
169                params.log_max_height,
170                2,
171                fri_params,
172            );
173        query = query + FriQueryCostEstimate::new(params.log_max_height, fri_params);
174
175        // Add quotient polynomial opening contribution
176        // Quotient only opens at single point zeta
177        open_input = open_input
178            + FriOpenInputCostEstimate::new(
179                params.quotient_degree * ext_degree,
180                params.log_max_height,
181                1,
182                fri_params,
183            );
184        query = query + FriQueryCostEstimate::new(params.log_max_height, fri_params);
185
186        Self {
187            open_input,
188            query,
189            constraint_eval: PhantomData,
190        }
191    }
192
193    pub fn from_vk<SC: StarkGenericConfig>(
194        vks: &[&StarkVerifyingKey<Val<SC>, Com<SC>>],
195        fri_params: FriParameters,
196        log_max_height: usize,
197    ) -> Self {
198        let num_main_columns: usize = vks
199            .iter()
200            .map(|vk| {
201                vk.params.width.common_main + vk.params.width.cached_mains.iter().sum::<usize>()
202            })
203            .sum();
204        let ext_degree = <SC::Challenge as FieldExtensionAlgebra<Val<SC>>>::D;
205        let num_perm_columns: usize = vks
206            .iter()
207            .map(|vk| vk.params.width.after_challenge.iter().sum::<usize>())
208            .sum::<usize>()
209            * ext_degree;
210        let quotient_degree = vks.iter().map(|vk| vk.quotient_degree).max().unwrap_or(0) as usize;
211        Self::new(
212            VerifierCostParameters {
213                num_main_columns,
214                num_perm_columns,
215                log_max_height,
216                quotient_degree,
217            },
218            fri_params,
219            ext_degree,
220        )
221    }
222}