halo2_proofs/poly/
evaluator.rs

1use std::{
2    cmp,
3    collections::{HashMap, HashSet},
4    fmt,
5    hash::{Hash, Hasher},
6    marker::PhantomData,
7    ops::{Add, Mul, MulAssign, Neg, Sub},
8    sync::Arc,
9};
10
11use group::ff::Field;
12use pasta_curves::arithmetic::FieldExt;
13
14use super::{
15    Basis, Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation,
16};
17use crate::{arithmetic::parallelize, multicore};
18
19/// Returns `(chunk_size, num_chunks)` suitable for processing the given polynomial length
20/// in the current parallelization environment.
21fn get_chunk_params(poly_len: usize) -> (usize, usize) {
22    // Check the level of parallelization we have available.
23    let num_threads = multicore::current_num_threads();
24    // We scale the number of chunks by a constant factor, to ensure that if not all
25    // threads are available, we can achieve more uniform throughput and don't end up
26    // waiting on a couple of threads to process the last chunks.
27    let num_chunks = num_threads * 4;
28    // Calculate the ideal chunk size for the desired throughput. We use ceiling
29    // division to ensure the minimum chunk size is 1.
30    //     chunk_size = ceil(poly_len / num_chunks)
31    let chunk_size = (poly_len + num_chunks - 1) / num_chunks;
32    // Now re-calculate num_chunks from the actual chunk size.
33    //     num_chunks = ceil(poly_len / chunk_size)
34    let num_chunks = (poly_len + chunk_size - 1) / chunk_size;
35
36    (chunk_size, num_chunks)
37}
38
39/// A reference to a polynomial registered with an [`Evaluator`].
40#[derive(Clone, Copy)]
41pub(crate) struct AstLeaf<E, B: Basis> {
42    index: usize,
43    rotation: Rotation,
44    _evaluator: PhantomData<(E, B)>,
45}
46
47impl<E, B: Basis> fmt::Debug for AstLeaf<E, B> {
48    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49        f.debug_struct("AstLeaf")
50            .field("index", &self.index)
51            .field("rotation", &self.rotation)
52            .finish()
53    }
54}
55
56impl<E, B: Basis> PartialEq for AstLeaf<E, B> {
57    fn eq(&self, rhs: &Self) -> bool {
58        // We compare rotations by offset, which doesn't account for equivalent rotations.
59        self.index.eq(&rhs.index) && self.rotation.0.eq(&rhs.rotation.0)
60    }
61}
62
63impl<E, B: Basis> Eq for AstLeaf<E, B> {}
64
65impl<E, B: Basis> Hash for AstLeaf<E, B> {
66    fn hash<H: Hasher>(&self, state: &mut H) {
67        self.index.hash(state);
68        self.rotation.0.hash(state);
69    }
70}
71
72impl<E, B: Basis> AstLeaf<E, B> {
73    /// Produces a new `AstLeaf` node corresponding to the underlying polynomial at a
74    /// _new_ rotation. Existing rotations applied to this leaf node are ignored and the
75    /// returned polynomial is not rotated _relative_ to the previous structure.
76    pub(crate) fn with_rotation(&self, rotation: Rotation) -> Self {
77        AstLeaf {
78            index: self.index,
79            rotation,
80            _evaluator: PhantomData::default(),
81        }
82    }
83}
84
85/// An evaluation context for polynomial operations.
86///
87/// This context enables us to de-duplicate queries of circuit columns (and the rotations
88/// they might require), by storing a list of all the underlying polynomials involved in
89/// any query (which are almost certainly column polynomials). We use the context like so:
90///
91/// - We register each underlying polynomial with the evaluator, which returns a reference
92///   to it as a [`AstLeaf`].
93/// - The references are then used to build up a [`Ast`] that represents the overall
94///   operations to be applied to the polynomials.
95/// - Finally, we call [`Evaluator::evaluate`] passing in the [`Ast`].
96pub(crate) struct Evaluator<E, F: Field, B: Basis> {
97    polys: Vec<Polynomial<F, B>>,
98    _context: E,
99}
100
101/// Constructs a new `Evaluator`.
102///
103/// The `context` parameter is used to provide type safety for evaluators. It ensures that
104/// an evaluator will only be used to evaluate [`Ast`]s containing [`AstLeaf`]s obtained
105/// from itself. It should be set to the empty closure `|| {}`, because anonymous closures
106/// all have unique types.
107pub(crate) fn new_evaluator<E: Fn() + Clone, F: Field, B: Basis>(context: E) -> Evaluator<E, F, B> {
108    Evaluator {
109        polys: vec![],
110        _context: context,
111    }
112}
113
114impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
115    /// Registers the given polynomial for use in this evaluation context.
116    ///
117    /// This API treats each registered polynomial as unique, even if the same polynomial
118    /// is added multiple times.
119    pub(crate) fn register_poly(&mut self, poly: Polynomial<F, B>) -> AstLeaf<E, B> {
120        let index = self.polys.len();
121        self.polys.push(poly);
122
123        AstLeaf {
124            index,
125            rotation: Rotation::cur(),
126            _evaluator: PhantomData::default(),
127        }
128    }
129
130    /// Evaluates the given polynomial operation against this context.
131    pub(crate) fn evaluate(
132        &self,
133        ast: &Ast<E, F, B>,
134        domain: &EvaluationDomain<F>,
135    ) -> Polynomial<F, B>
136    where
137        E: Copy + Send + Sync,
138        F: FieldExt,
139        B: BasisOps,
140    {
141        // Traverse `ast` to collect the used leaves.
142        fn collect_rotations<E: Copy, F: Field, B: Basis>(
143            ast: &Ast<E, F, B>,
144        ) -> HashSet<AstLeaf<E, B>> {
145            match ast {
146                Ast::Poly(leaf) => vec![*leaf].into_iter().collect(),
147                Ast::Add(a, b) | Ast::Mul(AstMul(a, b)) => {
148                    let lhs = collect_rotations(a);
149                    let rhs = collect_rotations(b);
150                    lhs.union(&rhs).cloned().collect()
151                }
152                Ast::Scale(a, _) => collect_rotations(a),
153                Ast::DistributePowers(terms, _) => terms
154                    .iter()
155                    .flat_map(|term| collect_rotations(term).into_iter())
156                    .collect(),
157                Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(),
158            }
159        }
160        let leaves = collect_rotations(ast);
161
162        // Produce the rotated polynomials.
163        let rotated: HashMap<_, _> = leaves
164            .iter()
165            .cloned()
166            .map(|leaf| {
167                (
168                    leaf,
169                    B::rotate(domain, &self.polys[leaf.index], leaf.rotation),
170                )
171            })
172            .collect();
173
174        // We're working in a single basis, so all polynomials are the same length.
175        let poly_len = self.polys.first().unwrap().len();
176        let (chunk_size, num_chunks) = get_chunk_params(poly_len);
177
178        // Split each rotated polynomial into chunks.
179        let chunks: Vec<HashMap<_, _>> = (0..num_chunks)
180            .map(|i| {
181                rotated
182                    .iter()
183                    .map(|(leaf, poly)| {
184                        (
185                            *leaf,
186                            poly.chunks(chunk_size)
187                                .nth(i)
188                                .expect("num_chunks was calculated correctly"),
189                        )
190                    })
191                    .collect()
192            })
193            .collect();
194
195        struct AstContext<'a, E, F: FieldExt, B: Basis> {
196            domain: &'a EvaluationDomain<F>,
197            poly_len: usize,
198            chunk_size: usize,
199            chunk_index: usize,
200            leaves: &'a HashMap<AstLeaf<E, B>, &'a [F]>,
201        }
202
203        fn recurse<E, F: FieldExt, B: BasisOps>(
204            ast: &Ast<E, F, B>,
205            ctx: &AstContext<'_, E, F, B>,
206        ) -> Vec<F> {
207            match ast {
208                Ast::Poly(leaf) => ctx.leaves.get(leaf).expect("We prepared this").to_vec(),
209                Ast::Add(a, b) => {
210                    let mut lhs = recurse(a, ctx);
211                    let rhs = recurse(b, ctx);
212                    for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
213                        *lhs += *rhs;
214                    }
215                    lhs
216                }
217                Ast::Mul(AstMul(a, b)) => {
218                    let mut lhs = recurse(a, ctx);
219                    let rhs = recurse(b, ctx);
220                    for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
221                        *lhs *= *rhs;
222                    }
223                    lhs
224                }
225                Ast::Scale(a, scalar) => {
226                    let mut lhs = recurse(a, ctx);
227                    for lhs in lhs.iter_mut() {
228                        *lhs *= scalar;
229                    }
230                    lhs
231                }
232                Ast::DistributePowers(terms, base) => terms.iter().fold(
233                    B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::zero()),
234                    |mut acc, term| {
235                        let term = recurse(term, ctx);
236                        for (acc, term) in acc.iter_mut().zip(term) {
237                            *acc *= base;
238                            *acc += term;
239                        }
240                        acc
241                    },
242                ),
243                Ast::LinearTerm(scalar) => B::linear_term(
244                    ctx.domain,
245                    ctx.poly_len,
246                    ctx.chunk_size,
247                    ctx.chunk_index,
248                    *scalar,
249                ),
250                Ast::ConstantTerm(scalar) => {
251                    B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar)
252                }
253            }
254        }
255
256        // Apply `ast` to each chunk in parallel, writing the result into an output
257        // polynomial.
258        let mut result = B::empty_poly(domain);
259        multicore::scope(|scope| {
260            for (chunk_index, (out, leaves)) in
261                result.chunks_mut(chunk_size).zip(chunks.iter()).enumerate()
262            {
263                scope.spawn(move |_| {
264                    let ctx = AstContext {
265                        domain,
266                        poly_len,
267                        chunk_size,
268                        chunk_index,
269                        leaves,
270                    };
271                    out.copy_from_slice(&recurse(ast, &ctx));
272                });
273            }
274        });
275        result
276    }
277}
278
279/// Struct representing the [`Ast::Mul`] case.
280///
281/// This struct exists to make the internals of this case private so that we don't
282/// accidentally construct this case directly, because it can only be implemented for the
283/// [`ExtendedLagrangeCoeff`] basis.
284#[derive(Clone)]
285pub(crate) struct AstMul<E, F: Field, B: Basis>(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>);
286
287impl<E, F: Field, B: Basis> fmt::Debug for AstMul<E, F, B> {
288    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289        f.debug_tuple("AstMul")
290            .field(&self.0)
291            .field(&self.1)
292            .finish()
293    }
294}
295
296/// A polynomial operation backed by an [`Evaluator`].
297#[derive(Clone)]
298pub(crate) enum Ast<E, F: Field, B: Basis> {
299    Poly(AstLeaf<E, B>),
300    Add(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>),
301    Mul(AstMul<E, F, B>),
302    Scale(Arc<Ast<E, F, B>>, F),
303    /// Represents a linear combination of a vector of nodes and the powers of a
304    /// field element, where the nodes are ordered from highest to lowest degree
305    /// terms.
306    DistributePowers(Arc<Vec<Ast<E, F, B>>>, F),
307    /// The degree-1 term of a polynomial.
308    ///
309    /// The field element is the coefficient of the term in the standard basis, not the
310    /// coefficient basis.
311    LinearTerm(F),
312    /// The degree-0 term of a polynomial.
313    ///
314    /// The field element is the same in both the standard and evaluation bases.
315    ConstantTerm(F),
316}
317
318impl<E, F: Field, B: Basis> Ast<E, F, B> {
319    pub fn distribute_powers<I: IntoIterator<Item = Self>>(i: I, base: F) -> Self {
320        Ast::DistributePowers(Arc::new(i.into_iter().collect()), base)
321    }
322}
323
324impl<E, F: Field, B: Basis> fmt::Debug for Ast<E, F, B> {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        match self {
327            Self::Poly(leaf) => f.debug_tuple("Poly").field(leaf).finish(),
328            Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(),
329            Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(),
330            Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(),
331            Self::DistributePowers(terms, base) => f
332                .debug_tuple("DistributePowers")
333                .field(terms)
334                .field(base)
335                .finish(),
336            Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(),
337            Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(),
338        }
339    }
340}
341
342impl<E, F: Field, B: Basis> From<AstLeaf<E, B>> for Ast<E, F, B> {
343    fn from(leaf: AstLeaf<E, B>) -> Self {
344        Ast::Poly(leaf)
345    }
346}
347
348impl<E, F: Field, B: Basis> Ast<E, F, B> {
349    pub(crate) fn one() -> Self {
350        Self::ConstantTerm(F::one())
351    }
352}
353
354impl<E, F: Field, B: Basis> Neg for Ast<E, F, B> {
355    type Output = Ast<E, F, B>;
356
357    fn neg(self) -> Self::Output {
358        Ast::Scale(Arc::new(self), -F::one())
359    }
360}
361
362impl<E: Clone, F: Field, B: Basis> Neg for &Ast<E, F, B> {
363    type Output = Ast<E, F, B>;
364
365    fn neg(self) -> Self::Output {
366        -(self.clone())
367    }
368}
369
370impl<E, F: Field, B: Basis> Add for Ast<E, F, B> {
371    type Output = Ast<E, F, B>;
372
373    fn add(self, other: Self) -> Self::Output {
374        Ast::Add(Arc::new(self), Arc::new(other))
375    }
376}
377
378impl<'a, E: Clone, F: Field, B: Basis> Add<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
379    type Output = Ast<E, F, B>;
380
381    fn add(self, other: &'a Ast<E, F, B>) -> Self::Output {
382        self.clone() + other.clone()
383    }
384}
385
386impl<E, F: Field, B: Basis> Add<AstLeaf<E, B>> for Ast<E, F, B> {
387    type Output = Ast<E, F, B>;
388
389    fn add(self, other: AstLeaf<E, B>) -> Self::Output {
390        Ast::Add(Arc::new(self), Arc::new(other.into()))
391    }
392}
393
394impl<E, F: Field, B: Basis> Sub for Ast<E, F, B> {
395    type Output = Ast<E, F, B>;
396
397    fn sub(self, other: Self) -> Self::Output {
398        self + (-other)
399    }
400}
401
402impl<'a, E: Clone, F: Field, B: Basis> Sub<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
403    type Output = Ast<E, F, B>;
404
405    fn sub(self, other: &'a Ast<E, F, B>) -> Self::Output {
406        self + &(-other)
407    }
408}
409
410impl<E, F: Field, B: Basis> Sub<AstLeaf<E, B>> for Ast<E, F, B> {
411    type Output = Ast<E, F, B>;
412
413    fn sub(self, other: AstLeaf<E, B>) -> Self::Output {
414        self + (-Ast::from(other))
415    }
416}
417
418impl<E, F: Field> Mul for Ast<E, F, LagrangeCoeff> {
419    type Output = Ast<E, F, LagrangeCoeff>;
420
421    fn mul(self, other: Self) -> Self::Output {
422        Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
423    }
424}
425
426impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, LagrangeCoeff>> for &'a Ast<E, F, LagrangeCoeff> {
427    type Output = Ast<E, F, LagrangeCoeff>;
428
429    fn mul(self, other: &'a Ast<E, F, LagrangeCoeff>) -> Self::Output {
430        self.clone() * other.clone()
431    }
432}
433
434impl<E, F: Field> Mul<AstLeaf<E, LagrangeCoeff>> for Ast<E, F, LagrangeCoeff> {
435    type Output = Ast<E, F, LagrangeCoeff>;
436
437    fn mul(self, other: AstLeaf<E, LagrangeCoeff>) -> Self::Output {
438        Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
439    }
440}
441
442impl<E, F: Field> Mul for Ast<E, F, ExtendedLagrangeCoeff> {
443    type Output = Ast<E, F, ExtendedLagrangeCoeff>;
444
445    fn mul(self, other: Self) -> Self::Output {
446        Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
447    }
448}
449
450impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, ExtendedLagrangeCoeff>>
451    for &'a Ast<E, F, ExtendedLagrangeCoeff>
452{
453    type Output = Ast<E, F, ExtendedLagrangeCoeff>;
454
455    fn mul(self, other: &'a Ast<E, F, ExtendedLagrangeCoeff>) -> Self::Output {
456        self.clone() * other.clone()
457    }
458}
459
460impl<E, F: Field> Mul<AstLeaf<E, ExtendedLagrangeCoeff>> for Ast<E, F, ExtendedLagrangeCoeff> {
461    type Output = Ast<E, F, ExtendedLagrangeCoeff>;
462
463    fn mul(self, other: AstLeaf<E, ExtendedLagrangeCoeff>) -> Self::Output {
464        Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
465    }
466}
467
468impl<E, F: Field, B: Basis> Mul<F> for Ast<E, F, B> {
469    type Output = Ast<E, F, B>;
470
471    fn mul(self, other: F) -> Self::Output {
472        Ast::Scale(Arc::new(self), other)
473    }
474}
475
476impl<E: Clone, F: Field, B: Basis> Mul<F> for &Ast<E, F, B> {
477    type Output = Ast<E, F, B>;
478
479    fn mul(self, other: F) -> Self::Output {
480        Ast::Scale(Arc::new(self.clone()), other)
481    }
482}
483
484impl<E: Clone, F: Field> MulAssign for Ast<E, F, ExtendedLagrangeCoeff> {
485    fn mul_assign(&mut self, rhs: Self) {
486        *self = self.clone().mul(rhs)
487    }
488}
489
490/// Operations which can be performed over a given basis.
491pub(crate) trait BasisOps: Basis {
492    fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self>;
493    fn constant_term<F: FieldExt>(
494        poly_len: usize,
495        chunk_size: usize,
496        chunk_index: usize,
497        scalar: F,
498    ) -> Vec<F>;
499    fn linear_term<F: FieldExt>(
500        domain: &EvaluationDomain<F>,
501        poly_len: usize,
502        chunk_size: usize,
503        chunk_index: usize,
504        scalar: F,
505    ) -> Vec<F>;
506    fn rotate<F: FieldExt>(
507        domain: &EvaluationDomain<F>,
508        poly: &Polynomial<F, Self>,
509        rotation: Rotation,
510    ) -> Polynomial<F, Self>;
511}
512
513impl BasisOps for Coeff {
514    fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
515        domain.empty_coeff()
516    }
517
518    fn constant_term<F: FieldExt>(
519        poly_len: usize,
520        chunk_size: usize,
521        chunk_index: usize,
522        scalar: F,
523    ) -> Vec<F> {
524        let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
525        if chunk_index == 0 {
526            chunk[0] = scalar;
527        }
528        chunk
529    }
530
531    fn linear_term<F: FieldExt>(
532        _: &EvaluationDomain<F>,
533        poly_len: usize,
534        chunk_size: usize,
535        chunk_index: usize,
536        scalar: F,
537    ) -> Vec<F> {
538        let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
539        // If the chunk size is 1 (e.g. if we have a small k and many threads), then the
540        // linear coefficient is the second chunk. Otherwise, the chunk size is greater
541        // than one, and the linear coefficient is the second element of the first chunk.
542        // Note that we check against the original chunk size, not the potentially-short
543        // actual size of the current chunk, because we want to know whether the size of
544        // the previous chunk was 1.
545        if chunk_size == 1 && chunk_index == 1 {
546            chunk[0] = scalar;
547        } else if chunk_index == 0 {
548            chunk[1] = scalar;
549        }
550        chunk
551    }
552
553    fn rotate<F: FieldExt>(
554        _: &EvaluationDomain<F>,
555        _: &Polynomial<F, Self>,
556        _: Rotation,
557    ) -> Polynomial<F, Self> {
558        panic!("Can't rotate polynomials in the standard basis")
559    }
560}
561
562impl BasisOps for LagrangeCoeff {
563    fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
564        domain.empty_lagrange()
565    }
566
567    fn constant_term<F: FieldExt>(
568        poly_len: usize,
569        chunk_size: usize,
570        chunk_index: usize,
571        scalar: F,
572    ) -> Vec<F> {
573        vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
574    }
575
576    fn linear_term<F: FieldExt>(
577        domain: &EvaluationDomain<F>,
578        poly_len: usize,
579        chunk_size: usize,
580        chunk_index: usize,
581        scalar: F,
582    ) -> Vec<F> {
583        // Take every power of omega within the chunk, and multiply by scalar.
584        let omega = domain.get_omega();
585        let start = chunk_size * chunk_index;
586        (0..cmp::min(chunk_size, poly_len - start))
587            .scan(omega.pow_vartime(&[start as u64]) * scalar, |acc, _| {
588                let ret = *acc;
589                *acc *= omega;
590                Some(ret)
591            })
592            .collect()
593    }
594
595    fn rotate<F: FieldExt>(
596        _: &EvaluationDomain<F>,
597        poly: &Polynomial<F, Self>,
598        rotation: Rotation,
599    ) -> Polynomial<F, Self> {
600        poly.rotate(rotation)
601    }
602}
603
604impl BasisOps for ExtendedLagrangeCoeff {
605    fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
606        domain.empty_extended()
607    }
608
609    fn constant_term<F: FieldExt>(
610        poly_len: usize,
611        chunk_size: usize,
612        chunk_index: usize,
613        scalar: F,
614    ) -> Vec<F> {
615        vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
616    }
617
618    fn linear_term<F: FieldExt>(
619        domain: &EvaluationDomain<F>,
620        poly_len: usize,
621        chunk_size: usize,
622        chunk_index: usize,
623        scalar: F,
624    ) -> Vec<F> {
625        // Take every power of the extended omega within the chunk, and multiply by scalar.
626        let omega = domain.get_extended_omega();
627        let start = chunk_size * chunk_index;
628        (0..cmp::min(chunk_size, poly_len - start))
629            .scan(
630                omega.pow_vartime(&[start as u64]) * F::ZETA * scalar,
631                |acc, _| {
632                    let ret = *acc;
633                    *acc *= omega;
634                    Some(ret)
635                },
636            )
637            .collect()
638    }
639
640    fn rotate<F: FieldExt>(
641        domain: &EvaluationDomain<F>,
642        poly: &Polynomial<F, Self>,
643        rotation: Rotation,
644    ) -> Polynomial<F, Self> {
645        domain.rotate_extended(poly, rotation)
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use std::iter;
652
653    use pasta_curves::pallas;
654
655    use super::{get_chunk_params, new_evaluator, Ast, BasisOps, Evaluator};
656    use crate::{
657        multicore,
658        poly::{Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff},
659    };
660
661    #[test]
662    fn short_chunk_regression_test() {
663        // Pick the smallest polynomial length that is guaranteed to produce a short chunk
664        // on this machine.
665        let k = match (1..16)
666            .map(|k| (k, get_chunk_params(1 << k)))
667            .find(|(k, (chunk_size, num_chunks))| (1 << k) < chunk_size * num_chunks)
668            .map(|(k, _)| k)
669        {
670            Some(k) => k,
671            None => {
672                // We are on a machine with a power-of-two number of threads, and cannot
673                // trigger the bug.
674                eprintln!(
675                    "can't find a polynomial length for short_chunk_regression_test; skipping"
676                );
677                return;
678            }
679        };
680        eprintln!("Testing short-chunk regression with k = {}", k);
681
682        fn test_case<E: Copy + Send + Sync, B: BasisOps>(
683            k: u32,
684            mut evaluator: Evaluator<E, pallas::Base, B>,
685        ) {
686            // Instantiate the evaluator with a trivial polynomial.
687            let domain = EvaluationDomain::new(1, k);
688            evaluator.register_poly(B::empty_poly(&domain));
689
690            // With the bug present, these will panic.
691            let _ = evaluator.evaluate(&Ast::ConstantTerm(pallas::Base::zero()), &domain);
692            let _ = evaluator.evaluate(&Ast::LinearTerm(pallas::Base::zero()), &domain);
693        }
694
695        test_case(k, new_evaluator::<_, _, Coeff>(|| {}));
696        test_case(k, new_evaluator::<_, _, LagrangeCoeff>(|| {}));
697        test_case(k, new_evaluator::<_, _, ExtendedLagrangeCoeff>(|| {}));
698    }
699}