1use std::ops::RangeInclusive;
23use openvm_stark_backend::{
4 interaction::InteractionBuilder,
5 p3_field::{Field, FieldAlgebra},
6};
78use crate::SubAir;
910/// Efficient encoding of circuit selectors
11///
12/// This encoder represents selectors as points in a k-dimensional space where each
13/// coordinate is between 0 and max_degree, and their sum doesn't exceed max_degree.
14/// This approach allows encoding many selectors with significantly fewer columns
15/// than the traditional approach of using one boolean column per selector.
16#[derive(Clone, Debug)]
17pub struct Encoder {
18/// Number of variables (columns) used to encode the flags
19var_cnt: usize,
20/// The number of flags, excluding the invalid/dummy flag.
21flag_cnt: usize,
22/// Maximal degree of the flag expressions.
23 /// The maximal degree of the equalities in the AIR, however, **is one higher:** that is, `max_flag_degree + 1`.
24max_flag_degree: u32,
25/// All possible points in the k-dimensional space that can be used to encode flags
26pts: Vec<Vec<u32>>,
27/// Whether the zero point (0,...,0) is reserved for invalid/dummy rows
28reserve_invalid: bool,
29}
3031impl Encoder {
32/// Create a new encoder for a given number of flags and maximum degree.
33 /// The flags will correspond to points in F^k, where k is the number of variables.
34 /// The zero point is reserved for the dummy row.
35 /// `max_degree` is the upper bound for the flag expressions, but the `eval` function
36 /// of the encoder itself will use some constraints of degree `max_degree + 1`.
37 /// `reserve_invalid` indicates if the encoder should reserve the (0, ..., 0) point as an invalid/dummy flag.
38pub fn new(cnt: usize, max_degree: u32, reserve_invalid: bool) -> Self {
39// Calculate binomial coefficient (d+k choose k) to determine how many points we can encode
40let binomial = |x: u32| {
41let mut res = 1;
42for i in 1..=max_degree {
43 res = res * (x + i) / i;
44 }
45 res
46 };
47// Find minimum k (number of variables) needed to encode cnt flags
48let k = (0..)
49 .find(|&x| binomial(x) >= cnt as u32 + reserve_invalid as u32)
50 .unwrap() as usize;
5152// Generate all points where coordinates sum to at most max_degree
53let mut cur = vec![0u32; k];
54let mut sum = 0;
55let mut pts = Vec::new();
56loop {
57 pts.push(cur.clone());
58if cur[0] == max_degree {
59break;
60 }
61let mut i = k - 1;
62while sum == max_degree {
63 sum -= cur[i];
64 cur[i] = 0;
65 i -= 1;
66 }
67 sum += 1;
68 cur[i] += 1;
69 }
70Self {
71 var_cnt: k,
72 flag_cnt: cnt,
73 max_flag_degree: max_degree,
74 pts,
75 reserve_invalid,
76 }
77 }
7879/// Construct the multivariate Lagrange polynomial for a specific point
80 /// This polynomial equals 1 at the given point and 0 at all other points
81 /// in our solution set
82fn expression_for_point<AB: InteractionBuilder>(
83&self,
84 pt: &[u32],
85 vars: &[AB::Var],
86 ) -> AB::Expr {
87assert_eq!(self.var_cnt, pt.len(), "wrong point dimension");
88assert_eq!(self.var_cnt, vars.len(), "wrong number of variables");
89let mut expr = AB::Expr::ONE;
90let mut denom = AB::F::ONE;
9192// First part: product for each coordinate
93for (i, &coord) in pt.iter().enumerate() {
94for j in 0..coord {
95 expr *= vars[i] - AB::Expr::from_canonical_u32(j);
96 denom *= AB::F::from_canonical_u32(coord - j);
97 }
98 }
99100// Second part: ensure the sum doesn't exceed max_degree
101{
102let sum: u32 = pt.iter().sum();
103let var_sum = vars.iter().fold(AB::Expr::ZERO, |acc, &v| acc + v);
104for j in 0..(self.max_flag_degree - sum) {
105 expr *= AB::Expr::from_canonical_u32(self.max_flag_degree - j) - var_sum.clone();
106 denom *= AB::F::from_canonical_u32(j + 1);
107 }
108 }
109 expr * denom.inverse()
110 }
111112/// Get the polynomial expression that equals 1 when the variables encode the flag at index flag_idx
113pub fn get_flag_expr<AB: InteractionBuilder>(
114&self,
115 flag_idx: usize,
116 vars: &[AB::Var],
117 ) -> AB::Expr {
118assert!(flag_idx < self.flag_cnt, "flag index out of range");
119self.expression_for_point::<AB>(&self.pts[flag_idx + self.reserve_invalid as usize], vars)
120 }
121122/// Get the point coordinates that correspond to the flag at index flag_idx
123pub fn get_flag_pt(&self, flag_idx: usize) -> Vec<u32> {
124assert!(flag_idx < self.flag_cnt, "flag index out of range");
125self.pts[flag_idx + self.reserve_invalid as usize].clone()
126 }
127128/// Returns an expression that is 1 if the variables encode a valid flag and 0 if they encode the invalid point
129pub fn is_valid<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
130 AB::Expr::ONE - self.expression_for_point::<AB>(&self.pts[0], vars)
131 }
132133/// Returns all flag expressions for the given variables
134pub fn flags<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> Vec<AB::Expr> {
135 (0..self.flag_cnt)
136 .map(|i| self.get_flag_expr::<AB>(i, vars))
137 .collect()
138 }
139140/// Returns the sum of expressions for all unused points
141 /// This is used to ensure that variables encode only valid flags
142pub fn sum_of_unused<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
143let mut expr = AB::Expr::ZERO;
144for i in (self.flag_cnt + self.reserve_invalid as usize)..self.pts.len() {
145 expr += self.expression_for_point::<AB>(&self.pts[i], vars);
146 }
147 expr
148 }
149150/// Returns the number of variables used for encoding
151pub fn width(&self) -> usize {
152self.var_cnt
153 }
154155/// Returns an expression that is 1 if `flag_idxs` contains the encoded flag and 0 otherwise
156pub fn contains_flag<AB: InteractionBuilder>(
157&self,
158 vars: &[AB::Var],
159 flag_idxs: &[usize],
160 ) -> AB::Expr {
161 flag_idxs.iter().fold(AB::Expr::ZERO, |acc, flag_idx| {
162 acc + self.get_flag_expr::<AB>(*flag_idx, vars)
163 })
164 }
165166/// Returns an expression that is 1 if (l..=r) contains the encoded flag and 0 otherwise
167pub fn contains_flag_range<AB: InteractionBuilder>(
168&self,
169 vars: &[AB::Var],
170 range: RangeInclusive<usize>,
171 ) -> AB::Expr {
172self.contains_flag::<AB>(vars, &range.collect::<Vec<_>>())
173 }
174175/// Returns an expression that is 0 if `flag_idxs_vals` doesn't contain the encoded flag
176 /// and the corresponding val if it does
177 /// `flag_idxs_vals` is a list of tuples (flag_idx, val)
178pub fn flag_with_val<AB: InteractionBuilder>(
179&self,
180 vars: &[AB::Var],
181 flag_idx_vals: &[(usize, usize)],
182 ) -> AB::Expr {
183 flag_idx_vals
184 .iter()
185 .fold(AB::Expr::ZERO, |acc, (flag_idx, val)| {
186 acc + self.get_flag_expr::<AB>(*flag_idx, vars)
187 * AB::Expr::from_canonical_usize(*val)
188 })
189 }
190}
191192impl<AB: InteractionBuilder> SubAir<AB> for Encoder {
193type AirContext<'a>
194 = &'a [AB::Var]
195where
196AB: 'a,
197 AB::Var: 'a,
198 AB::Expr: 'a;
199200fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
201where
202AB: 'a,
203 AB::Expr: 'a,
204 {
205assert_eq!(local.len(), self.var_cnt, "wrong number of variables");
206207// Helper function to create the product (x-0)(x-1)...(x-max_degree)
208let falling_factorial = |lin: AB::Expr| {
209let mut res = AB::Expr::ONE;
210for i in 0..=self.max_flag_degree {
211 res *= lin.clone() - AB::Expr::from_canonical_u32(i);
212 }
213 res
214 };
215// All x_i are from 0 to max_degree
216for &var in local.iter() {
217 builder.assert_zero(falling_factorial(var.into()))
218 }
219// Sum of all x_i is from 0 to max_degree
220builder.assert_zero(falling_factorial(
221 local.iter().fold(AB::Expr::ZERO, |acc, &x| acc + x),
222 ));
223// This constraint guarantees that the encoded point either:
224 // 1. Is the zero point (0,...,0) if reserved for invalid/dummy rows, or
225 // 2. Represents one of our defined selectors (flag_idx from 0 to flag_cnt-1)
226 // It works by requiring the sum of Lagrange polynomials for all unused points to be zero,
227 // which forces the current point to be one of our explicitly defined selector patterns
228builder.assert_zero(self.sum_of_unused::<AB>(local));
229 }
230}