openvm_stark_backend/poly/
uni.rs

1//! Copied from starkware-libs/stwo under Apache-2.0 license.
2use std::{
3    iter::Sum,
4    ops::{Add, Deref, Mul, Neg, Sub},
5};
6
7use p3_field::Field;
8
9#[derive(Debug, Clone)]
10pub struct UnivariatePolynomial<F> {
11    coeffs: Vec<F>,
12}
13
14impl<F: Field> UnivariatePolynomial<F> {
15    /// Creates a new univariate polynomial from a vector of coefficients.
16    pub fn from_coeffs(coeffs: Vec<F>) -> Self {
17        let mut polynomial = Self { coeffs };
18        polynomial.trim_leading_zeroes();
19        polynomial
20    }
21
22    pub fn zero() -> Self {
23        Self { coeffs: vec![] }
24    }
25
26    fn one() -> Self {
27        Self {
28            coeffs: vec![F::ONE],
29        }
30    }
31
32    fn is_zero(&self) -> bool {
33        self.coeffs.iter().all(F::is_zero)
34    }
35
36    pub fn evaluate(&self, x: F) -> F {
37        self.coeffs
38            .iter()
39            .rfold(F::ZERO, |acc, coeff| acc * x + *coeff)
40    }
41
42    pub fn degree(&self) -> usize {
43        self.coeffs.iter().rposition(|&v| !v.is_zero()).unwrap_or(0)
44    }
45
46    /// Interpolates `points` via Lagrange interpolation.
47    ///
48    /// # Panics
49    ///
50    /// Panics if `points` contains duplicate x-coordinates.
51    pub fn from_interpolation(points: &[(F, F)]) -> Self {
52        let mut coeffs = Self::zero();
53
54        for (i, &(xi, yi)) in points.iter().enumerate() {
55            let mut num = UnivariatePolynomial::one();
56            let mut denom = F::ONE;
57
58            for (j, &(xj, _)) in points.iter().enumerate() {
59                if i != j {
60                    num = num * (Self::identity() - xj.into());
61                    denom *= xi - xj;
62                }
63            }
64
65            let selector = num * denom.inverse();
66            coeffs = coeffs + selector * yi;
67        }
68
69        coeffs.trim_leading_zeroes();
70        coeffs
71    }
72
73    fn identity() -> Self {
74        Self {
75            coeffs: vec![F::ZERO, F::ONE],
76        }
77    }
78
79    fn trim_leading_zeroes(&mut self) {
80        if let Some(non_zero_idx) = self.coeffs.iter().rposition(|&coeff| !coeff.is_zero()) {
81            self.coeffs.truncate(non_zero_idx + 1);
82        } else {
83            self.coeffs.clear();
84        }
85    }
86
87    pub fn into_coeffs(self) -> Vec<F> {
88        self.coeffs
89    }
90}
91
92impl<F: Field> Default for UnivariatePolynomial<F> {
93    fn default() -> Self {
94        Self::zero()
95    }
96}
97
98impl<F: Field> From<F> for UnivariatePolynomial<F> {
99    fn from(value: F) -> Self {
100        Self::from_coeffs(vec![value])
101    }
102}
103
104impl<F: Field> Mul<F> for UnivariatePolynomial<F> {
105    type Output = Self;
106
107    fn mul(mut self, rhs: F) -> Self {
108        for coeff in &mut self.coeffs {
109            *coeff *= rhs;
110        }
111        self
112    }
113}
114
115impl<F: Field> Mul for UnivariatePolynomial<F> {
116    type Output = Self;
117
118    fn mul(mut self, mut rhs: Self) -> Self {
119        if self.is_zero() || rhs.is_zero() {
120            return Self::zero();
121        }
122
123        self.trim_leading_zeroes();
124        rhs.trim_leading_zeroes();
125
126        let mut res = vec![F::ZERO; self.coeffs.len() + rhs.coeffs.len() - 1];
127
128        for (i, coeff_a) in self.coeffs.into_iter().enumerate() {
129            for (j, coeff_b) in rhs.coeffs.iter().enumerate() {
130                res[i + j] += coeff_a * *coeff_b;
131            }
132        }
133
134        Self::from_coeffs(res)
135    }
136}
137
138impl<F: Field> Add for UnivariatePolynomial<F> {
139    type Output = Self;
140
141    fn add(self, rhs: Self) -> Self {
142        let n = self.coeffs.len().max(rhs.coeffs.len());
143        let mut coeffs = Vec::with_capacity(n);
144
145        for i in 0..n {
146            let a = self.coeffs.get(i).copied().unwrap_or(F::ZERO);
147            let b = rhs.coeffs.get(i).copied().unwrap_or(F::ZERO);
148            coeffs.push(a + b);
149        }
150
151        Self { coeffs }
152    }
153}
154
155impl<F: Field> Sub for UnivariatePolynomial<F> {
156    type Output = Self;
157
158    fn sub(self, rhs: Self) -> Self {
159        self + (-rhs)
160    }
161}
162
163impl<F: Field> Neg for UnivariatePolynomial<F> {
164    type Output = Self;
165
166    fn neg(self) -> Self {
167        Self {
168            coeffs: self.coeffs.into_iter().map(|v| -v).collect(),
169        }
170    }
171}
172
173impl<F: Field> Deref for UnivariatePolynomial<F> {
174    type Target = [F];
175
176    fn deref(&self) -> &Self::Target {
177        &self.coeffs
178    }
179}
180
181/// Evaluates a polynomial represented by coefficients in a slice at a given point `x`.
182pub fn evaluate_on_slice<F: Field>(coeffs: &[F], x: F) -> F {
183    coeffs.iter().rfold(F::ZERO, |acc, &coeff| acc * x + coeff)
184}
185
186/// Returns `v_0 + alpha * v_1 + ... + alpha^(n-1) * v_{n-1}`.
187pub fn random_linear_combination<F: Field>(v: &[F], alpha: F) -> F {
188    evaluate_on_slice(v, alpha)
189}
190
191/// Projective fraction.
192#[derive(Debug, Clone, Copy)]
193pub struct Fraction<T> {
194    pub numerator: T,
195    pub denominator: T,
196}
197
198impl<T> Fraction<T> {
199    pub const fn new(numerator: T, denominator: T) -> Self {
200        Self {
201            numerator,
202            denominator,
203        }
204    }
205}
206
207impl<T: Clone + Add<Output = T> + Mul<Output = T>> Add for Fraction<T> {
208    type Output = Fraction<T>;
209
210    fn add(self, rhs: Self) -> Fraction<T> {
211        Fraction {
212            numerator: rhs.denominator.clone() * self.numerator.clone()
213                + self.denominator.clone() * rhs.numerator.clone(),
214            denominator: self.denominator * rhs.denominator,
215        }
216    }
217}
218
219impl<F: Field> Fraction<F> {
220    const ZERO: Self = Self::new(F::ZERO, F::ONE);
221
222    pub fn is_zero(&self) -> bool {
223        self.numerator.is_zero() && !self.denominator.is_zero()
224    }
225}
226
227impl<F: Field> Sum for Fraction<F> {
228    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
229        iter.fold(Self::ZERO, |a, b| a + b)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use std::iter::zip;
236
237    use itertools::Itertools;
238    use p3_baby_bear::BabyBear;
239    use p3_field::FieldAlgebra;
240
241    use super::*;
242
243    macro_rules! bbvec {
244        [$($x:expr),*] => {
245            vec![$(BabyBear::from_canonical_u32($x)),*]
246        }
247    }
248
249    #[test]
250    fn test_interpolate() {
251        let xs = bbvec![5, 1, 3, 9];
252        let ys = bbvec![1, 2, 3, 4];
253        let points = zip(&xs, &ys).map(|(x, y)| (*x, *y)).collect_vec();
254
255        let poly = UnivariatePolynomial::from_interpolation(&points);
256
257        for (x, y) in zip(xs, ys) {
258            assert_eq!(poly.evaluate(x), y, "mismatch for x={x}");
259        }
260    }
261
262    #[test]
263    fn test_eval() {
264        let coeffs = bbvec![9, 2, 3];
265        let x = BabyBear::from_canonical_u32(7);
266
267        let eval = UnivariatePolynomial::from_coeffs(coeffs.clone()).evaluate(x);
268
269        assert_eq!(eval, coeffs[0] + coeffs[1] * x + coeffs[2] * x.square());
270    }
271
272    #[test]
273    fn test_fractional_addition() {
274        let a = Fraction::new(BabyBear::ONE, BabyBear::from_canonical_u32(3));
275        let b = Fraction::new(BabyBear::TWO, BabyBear::from_canonical_u32(6));
276
277        let Fraction {
278            numerator,
279            denominator,
280        } = a + b;
281
282        assert_eq!(
283            numerator / denominator,
284            BabyBear::TWO / BabyBear::from_canonical_u32(3)
285        );
286    }
287
288    #[test]
289    fn test_degree() {
290        // Case 1: Zero polynomial (expect degree 0 for a polynomial with no terms)
291        let poly_zero = UnivariatePolynomial::<BabyBear>::from_coeffs(vec![]);
292        assert_eq!(
293            poly_zero.degree(),
294            0,
295            "Zero polynomial should have degree 0"
296        );
297
298        // Case 2: Polynomial with only a constant term (degree should be 0)
299        let poly_constant = UnivariatePolynomial::from_coeffs(bbvec![5]);
300        assert_eq!(
301            poly_constant.degree(),
302            0,
303            "Constant polynomial should have degree 0"
304        );
305
306        // Case 3: Linear polynomial (degree 1, e.g., 3x + 5)
307        let poly_linear = UnivariatePolynomial::from_coeffs(bbvec![5, 3]);
308        assert_eq!(
309            poly_linear.degree(),
310            1,
311            "Linear polynomial should have degree 1"
312        );
313
314        // Case 4: Quadratic polynomial with trailing zeros (degree should ignore trailing zeros)
315        let poly_quadratic = UnivariatePolynomial::from_coeffs(bbvec![2, 0, 4, 0, 0]);
316        assert_eq!(
317            poly_quadratic.degree(),
318            2,
319            "Quadratic polynomial with trailing zeros should have degree 2"
320        );
321
322        // Case 5: High-degree polynomial without trailing zeros
323        let poly_high_degree = UnivariatePolynomial::from_coeffs(bbvec![1, 0, 0, 0, 5]);
324        assert_eq!(
325            poly_high_degree.degree(),
326            4,
327            "Polynomial of degree 4 should have degree 4"
328        );
329    }
330}