openvm_stark_backend/prover/cpu/quotient/
evaluator.rs
1use std::ops::{Add, Mul, Neg, Sub};
2
3use derivative::Derivative;
4use p3_field::FieldAlgebra;
5
6use crate::{
7 air_builders::symbolic::{
8 symbolic_expression::SymbolicEvaluator,
9 symbolic_variable::{Entry, SymbolicVariable},
10 SymbolicExpressionDag,
11 },
12 config::{PackedChallenge, PackedVal, StarkGenericConfig, Val},
13};
14
15pub(super) struct ViewPair<T> {
16 local: Vec<T>,
17 next: Option<Vec<T>>,
18}
19
20impl<T> ViewPair<T> {
21 pub fn new(local: Vec<T>, next: Option<Vec<T>>) -> Self {
22 Self { local, next }
23 }
24
25 pub unsafe fn get(&self, row_offset: usize, column_idx: usize) -> &T {
27 match row_offset {
28 0 => self.local.get_unchecked(column_idx),
29 1 => self
30 .next
31 .as_ref()
32 .unwrap_unchecked()
33 .get_unchecked(column_idx),
34 _ => panic!("row offset {row_offset} not supported"),
35 }
36 }
37}
38
39pub(super) struct ProverConstraintEvaluator<'a, SC: StarkGenericConfig> {
42 pub preprocessed: ViewPair<PackedVal<SC>>,
43 pub partitioned_main: Vec<ViewPair<PackedVal<SC>>>,
44 pub after_challenge: Vec<ViewPair<PackedChallenge<SC>>>,
45 pub challenges: &'a [Vec<PackedChallenge<SC>>],
46 pub is_first_row: PackedVal<SC>,
47 pub is_last_row: PackedVal<SC>,
48 pub is_transition: PackedVal<SC>,
49 pub public_values: &'a [Val<SC>],
50 pub exposed_values_after_challenge: &'a [Vec<PackedChallenge<SC>>],
51}
52
53#[derive(Derivative, Copy)]
56#[derivative(Clone(bound = ""))]
57enum PackedExpr<SC: StarkGenericConfig> {
58 Val(PackedVal<SC>),
59 Challenge(PackedChallenge<SC>),
60}
61
62impl<SC: StarkGenericConfig> Add for PackedExpr<SC> {
63 type Output = Self;
64
65 fn add(self, other: Self) -> Self {
66 match (self, other) {
67 (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x + y),
68 (PackedExpr::Val(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(y + x),
69 (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x + y),
70 (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x + y),
71 }
72 }
73}
74
75impl<SC: StarkGenericConfig> Sub for PackedExpr<SC> {
76 type Output = Self;
77
78 fn sub(self, other: Self) -> Self {
79 match (self, other) {
80 (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x - y),
81 (PackedExpr::Val(x), PackedExpr::Challenge(y)) => {
82 let x: PackedChallenge<SC> = x.into();
83 PackedExpr::Challenge(x - y)
85 }
86 (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x - y),
87 (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x - y),
88 }
89 }
90}
91
92impl<SC: StarkGenericConfig> Mul for PackedExpr<SC> {
93 type Output = Self;
94
95 fn mul(self, other: Self) -> Self {
96 match (self, other) {
97 (PackedExpr::Val(x), PackedExpr::Val(y)) => PackedExpr::Val(x * y),
98 (PackedExpr::Val(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(y * x),
99 (PackedExpr::Challenge(x), PackedExpr::Val(y)) => PackedExpr::Challenge(x * y),
100 (PackedExpr::Challenge(x), PackedExpr::Challenge(y)) => PackedExpr::Challenge(x * y),
101 }
102 }
103}
104
105impl<SC: StarkGenericConfig> Neg for PackedExpr<SC> {
106 type Output = Self;
107
108 fn neg(self) -> Self {
109 match self {
110 PackedExpr::Val(x) => PackedExpr::Val(-x),
111 PackedExpr::Challenge(x) => PackedExpr::Challenge(-x),
112 }
113 }
114}
115
116impl<SC> SymbolicEvaluator<Val<SC>, PackedExpr<SC>> for ProverConstraintEvaluator<'_, SC>
117where
118 SC: StarkGenericConfig,
119{
120 fn eval_const(&self, c: Val<SC>) -> PackedExpr<SC> {
121 PackedExpr::Val(c.into())
122 }
123 fn eval_is_first_row(&self) -> PackedExpr<SC> {
124 PackedExpr::Val(self.is_first_row)
125 }
126 fn eval_is_last_row(&self) -> PackedExpr<SC> {
127 PackedExpr::Val(self.is_last_row)
128 }
129 fn eval_is_transition(&self) -> PackedExpr<SC> {
130 PackedExpr::Val(self.is_transition)
131 }
132
133 fn eval_var(&self, symbolic_var: SymbolicVariable<Val<SC>>) -> PackedExpr<SC> {
137 let index = symbolic_var.index;
138 match symbolic_var.entry {
139 Entry::Preprocessed { offset } => unsafe {
140 PackedExpr::Val(*self.preprocessed.get(offset, index))
141 },
142 Entry::Main { part_index, offset } => unsafe {
143 PackedExpr::Val(*self.partitioned_main[part_index].get(offset, index))
144 },
145 Entry::Public => unsafe {
146 PackedExpr::Val((*self.public_values.get_unchecked(index)).into())
147 },
148 Entry::Permutation { offset } => unsafe {
149 let perm = self.after_challenge.get_unchecked(0);
150 PackedExpr::Challenge(*perm.get(offset, index))
151 },
152 Entry::Challenge => unsafe {
153 PackedExpr::Challenge(*self.challenges.get_unchecked(0).get_unchecked(index))
154 },
155 Entry::Exposed => unsafe {
156 PackedExpr::Challenge(
157 *self
158 .exposed_values_after_challenge
159 .get_unchecked(0)
160 .get_unchecked(index),
161 )
162 },
163 }
164 }
165}
166
167impl<SC: StarkGenericConfig> ProverConstraintEvaluator<'_, SC> {
168 pub fn accumulate(
171 &self,
172 constraints: &SymbolicExpressionDag<Val<SC>>,
173 alpha_powers: &[PackedChallenge<SC>],
174 ) -> PackedChallenge<SC> {
175 let evaluated_nodes = self.eval_nodes(&constraints.nodes);
176 let mut accumulator = PackedChallenge::<SC>::ZERO;
177 for (&alpha_pow, &node_idx) in alpha_powers.iter().zip(&constraints.constraint_idx) {
178 match evaluated_nodes[node_idx] {
179 PackedExpr::Val(x) => accumulator += alpha_pow * x,
180 PackedExpr::Challenge(x) => accumulator += alpha_pow * x,
181 }
182 }
183 accumulator
184 }
185}