openvm_circuit_primitives/encoder/
mod.rs

1use std::ops::RangeInclusive;
2
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_field::{Field, FieldAlgebra},
6};
7
8use crate::SubAir;
9
10#[cfg(all(test, feature = "cuda"))]
11mod tests;
12
13/// Efficient encoding of circuit selectors
14///
15/// This encoder represents selectors as points in a k-dimensional space where each
16/// coordinate is between 0 and max_degree, and their sum doesn't exceed max_degree.
17/// This approach allows encoding many selectors with significantly fewer columns
18/// than the traditional approach of using one boolean column per selector.
19#[derive(Clone, Debug)]
20pub struct Encoder {
21    /// Number of variables (columns) used to encode the flags
22    var_cnt: usize,
23    /// The number of flags, excluding the invalid/dummy flag.
24    flag_cnt: usize,
25    /// Maximal degree of the flag expressions.
26    /// The maximal degree of the equalities in the AIR, however, **is one higher:** that is,
27    /// `max_flag_degree + 1`.
28    max_flag_degree: u32,
29    /// All possible points in the k-dimensional space that can be used to encode flags
30    pts: Vec<Vec<u32>>,
31    /// Whether the zero point (0,...,0) is reserved for invalid/dummy rows
32    reserve_invalid: bool,
33}
34
35impl Encoder {
36    /// Create a new encoder for a given number of flags and maximum degree.
37    /// The flags will correspond to points in F^k, where k is the number of variables.
38    /// The zero point is reserved for the dummy row.
39    /// `max_degree` is the upper bound for the flag expressions, but the `eval` function
40    /// of the encoder itself will use some constraints of degree `max_degree + 1`.
41    /// `reserve_invalid` indicates if the encoder should reserve the (0, ..., 0) point as an
42    /// invalid/dummy flag.
43    pub fn new(cnt: usize, max_degree: u32, reserve_invalid: bool) -> Self {
44        // Calculate binomial coefficient (d+k choose k) to determine how many points we can encode
45        let binomial = |x: u32| {
46            let mut res = 1;
47            for i in 1..=max_degree {
48                res = res * (x + i) / i;
49            }
50            res
51        };
52        // Find minimum k (number of variables) needed to encode cnt flags
53        let k = (0..)
54            .find(|&x| binomial(x) >= cnt as u32 + reserve_invalid as u32)
55            .unwrap() as usize;
56
57        // Generate all points where coordinates sum to at most max_degree
58        let mut cur = vec![0u32; k];
59        let mut sum = 0;
60        let mut pts = Vec::new();
61        loop {
62            pts.push(cur.clone());
63            if cur[0] == max_degree {
64                break;
65            }
66            let mut i = k - 1;
67            while sum == max_degree {
68                sum -= cur[i];
69                cur[i] = 0;
70                i -= 1;
71            }
72            sum += 1;
73            cur[i] += 1;
74        }
75        Self {
76            var_cnt: k,
77            flag_cnt: cnt,
78            max_flag_degree: max_degree,
79            pts,
80            reserve_invalid,
81        }
82    }
83
84    /// Construct the multivariate Lagrange polynomial for a specific point
85    /// This polynomial equals 1 at the given point and 0 at all other points
86    /// in our solution set
87    fn expression_for_point<AB: InteractionBuilder>(
88        &self,
89        pt: &[u32],
90        vars: &[AB::Var],
91    ) -> AB::Expr {
92        assert_eq!(self.var_cnt, pt.len(), "wrong point dimension");
93        assert_eq!(self.var_cnt, vars.len(), "wrong number of variables");
94        let mut expr = AB::Expr::ONE;
95        let mut denom = AB::F::ONE;
96
97        // First part: product for each coordinate
98        for (i, &coord) in pt.iter().enumerate() {
99            for j in 0..coord {
100                expr *= vars[i] - AB::Expr::from_canonical_u32(j);
101                denom *= AB::F::from_canonical_u32(coord - j);
102            }
103        }
104
105        // Second part: ensure the sum doesn't exceed max_degree
106        {
107            let sum: u32 = pt.iter().sum();
108            let var_sum = vars.iter().fold(AB::Expr::ZERO, |acc, &v| acc + v);
109            for j in 0..(self.max_flag_degree - sum) {
110                expr *= AB::Expr::from_canonical_u32(self.max_flag_degree - j) - var_sum.clone();
111                denom *= AB::F::from_canonical_u32(j + 1);
112            }
113        }
114        expr * denom.inverse()
115    }
116
117    /// Get the polynomial expression that equals 1 when the variables encode the flag at index
118    /// flag_idx
119    pub fn get_flag_expr<AB: InteractionBuilder>(
120        &self,
121        flag_idx: usize,
122        vars: &[AB::Var],
123    ) -> AB::Expr {
124        assert!(flag_idx < self.flag_cnt, "flag index out of range");
125        self.expression_for_point::<AB>(&self.pts[flag_idx + self.reserve_invalid as usize], vars)
126    }
127
128    /// Get the point coordinates that correspond to the flag at index flag_idx
129    pub fn get_flag_pt(&self, flag_idx: usize) -> Vec<u32> {
130        assert!(flag_idx < self.flag_cnt, "flag index out of range");
131        self.pts[flag_idx + self.reserve_invalid as usize].clone()
132    }
133
134    /// Returns an expression that is 1 if the variables encode a valid flag and 0 if they encode
135    /// the invalid point
136    pub fn is_valid<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
137        AB::Expr::ONE - self.expression_for_point::<AB>(&self.pts[0], vars)
138    }
139
140    /// Returns all flag expressions for the given variables
141    pub fn flags<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> Vec<AB::Expr> {
142        (0..self.flag_cnt)
143            .map(|i| self.get_flag_expr::<AB>(i, vars))
144            .collect()
145    }
146
147    /// Returns the sum of expressions for all unused points
148    /// This is used to ensure that variables encode only valid flags
149    pub fn sum_of_unused<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
150        let mut expr = AB::Expr::ZERO;
151        for i in (self.flag_cnt + self.reserve_invalid as usize)..self.pts.len() {
152            expr += self.expression_for_point::<AB>(&self.pts[i], vars);
153        }
154        expr
155    }
156
157    /// Returns the number of variables used for encoding
158    pub fn width(&self) -> usize {
159        self.var_cnt
160    }
161
162    /// Returns an expression that is 1 if `flag_idxs` contains the encoded flag and 0 otherwise
163    pub fn contains_flag<AB: InteractionBuilder>(
164        &self,
165        vars: &[AB::Var],
166        flag_idxs: &[usize],
167    ) -> AB::Expr {
168        flag_idxs.iter().fold(AB::Expr::ZERO, |acc, flag_idx| {
169            acc + self.get_flag_expr::<AB>(*flag_idx, vars)
170        })
171    }
172
173    /// Returns an expression that is 1 if (l..=r) contains the encoded flag and 0 otherwise
174    pub fn contains_flag_range<AB: InteractionBuilder>(
175        &self,
176        vars: &[AB::Var],
177        range: RangeInclusive<usize>,
178    ) -> AB::Expr {
179        self.contains_flag::<AB>(vars, &range.collect::<Vec<_>>())
180    }
181
182    /// Returns an expression that is 0 if `flag_idxs_vals` doesn't contain the encoded flag
183    /// and the corresponding val if it does
184    /// `flag_idxs_vals` is a list of tuples (flag_idx, val)
185    pub fn flag_with_val<AB: InteractionBuilder>(
186        &self,
187        vars: &[AB::Var],
188        flag_idx_vals: &[(usize, usize)],
189    ) -> AB::Expr {
190        flag_idx_vals
191            .iter()
192            .fold(AB::Expr::ZERO, |acc, (flag_idx, val)| {
193                acc + self.get_flag_expr::<AB>(*flag_idx, vars)
194                    * AB::Expr::from_canonical_usize(*val)
195            })
196    }
197}
198
199impl<AB: InteractionBuilder> SubAir<AB> for Encoder {
200    type AirContext<'a>
201        = &'a [AB::Var]
202    where
203        AB: 'a,
204        AB::Var: 'a,
205        AB::Expr: 'a;
206
207    fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
208    where
209        AB: 'a,
210        AB::Expr: 'a,
211    {
212        assert_eq!(local.len(), self.var_cnt, "wrong number of variables");
213
214        // Helper function to create the product (x-0)(x-1)...(x-max_degree)
215        let falling_factorial = |lin: AB::Expr| {
216            let mut res = AB::Expr::ONE;
217            for i in 0..=self.max_flag_degree {
218                res *= lin.clone() - AB::Expr::from_canonical_u32(i);
219            }
220            res
221        };
222        // All x_i are from 0 to max_degree
223        for &var in local.iter() {
224            builder.assert_zero(falling_factorial(var.into()))
225        }
226        // Sum of all x_i is from 0 to max_degree
227        builder.assert_zero(falling_factorial(
228            local.iter().fold(AB::Expr::ZERO, |acc, &x| acc + x),
229        ));
230        // This constraint guarantees that the encoded point either:
231        // 1. Is the zero point (0,...,0) if reserved for invalid/dummy rows, or
232        // 2. Represents one of our defined selectors (flag_idx from 0 to flag_cnt-1)
233        // It works by requiring the sum of Lagrange polynomials for all unused points to be zero,
234        // which forces the current point to be one of our explicitly defined selector patterns
235        builder.assert_zero(self.sum_of_unused::<AB>(local));
236    }
237}