openvm_stark_backend/prover/cpu/quotient/
evaluator.rs

1use std::ops::{Add, Mul, Neg, Sub};
2
3use derivative::Derivative;
4use p3_field::FieldAlgebra;
5
6use crate::{
7    air_builders::symbolic::{
8        symbolic_expression::SymbolicEvaluator,
9        symbolic_variable::{Entry, SymbolicVariable},
10        SymbolicExpressionDag,
11    },
12    config::{PackedChallenge, PackedVal, StarkGenericConfig, Val},
13};
14
15pub(super) struct ViewPair<T> {
16    local: Vec<T>,
17    next: Option<Vec<T>>,
18}
19
20impl<T> ViewPair<T> {
21    pub fn new(local: Vec<T>, next: Option<Vec<T>>) -> Self {
22        Self { local, next }
23    }
24
25    /// SAFETY: no matrix bounds checks are done.
26    pub unsafe fn get(&self, row_offset: usize, column_idx: usize) -> &T {
27        match row_offset {
28            0 => self.local.get_unchecked(column_idx),
29            1 => self
30                .next
31                .as_ref()
32                .unwrap_unchecked()
33                .get_unchecked(column_idx),
34            _ => panic!("row offset {row_offset} not supported"),
35        }
36    }
37}
38
39/// A struct for quotient polynomial evaluation. This evaluates `WIDTH` rows of the quotient polynomial
40/// simultaneously using SIMD (if target arch allows it) via `PackedVal` and `PackedChallenge` types.
41pub(super) struct ProverConstraintEvaluator<'a, SC: StarkGenericConfig> {
42    pub preprocessed: ViewPair<PackedVal<SC>>,
43    pub partitioned_main: Vec<ViewPair<PackedVal<SC>>>,
44    pub after_challenge: Vec<ViewPair<PackedChallenge<SC>>>,
45    pub challenges: &'a [Vec<PackedChallenge<SC>>],
46    pub is_first_row: PackedVal<SC>,
47    pub is_last_row: PackedVal<SC>,
48    pub is_transition: PackedVal<SC>,
49    pub public_values: &'a [Val<SC>],
50    pub exposed_values_after_challenge: &'a [Vec<PackedChallenge<SC>>],
51}
52
53/// In order to avoid extension field arithmetic as much as possible, we evaluate into
54/// the smallest packed expression possible.
55#[derive(Derivative, Copy)]
56#[derivative(Clone(bound = ""))]
57enum PackedExpr<SC: StarkGenericConfig> {
58    Val(PackedVal<SC>),
59    Challenge(PackedChallenge<SC>),
60}
61
62impl<SC: StarkGenericConfig> Add for PackedExpr<SC> {
63    type Output = Self;
64
65    fn add(self, other: Self) -> Self {
66        match (self, other) {
67            (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x + y),
68            (PackedExpr::Val(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(y + x),
69            (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x + y),
70            (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x + y),
71        }
72    }
73}
74
75impl<SC: StarkGenericConfig> Sub for PackedExpr<SC> {
76    type Output = Self;
77
78    fn sub(self, other: Self) -> Self {
79        match (self, other) {
80            (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x - y),
81            (PackedExpr::Val(x), PackedExpr::Challenge(y)) => {
82                let x: PackedChallenge<SC> = x.into();
83                // We could alternative do (-y) + x
84                PackedExpr::Challenge(x - y)
85            }
86            (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x - y),
87            (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x - y),
88        }
89    }
90}
91
92impl<SC: StarkGenericConfig> Mul for PackedExpr<SC> {
93    type Output = Self;
94
95    fn mul(self, other: Self) -> Self {
96        match (self, other) {
97            (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x * y),
98            (PackedExpr::Val(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(y * x),
99            (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x * y),
100            (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x * y),
101        }
102    }
103}
104
105impl<SC: StarkGenericConfig> Neg for PackedExpr<SC> {
106    type Output = Self;
107
108    fn neg(self) -> Self {
109        match self {
110            PackedExpr::Val(x) => PackedExpr::Val(-x),
111            PackedExpr::Challenge(x) => PackedExpr::Challenge(-x),
112        }
113    }
114}
115
116impl<SC> SymbolicEvaluator<Val<SC>, PackedExpr<SC>> for ProverConstraintEvaluator<'_, SC>
117where
118    SC: StarkGenericConfig,
119{
120    fn eval_const(&self, c: Val<SC>) -> PackedExpr<SC> {
121        PackedExpr::Val(c.into())
122    }
123    fn eval_is_first_row(&self) -> PackedExpr<SC> {
124        PackedExpr::Val(self.is_first_row)
125    }
126    fn eval_is_last_row(&self) -> PackedExpr<SC> {
127        PackedExpr::Val(self.is_last_row)
128    }
129    fn eval_is_transition(&self) -> PackedExpr<SC> {
130        PackedExpr::Val(self.is_transition)
131    }
132
133    /// SAFETY: we only use this trait implementation when we have already done
134    /// a previous scan to ensure all matrix bounds are satisfied,
135    /// so no bounds checks are done here.
136    fn eval_var(&self, symbolic_var: SymbolicVariable<Val<SC>>) -> PackedExpr<SC> {
137        let index = symbolic_var.index;
138        match symbolic_var.entry {
139            Entry::Preprocessed { offset } => unsafe {
140                PackedExpr::Val(*self.preprocessed.get(offset, index))
141            },
142            Entry::Main { part_index, offset } => unsafe {
143                PackedExpr::Val(*self.partitioned_main[part_index].get(offset, index))
144            },
145            Entry::Public => unsafe {
146                PackedExpr::Val((*self.public_values.get_unchecked(index)).into())
147            },
148            Entry::Permutation { offset } => unsafe {
149                let perm = self.after_challenge.get_unchecked(0);
150                PackedExpr::Challenge(*perm.get(offset, index))
151            },
152            Entry::Challenge => unsafe {
153                PackedExpr::Challenge(*self.challenges.get_unchecked(0).get_unchecked(index))
154            },
155            Entry::Exposed => unsafe {
156                PackedExpr::Challenge(
157                    *self
158                        .exposed_values_after_challenge
159                        .get_unchecked(0)
160                        .get_unchecked(index),
161                )
162            },
163        }
164    }
165}
166
167impl<SC: StarkGenericConfig> ProverConstraintEvaluator<'_, SC> {
168    /// `alpha_powers` are in **reversed** order, with highest power coming first.
169    // Note: this could be split into multiple functions if additional constraints need to be folded in
170    pub fn accumulate(
171        &self,
172        constraints: &SymbolicExpressionDag<Val<SC>>,
173        alpha_powers: &[PackedChallenge<SC>],
174    ) -> PackedChallenge<SC> {
175        let evaluated_nodes = self.eval_nodes(&constraints.nodes);
176        let mut accumulator = PackedChallenge::<SC>::ZERO;
177        for (&alpha_pow, &node_idx) in alpha_powers.iter().zip(&constraints.constraint_idx) {
178            match evaluated_nodes[node_idx] {
179                PackedExpr::Val(x) => accumulator += alpha_pow * x,
180                PackedExpr::Challenge(x) => accumulator += alpha_pow * x,
181            }
182        }
183        accumulator
184    }
185}