openvm_circuit_primitives/encoder/
mod.rs

1use std::ops::RangeInclusive;
2
3use openvm_stark_backend::{
4    interaction::InteractionBuilder,
5    p3_field::{Field, FieldAlgebra},
6};
7
8use crate::SubAir;
9
10/// Efficient encoding of circuit selectors
11///
12/// This encoder represents selectors as points in a k-dimensional space where each
13/// coordinate is between 0 and max_degree, and their sum doesn't exceed max_degree.
14/// This approach allows encoding many selectors with significantly fewer columns
15/// than the traditional approach of using one boolean column per selector.
16#[derive(Clone, Debug)]
17pub struct Encoder {
18    /// Number of variables (columns) used to encode the flags
19    var_cnt: usize,
20    /// The number of flags, excluding the invalid/dummy flag.
21    flag_cnt: usize,
22    /// Maximal degree of the flag expressions.
23    /// The maximal degree of the equalities in the AIR, however, **is one higher:** that is, `max_flag_degree + 1`.
24    max_flag_degree: u32,
25    /// All possible points in the k-dimensional space that can be used to encode flags
26    pts: Vec<Vec<u32>>,
27    /// Whether the zero point (0,...,0) is reserved for invalid/dummy rows
28    reserve_invalid: bool,
29}
30
31impl Encoder {
32    /// Create a new encoder for a given number of flags and maximum degree.
33    /// The flags will correspond to points in F^k, where k is the number of variables.
34    /// The zero point is reserved for the dummy row.
35    /// `max_degree` is the upper bound for the flag expressions, but the `eval` function
36    /// of the encoder itself will use some constraints of degree `max_degree + 1`.
37    /// `reserve_invalid` indicates if the encoder should reserve the (0, ..., 0) point as an invalid/dummy flag.
38    pub fn new(cnt: usize, max_degree: u32, reserve_invalid: bool) -> Self {
39        // Calculate binomial coefficient (d+k choose k) to determine how many points we can encode
40        let binomial = |x: u32| {
41            let mut res = 1;
42            for i in 1..=max_degree {
43                res = res * (x + i) / i;
44            }
45            res
46        };
47        // Find minimum k (number of variables) needed to encode cnt flags
48        let k = (0..)
49            .find(|&x| binomial(x) >= cnt as u32 + reserve_invalid as u32)
50            .unwrap() as usize;
51
52        // Generate all points where coordinates sum to at most max_degree
53        let mut cur = vec![0u32; k];
54        let mut sum = 0;
55        let mut pts = Vec::new();
56        loop {
57            pts.push(cur.clone());
58            if cur[0] == max_degree {
59                break;
60            }
61            let mut i = k - 1;
62            while sum == max_degree {
63                sum -= cur[i];
64                cur[i] = 0;
65                i -= 1;
66            }
67            sum += 1;
68            cur[i] += 1;
69        }
70        Self {
71            var_cnt: k,
72            flag_cnt: cnt,
73            max_flag_degree: max_degree,
74            pts,
75            reserve_invalid,
76        }
77    }
78
79    /// Construct the multivariate Lagrange polynomial for a specific point
80    /// This polynomial equals 1 at the given point and 0 at all other points
81    /// in our solution set
82    fn expression_for_point<AB: InteractionBuilder>(
83        &self,
84        pt: &[u32],
85        vars: &[AB::Var],
86    ) -> AB::Expr {
87        assert_eq!(self.var_cnt, pt.len(), "wrong point dimension");
88        assert_eq!(self.var_cnt, vars.len(), "wrong number of variables");
89        let mut expr = AB::Expr::ONE;
90        let mut denom = AB::F::ONE;
91
92        // First part: product for each coordinate
93        for (i, &coord) in pt.iter().enumerate() {
94            for j in 0..coord {
95                expr *= vars[i] - AB::Expr::from_canonical_u32(j);
96                denom *= AB::F::from_canonical_u32(coord - j);
97            }
98        }
99
100        // Second part: ensure the sum doesn't exceed max_degree
101        {
102            let sum: u32 = pt.iter().sum();
103            let var_sum = vars.iter().fold(AB::Expr::ZERO, |acc, &v| acc + v);
104            for j in 0..(self.max_flag_degree - sum) {
105                expr *= AB::Expr::from_canonical_u32(self.max_flag_degree - j) - var_sum.clone();
106                denom *= AB::F::from_canonical_u32(j + 1);
107            }
108        }
109        expr * denom.inverse()
110    }
111
112    /// Get the polynomial expression that equals 1 when the variables encode the flag at index flag_idx
113    pub fn get_flag_expr<AB: InteractionBuilder>(
114        &self,
115        flag_idx: usize,
116        vars: &[AB::Var],
117    ) -> AB::Expr {
118        assert!(flag_idx < self.flag_cnt, "flag index out of range");
119        self.expression_for_point::<AB>(&self.pts[flag_idx + self.reserve_invalid as usize], vars)
120    }
121
122    /// Get the point coordinates that correspond to the flag at index flag_idx
123    pub fn get_flag_pt(&self, flag_idx: usize) -> Vec<u32> {
124        assert!(flag_idx < self.flag_cnt, "flag index out of range");
125        self.pts[flag_idx + self.reserve_invalid as usize].clone()
126    }
127
128    /// Returns an expression that is 1 if the variables encode a valid flag and 0 if they encode the invalid point
129    pub fn is_valid<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
130        AB::Expr::ONE - self.expression_for_point::<AB>(&self.pts[0], vars)
131    }
132
133    /// Returns all flag expressions for the given variables
134    pub fn flags<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> Vec<AB::Expr> {
135        (0..self.flag_cnt)
136            .map(|i| self.get_flag_expr::<AB>(i, vars))
137            .collect()
138    }
139
140    /// Returns the sum of expressions for all unused points
141    /// This is used to ensure that variables encode only valid flags
142    pub fn sum_of_unused<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
143        let mut expr = AB::Expr::ZERO;
144        for i in (self.flag_cnt + self.reserve_invalid as usize)..self.pts.len() {
145            expr += self.expression_for_point::<AB>(&self.pts[i], vars);
146        }
147        expr
148    }
149
150    /// Returns the number of variables used for encoding
151    pub fn width(&self) -> usize {
152        self.var_cnt
153    }
154
155    /// Returns an expression that is 1 if `flag_idxs` contains the encoded flag and 0 otherwise
156    pub fn contains_flag<AB: InteractionBuilder>(
157        &self,
158        vars: &[AB::Var],
159        flag_idxs: &[usize],
160    ) -> AB::Expr {
161        flag_idxs.iter().fold(AB::Expr::ZERO, |acc, flag_idx| {
162            acc + self.get_flag_expr::<AB>(*flag_idx, vars)
163        })
164    }
165
166    /// Returns an expression that is 1 if (l..=r) contains the encoded flag and 0 otherwise
167    pub fn contains_flag_range<AB: InteractionBuilder>(
168        &self,
169        vars: &[AB::Var],
170        range: RangeInclusive<usize>,
171    ) -> AB::Expr {
172        self.contains_flag::<AB>(vars, &range.collect::<Vec<_>>())
173    }
174
175    /// Returns an expression that is 0 if `flag_idxs_vals` doesn't contain the encoded flag
176    /// and the corresponding val if it does
177    /// `flag_idxs_vals` is a list of tuples (flag_idx, val)
178    pub fn flag_with_val<AB: InteractionBuilder>(
179        &self,
180        vars: &[AB::Var],
181        flag_idx_vals: &[(usize, usize)],
182    ) -> AB::Expr {
183        flag_idx_vals
184            .iter()
185            .fold(AB::Expr::ZERO, |acc, (flag_idx, val)| {
186                acc + self.get_flag_expr::<AB>(*flag_idx, vars)
187                    * AB::Expr::from_canonical_usize(*val)
188            })
189    }
190}
191
192impl<AB: InteractionBuilder> SubAir<AB> for Encoder {
193    type AirContext<'a>
194        = &'a [AB::Var]
195    where
196        AB: 'a,
197        AB::Var: 'a,
198        AB::Expr: 'a;
199
200    fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
201    where
202        AB: 'a,
203        AB::Expr: 'a,
204    {
205        assert_eq!(local.len(), self.var_cnt, "wrong number of variables");
206
207        // Helper function to create the product (x-0)(x-1)...(x-max_degree)
208        let falling_factorial = |lin: AB::Expr| {
209            let mut res = AB::Expr::ONE;
210            for i in 0..=self.max_flag_degree {
211                res *= lin.clone() - AB::Expr::from_canonical_u32(i);
212            }
213            res
214        };
215        // All x_i are from 0 to max_degree
216        for &var in local.iter() {
217            builder.assert_zero(falling_factorial(var.into()))
218        }
219        // Sum of all x_i is from 0 to max_degree
220        builder.assert_zero(falling_factorial(
221            local.iter().fold(AB::Expr::ZERO, |acc, &x| acc + x),
222        ));
223        // This constraint guarantees that the encoded point either:
224        // 1. Is the zero point (0,...,0) if reserved for invalid/dummy rows, or
225        // 2. Represents one of our defined selectors (flag_idx from 0 to flag_cnt-1)
226        // It works by requiring the sum of Lagrange polynomials for all unused points to be zero,
227        // which forces the current point to be one of our explicitly defined selector patterns
228        builder.assert_zero(self.sum_of_unused::<AB>(local));
229    }
230}