openvm_circuit_primitives/encoder/
mod.rs1use std::ops::RangeInclusive;
2
3use openvm_stark_backend::{
4 interaction::InteractionBuilder,
5 p3_field::{Field, FieldAlgebra},
6};
7
8use crate::SubAir;
9
10#[cfg(all(test, feature = "cuda"))]
11mod tests;
12
13#[derive(Clone, Debug)]
20pub struct Encoder {
21 var_cnt: usize,
23 flag_cnt: usize,
25 max_flag_degree: u32,
29 pts: Vec<Vec<u32>>,
31 reserve_invalid: bool,
33}
34
35impl Encoder {
36 pub fn new(cnt: usize, max_degree: u32, reserve_invalid: bool) -> Self {
44 let binomial = |x: u32| {
46 let mut res = 1;
47 for i in 1..=max_degree {
48 res = res * (x + i) / i;
49 }
50 res
51 };
52 let k = (0..)
54 .find(|&x| binomial(x) >= cnt as u32 + reserve_invalid as u32)
55 .unwrap() as usize;
56
57 let mut cur = vec![0u32; k];
59 let mut sum = 0;
60 let mut pts = Vec::new();
61 loop {
62 pts.push(cur.clone());
63 if cur[0] == max_degree {
64 break;
65 }
66 let mut i = k - 1;
67 while sum == max_degree {
68 sum -= cur[i];
69 cur[i] = 0;
70 i -= 1;
71 }
72 sum += 1;
73 cur[i] += 1;
74 }
75 Self {
76 var_cnt: k,
77 flag_cnt: cnt,
78 max_flag_degree: max_degree,
79 pts,
80 reserve_invalid,
81 }
82 }
83
84 fn expression_for_point<AB: InteractionBuilder>(
88 &self,
89 pt: &[u32],
90 vars: &[AB::Var],
91 ) -> AB::Expr {
92 assert_eq!(self.var_cnt, pt.len(), "wrong point dimension");
93 assert_eq!(self.var_cnt, vars.len(), "wrong number of variables");
94 let mut expr = AB::Expr::ONE;
95 let mut denom = AB::F::ONE;
96
97 for (i, &coord) in pt.iter().enumerate() {
99 for j in 0..coord {
100 expr *= vars[i] - AB::Expr::from_canonical_u32(j);
101 denom *= AB::F::from_canonical_u32(coord - j);
102 }
103 }
104
105 {
107 let sum: u32 = pt.iter().sum();
108 let var_sum = vars.iter().fold(AB::Expr::ZERO, |acc, &v| acc + v);
109 for j in 0..(self.max_flag_degree - sum) {
110 expr *= AB::Expr::from_canonical_u32(self.max_flag_degree - j) - var_sum.clone();
111 denom *= AB::F::from_canonical_u32(j + 1);
112 }
113 }
114 expr * denom.inverse()
115 }
116
117 pub fn get_flag_expr<AB: InteractionBuilder>(
120 &self,
121 flag_idx: usize,
122 vars: &[AB::Var],
123 ) -> AB::Expr {
124 assert!(flag_idx < self.flag_cnt, "flag index out of range");
125 self.expression_for_point::<AB>(&self.pts[flag_idx + self.reserve_invalid as usize], vars)
126 }
127
128 pub fn get_flag_pt(&self, flag_idx: usize) -> Vec<u32> {
130 assert!(flag_idx < self.flag_cnt, "flag index out of range");
131 self.pts[flag_idx + self.reserve_invalid as usize].clone()
132 }
133
134 pub fn is_valid<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
137 AB::Expr::ONE - self.expression_for_point::<AB>(&self.pts[0], vars)
138 }
139
140 pub fn flags<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> Vec<AB::Expr> {
142 (0..self.flag_cnt)
143 .map(|i| self.get_flag_expr::<AB>(i, vars))
144 .collect()
145 }
146
147 pub fn sum_of_unused<AB: InteractionBuilder>(&self, vars: &[AB::Var]) -> AB::Expr {
150 let mut expr = AB::Expr::ZERO;
151 for i in (self.flag_cnt + self.reserve_invalid as usize)..self.pts.len() {
152 expr += self.expression_for_point::<AB>(&self.pts[i], vars);
153 }
154 expr
155 }
156
157 pub fn width(&self) -> usize {
159 self.var_cnt
160 }
161
162 pub fn contains_flag<AB: InteractionBuilder>(
164 &self,
165 vars: &[AB::Var],
166 flag_idxs: &[usize],
167 ) -> AB::Expr {
168 flag_idxs.iter().fold(AB::Expr::ZERO, |acc, flag_idx| {
169 acc + self.get_flag_expr::<AB>(*flag_idx, vars)
170 })
171 }
172
173 pub fn contains_flag_range<AB: InteractionBuilder>(
175 &self,
176 vars: &[AB::Var],
177 range: RangeInclusive<usize>,
178 ) -> AB::Expr {
179 self.contains_flag::<AB>(vars, &range.collect::<Vec<_>>())
180 }
181
182 pub fn flag_with_val<AB: InteractionBuilder>(
186 &self,
187 vars: &[AB::Var],
188 flag_idx_vals: &[(usize, usize)],
189 ) -> AB::Expr {
190 flag_idx_vals
191 .iter()
192 .fold(AB::Expr::ZERO, |acc, (flag_idx, val)| {
193 acc + self.get_flag_expr::<AB>(*flag_idx, vars)
194 * AB::Expr::from_canonical_usize(*val)
195 })
196 }
197}
198
199impl<AB: InteractionBuilder> SubAir<AB> for Encoder {
200 type AirContext<'a>
201 = &'a [AB::Var]
202 where
203 AB: 'a,
204 AB::Var: 'a,
205 AB::Expr: 'a;
206
207 fn eval<'a>(&'a self, builder: &'a mut AB, local: &'a [AB::Var])
208 where
209 AB: 'a,
210 AB::Expr: 'a,
211 {
212 assert_eq!(local.len(), self.var_cnt, "wrong number of variables");
213
214 let falling_factorial = |lin: AB::Expr| {
216 let mut res = AB::Expr::ONE;
217 for i in 0..=self.max_flag_degree {
218 res *= lin.clone() - AB::Expr::from_canonical_u32(i);
219 }
220 res
221 };
222 for &var in local.iter() {
224 builder.assert_zero(falling_factorial(var.into()))
225 }
226 builder.assert_zero(falling_factorial(
228 local.iter().fold(AB::Expr::ZERO, |acc, &x| acc + x),
229 ));
230 builder.assert_zero(self.sum_of_unused::<AB>(local));
236 }
237}