1use std::{
2 cmp,
3 collections::{HashMap, HashSet},
4 fmt,
5 hash::{Hash, Hasher},
6 marker::PhantomData,
7 ops::{Add, Mul, MulAssign, Neg, Sub},
8 sync::Arc,
9};
10
11use group::ff::Field;
12use pasta_curves::arithmetic::FieldExt;
13
14use super::{
15 Basis, Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff, Polynomial, Rotation,
16};
17use crate::{arithmetic::parallelize, multicore};
18
19fn get_chunk_params(poly_len: usize) -> (usize, usize) {
22 let num_threads = multicore::current_num_threads();
24 let num_chunks = num_threads * 4;
28 let chunk_size = (poly_len + num_chunks - 1) / num_chunks;
32 let num_chunks = (poly_len + chunk_size - 1) / chunk_size;
35
36 (chunk_size, num_chunks)
37}
38
39#[derive(Clone, Copy)]
41pub(crate) struct AstLeaf<E, B: Basis> {
42 index: usize,
43 rotation: Rotation,
44 _evaluator: PhantomData<(E, B)>,
45}
46
47impl<E, B: Basis> fmt::Debug for AstLeaf<E, B> {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 f.debug_struct("AstLeaf")
50 .field("index", &self.index)
51 .field("rotation", &self.rotation)
52 .finish()
53 }
54}
55
56impl<E, B: Basis> PartialEq for AstLeaf<E, B> {
57 fn eq(&self, rhs: &Self) -> bool {
58 self.index.eq(&rhs.index) && self.rotation.0.eq(&rhs.rotation.0)
60 }
61}
62
63impl<E, B: Basis> Eq for AstLeaf<E, B> {}
64
65impl<E, B: Basis> Hash for AstLeaf<E, B> {
66 fn hash<H: Hasher>(&self, state: &mut H) {
67 self.index.hash(state);
68 self.rotation.0.hash(state);
69 }
70}
71
72impl<E, B: Basis> AstLeaf<E, B> {
73 pub(crate) fn with_rotation(&self, rotation: Rotation) -> Self {
77 AstLeaf {
78 index: self.index,
79 rotation,
80 _evaluator: PhantomData::default(),
81 }
82 }
83}
84
85pub(crate) struct Evaluator<E, F: Field, B: Basis> {
97 polys: Vec<Polynomial<F, B>>,
98 _context: E,
99}
100
101pub(crate) fn new_evaluator<E: Fn() + Clone, F: Field, B: Basis>(context: E) -> Evaluator<E, F, B> {
108 Evaluator {
109 polys: vec![],
110 _context: context,
111 }
112}
113
114impl<E, F: Field, B: Basis> Evaluator<E, F, B> {
115 pub(crate) fn register_poly(&mut self, poly: Polynomial<F, B>) -> AstLeaf<E, B> {
120 let index = self.polys.len();
121 self.polys.push(poly);
122
123 AstLeaf {
124 index,
125 rotation: Rotation::cur(),
126 _evaluator: PhantomData::default(),
127 }
128 }
129
130 pub(crate) fn evaluate(
132 &self,
133 ast: &Ast<E, F, B>,
134 domain: &EvaluationDomain<F>,
135 ) -> Polynomial<F, B>
136 where
137 E: Copy + Send + Sync,
138 F: FieldExt,
139 B: BasisOps,
140 {
141 fn collect_rotations<E: Copy, F: Field, B: Basis>(
143 ast: &Ast<E, F, B>,
144 ) -> HashSet<AstLeaf<E, B>> {
145 match ast {
146 Ast::Poly(leaf) => vec![*leaf].into_iter().collect(),
147 Ast::Add(a, b) | Ast::Mul(AstMul(a, b)) => {
148 let lhs = collect_rotations(a);
149 let rhs = collect_rotations(b);
150 lhs.union(&rhs).cloned().collect()
151 }
152 Ast::Scale(a, _) => collect_rotations(a),
153 Ast::DistributePowers(terms, _) => terms
154 .iter()
155 .flat_map(|term| collect_rotations(term).into_iter())
156 .collect(),
157 Ast::LinearTerm(_) | Ast::ConstantTerm(_) => HashSet::default(),
158 }
159 }
160 let leaves = collect_rotations(ast);
161
162 let rotated: HashMap<_, _> = leaves
164 .iter()
165 .cloned()
166 .map(|leaf| {
167 (
168 leaf,
169 B::rotate(domain, &self.polys[leaf.index], leaf.rotation),
170 )
171 })
172 .collect();
173
174 let poly_len = self.polys.first().unwrap().len();
176 let (chunk_size, num_chunks) = get_chunk_params(poly_len);
177
178 let chunks: Vec<HashMap<_, _>> = (0..num_chunks)
180 .map(|i| {
181 rotated
182 .iter()
183 .map(|(leaf, poly)| {
184 (
185 *leaf,
186 poly.chunks(chunk_size)
187 .nth(i)
188 .expect("num_chunks was calculated correctly"),
189 )
190 })
191 .collect()
192 })
193 .collect();
194
195 struct AstContext<'a, E, F: FieldExt, B: Basis> {
196 domain: &'a EvaluationDomain<F>,
197 poly_len: usize,
198 chunk_size: usize,
199 chunk_index: usize,
200 leaves: &'a HashMap<AstLeaf<E, B>, &'a [F]>,
201 }
202
203 fn recurse<E, F: FieldExt, B: BasisOps>(
204 ast: &Ast<E, F, B>,
205 ctx: &AstContext<'_, E, F, B>,
206 ) -> Vec<F> {
207 match ast {
208 Ast::Poly(leaf) => ctx.leaves.get(leaf).expect("We prepared this").to_vec(),
209 Ast::Add(a, b) => {
210 let mut lhs = recurse(a, ctx);
211 let rhs = recurse(b, ctx);
212 for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
213 *lhs += *rhs;
214 }
215 lhs
216 }
217 Ast::Mul(AstMul(a, b)) => {
218 let mut lhs = recurse(a, ctx);
219 let rhs = recurse(b, ctx);
220 for (lhs, rhs) in lhs.iter_mut().zip(rhs.iter()) {
221 *lhs *= *rhs;
222 }
223 lhs
224 }
225 Ast::Scale(a, scalar) => {
226 let mut lhs = recurse(a, ctx);
227 for lhs in lhs.iter_mut() {
228 *lhs *= scalar;
229 }
230 lhs
231 }
232 Ast::DistributePowers(terms, base) => terms.iter().fold(
233 B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, F::zero()),
234 |mut acc, term| {
235 let term = recurse(term, ctx);
236 for (acc, term) in acc.iter_mut().zip(term) {
237 *acc *= base;
238 *acc += term;
239 }
240 acc
241 },
242 ),
243 Ast::LinearTerm(scalar) => B::linear_term(
244 ctx.domain,
245 ctx.poly_len,
246 ctx.chunk_size,
247 ctx.chunk_index,
248 *scalar,
249 ),
250 Ast::ConstantTerm(scalar) => {
251 B::constant_term(ctx.poly_len, ctx.chunk_size, ctx.chunk_index, *scalar)
252 }
253 }
254 }
255
256 let mut result = B::empty_poly(domain);
259 multicore::scope(|scope| {
260 for (chunk_index, (out, leaves)) in
261 result.chunks_mut(chunk_size).zip(chunks.iter()).enumerate()
262 {
263 scope.spawn(move |_| {
264 let ctx = AstContext {
265 domain,
266 poly_len,
267 chunk_size,
268 chunk_index,
269 leaves,
270 };
271 out.copy_from_slice(&recurse(ast, &ctx));
272 });
273 }
274 });
275 result
276 }
277}
278
279#[derive(Clone)]
285pub(crate) struct AstMul<E, F: Field, B: Basis>(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>);
286
287impl<E, F: Field, B: Basis> fmt::Debug for AstMul<E, F, B> {
288 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
289 f.debug_tuple("AstMul")
290 .field(&self.0)
291 .field(&self.1)
292 .finish()
293 }
294}
295
296#[derive(Clone)]
298pub(crate) enum Ast<E, F: Field, B: Basis> {
299 Poly(AstLeaf<E, B>),
300 Add(Arc<Ast<E, F, B>>, Arc<Ast<E, F, B>>),
301 Mul(AstMul<E, F, B>),
302 Scale(Arc<Ast<E, F, B>>, F),
303 DistributePowers(Arc<Vec<Ast<E, F, B>>>, F),
307 LinearTerm(F),
312 ConstantTerm(F),
316}
317
318impl<E, F: Field, B: Basis> Ast<E, F, B> {
319 pub fn distribute_powers<I: IntoIterator<Item = Self>>(i: I, base: F) -> Self {
320 Ast::DistributePowers(Arc::new(i.into_iter().collect()), base)
321 }
322}
323
324impl<E, F: Field, B: Basis> fmt::Debug for Ast<E, F, B> {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 match self {
327 Self::Poly(leaf) => f.debug_tuple("Poly").field(leaf).finish(),
328 Self::Add(lhs, rhs) => f.debug_tuple("Add").field(lhs).field(rhs).finish(),
329 Self::Mul(x) => f.debug_tuple("Mul").field(x).finish(),
330 Self::Scale(base, scalar) => f.debug_tuple("Scale").field(base).field(scalar).finish(),
331 Self::DistributePowers(terms, base) => f
332 .debug_tuple("DistributePowers")
333 .field(terms)
334 .field(base)
335 .finish(),
336 Self::LinearTerm(x) => f.debug_tuple("LinearTerm").field(x).finish(),
337 Self::ConstantTerm(x) => f.debug_tuple("ConstantTerm").field(x).finish(),
338 }
339 }
340}
341
342impl<E, F: Field, B: Basis> From<AstLeaf<E, B>> for Ast<E, F, B> {
343 fn from(leaf: AstLeaf<E, B>) -> Self {
344 Ast::Poly(leaf)
345 }
346}
347
348impl<E, F: Field, B: Basis> Ast<E, F, B> {
349 pub(crate) fn one() -> Self {
350 Self::ConstantTerm(F::one())
351 }
352}
353
354impl<E, F: Field, B: Basis> Neg for Ast<E, F, B> {
355 type Output = Ast<E, F, B>;
356
357 fn neg(self) -> Self::Output {
358 Ast::Scale(Arc::new(self), -F::one())
359 }
360}
361
362impl<E: Clone, F: Field, B: Basis> Neg for &Ast<E, F, B> {
363 type Output = Ast<E, F, B>;
364
365 fn neg(self) -> Self::Output {
366 -(self.clone())
367 }
368}
369
370impl<E, F: Field, B: Basis> Add for Ast<E, F, B> {
371 type Output = Ast<E, F, B>;
372
373 fn add(self, other: Self) -> Self::Output {
374 Ast::Add(Arc::new(self), Arc::new(other))
375 }
376}
377
378impl<'a, E: Clone, F: Field, B: Basis> Add<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
379 type Output = Ast<E, F, B>;
380
381 fn add(self, other: &'a Ast<E, F, B>) -> Self::Output {
382 self.clone() + other.clone()
383 }
384}
385
386impl<E, F: Field, B: Basis> Add<AstLeaf<E, B>> for Ast<E, F, B> {
387 type Output = Ast<E, F, B>;
388
389 fn add(self, other: AstLeaf<E, B>) -> Self::Output {
390 Ast::Add(Arc::new(self), Arc::new(other.into()))
391 }
392}
393
394impl<E, F: Field, B: Basis> Sub for Ast<E, F, B> {
395 type Output = Ast<E, F, B>;
396
397 fn sub(self, other: Self) -> Self::Output {
398 self + (-other)
399 }
400}
401
402impl<'a, E: Clone, F: Field, B: Basis> Sub<&'a Ast<E, F, B>> for &'a Ast<E, F, B> {
403 type Output = Ast<E, F, B>;
404
405 fn sub(self, other: &'a Ast<E, F, B>) -> Self::Output {
406 self + &(-other)
407 }
408}
409
410impl<E, F: Field, B: Basis> Sub<AstLeaf<E, B>> for Ast<E, F, B> {
411 type Output = Ast<E, F, B>;
412
413 fn sub(self, other: AstLeaf<E, B>) -> Self::Output {
414 self + (-Ast::from(other))
415 }
416}
417
418impl<E, F: Field> Mul for Ast<E, F, LagrangeCoeff> {
419 type Output = Ast<E, F, LagrangeCoeff>;
420
421 fn mul(self, other: Self) -> Self::Output {
422 Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
423 }
424}
425
426impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, LagrangeCoeff>> for &'a Ast<E, F, LagrangeCoeff> {
427 type Output = Ast<E, F, LagrangeCoeff>;
428
429 fn mul(self, other: &'a Ast<E, F, LagrangeCoeff>) -> Self::Output {
430 self.clone() * other.clone()
431 }
432}
433
434impl<E, F: Field> Mul<AstLeaf<E, LagrangeCoeff>> for Ast<E, F, LagrangeCoeff> {
435 type Output = Ast<E, F, LagrangeCoeff>;
436
437 fn mul(self, other: AstLeaf<E, LagrangeCoeff>) -> Self::Output {
438 Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
439 }
440}
441
442impl<E, F: Field> Mul for Ast<E, F, ExtendedLagrangeCoeff> {
443 type Output = Ast<E, F, ExtendedLagrangeCoeff>;
444
445 fn mul(self, other: Self) -> Self::Output {
446 Ast::Mul(AstMul(Arc::new(self), Arc::new(other)))
447 }
448}
449
450impl<'a, E: Clone, F: Field> Mul<&'a Ast<E, F, ExtendedLagrangeCoeff>>
451 for &'a Ast<E, F, ExtendedLagrangeCoeff>
452{
453 type Output = Ast<E, F, ExtendedLagrangeCoeff>;
454
455 fn mul(self, other: &'a Ast<E, F, ExtendedLagrangeCoeff>) -> Self::Output {
456 self.clone() * other.clone()
457 }
458}
459
460impl<E, F: Field> Mul<AstLeaf<E, ExtendedLagrangeCoeff>> for Ast<E, F, ExtendedLagrangeCoeff> {
461 type Output = Ast<E, F, ExtendedLagrangeCoeff>;
462
463 fn mul(self, other: AstLeaf<E, ExtendedLagrangeCoeff>) -> Self::Output {
464 Ast::Mul(AstMul(Arc::new(self), Arc::new(other.into())))
465 }
466}
467
468impl<E, F: Field, B: Basis> Mul<F> for Ast<E, F, B> {
469 type Output = Ast<E, F, B>;
470
471 fn mul(self, other: F) -> Self::Output {
472 Ast::Scale(Arc::new(self), other)
473 }
474}
475
476impl<E: Clone, F: Field, B: Basis> Mul<F> for &Ast<E, F, B> {
477 type Output = Ast<E, F, B>;
478
479 fn mul(self, other: F) -> Self::Output {
480 Ast::Scale(Arc::new(self.clone()), other)
481 }
482}
483
484impl<E: Clone, F: Field> MulAssign for Ast<E, F, ExtendedLagrangeCoeff> {
485 fn mul_assign(&mut self, rhs: Self) {
486 *self = self.clone().mul(rhs)
487 }
488}
489
490pub(crate) trait BasisOps: Basis {
492 fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self>;
493 fn constant_term<F: FieldExt>(
494 poly_len: usize,
495 chunk_size: usize,
496 chunk_index: usize,
497 scalar: F,
498 ) -> Vec<F>;
499 fn linear_term<F: FieldExt>(
500 domain: &EvaluationDomain<F>,
501 poly_len: usize,
502 chunk_size: usize,
503 chunk_index: usize,
504 scalar: F,
505 ) -> Vec<F>;
506 fn rotate<F: FieldExt>(
507 domain: &EvaluationDomain<F>,
508 poly: &Polynomial<F, Self>,
509 rotation: Rotation,
510 ) -> Polynomial<F, Self>;
511}
512
513impl BasisOps for Coeff {
514 fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
515 domain.empty_coeff()
516 }
517
518 fn constant_term<F: FieldExt>(
519 poly_len: usize,
520 chunk_size: usize,
521 chunk_index: usize,
522 scalar: F,
523 ) -> Vec<F> {
524 let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
525 if chunk_index == 0 {
526 chunk[0] = scalar;
527 }
528 chunk
529 }
530
531 fn linear_term<F: FieldExt>(
532 _: &EvaluationDomain<F>,
533 poly_len: usize,
534 chunk_size: usize,
535 chunk_index: usize,
536 scalar: F,
537 ) -> Vec<F> {
538 let mut chunk = vec![F::zero(); cmp::min(chunk_size, poly_len - chunk_size * chunk_index)];
539 if chunk_size == 1 && chunk_index == 1 {
546 chunk[0] = scalar;
547 } else if chunk_index == 0 {
548 chunk[1] = scalar;
549 }
550 chunk
551 }
552
553 fn rotate<F: FieldExt>(
554 _: &EvaluationDomain<F>,
555 _: &Polynomial<F, Self>,
556 _: Rotation,
557 ) -> Polynomial<F, Self> {
558 panic!("Can't rotate polynomials in the standard basis")
559 }
560}
561
562impl BasisOps for LagrangeCoeff {
563 fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
564 domain.empty_lagrange()
565 }
566
567 fn constant_term<F: FieldExt>(
568 poly_len: usize,
569 chunk_size: usize,
570 chunk_index: usize,
571 scalar: F,
572 ) -> Vec<F> {
573 vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
574 }
575
576 fn linear_term<F: FieldExt>(
577 domain: &EvaluationDomain<F>,
578 poly_len: usize,
579 chunk_size: usize,
580 chunk_index: usize,
581 scalar: F,
582 ) -> Vec<F> {
583 let omega = domain.get_omega();
585 let start = chunk_size * chunk_index;
586 (0..cmp::min(chunk_size, poly_len - start))
587 .scan(omega.pow_vartime(&[start as u64]) * scalar, |acc, _| {
588 let ret = *acc;
589 *acc *= omega;
590 Some(ret)
591 })
592 .collect()
593 }
594
595 fn rotate<F: FieldExt>(
596 _: &EvaluationDomain<F>,
597 poly: &Polynomial<F, Self>,
598 rotation: Rotation,
599 ) -> Polynomial<F, Self> {
600 poly.rotate(rotation)
601 }
602}
603
604impl BasisOps for ExtendedLagrangeCoeff {
605 fn empty_poly<F: FieldExt>(domain: &EvaluationDomain<F>) -> Polynomial<F, Self> {
606 domain.empty_extended()
607 }
608
609 fn constant_term<F: FieldExt>(
610 poly_len: usize,
611 chunk_size: usize,
612 chunk_index: usize,
613 scalar: F,
614 ) -> Vec<F> {
615 vec![scalar; cmp::min(chunk_size, poly_len - chunk_size * chunk_index)]
616 }
617
618 fn linear_term<F: FieldExt>(
619 domain: &EvaluationDomain<F>,
620 poly_len: usize,
621 chunk_size: usize,
622 chunk_index: usize,
623 scalar: F,
624 ) -> Vec<F> {
625 let omega = domain.get_extended_omega();
627 let start = chunk_size * chunk_index;
628 (0..cmp::min(chunk_size, poly_len - start))
629 .scan(
630 omega.pow_vartime(&[start as u64]) * F::ZETA * scalar,
631 |acc, _| {
632 let ret = *acc;
633 *acc *= omega;
634 Some(ret)
635 },
636 )
637 .collect()
638 }
639
640 fn rotate<F: FieldExt>(
641 domain: &EvaluationDomain<F>,
642 poly: &Polynomial<F, Self>,
643 rotation: Rotation,
644 ) -> Polynomial<F, Self> {
645 domain.rotate_extended(poly, rotation)
646 }
647}
648
649#[cfg(test)]
650mod tests {
651 use std::iter;
652
653 use pasta_curves::pallas;
654
655 use super::{get_chunk_params, new_evaluator, Ast, BasisOps, Evaluator};
656 use crate::{
657 multicore,
658 poly::{Coeff, EvaluationDomain, ExtendedLagrangeCoeff, LagrangeCoeff},
659 };
660
661 #[test]
662 fn short_chunk_regression_test() {
663 let k = match (1..16)
666 .map(|k| (k, get_chunk_params(1 << k)))
667 .find(|(k, (chunk_size, num_chunks))| (1 << k) < chunk_size * num_chunks)
668 .map(|(k, _)| k)
669 {
670 Some(k) => k,
671 None => {
672 eprintln!(
675 "can't find a polynomial length for short_chunk_regression_test; skipping"
676 );
677 return;
678 }
679 };
680 eprintln!("Testing short-chunk regression with k = {}", k);
681
682 fn test_case<E: Copy + Send + Sync, B: BasisOps>(
683 k: u32,
684 mut evaluator: Evaluator<E, pallas::Base, B>,
685 ) {
686 let domain = EvaluationDomain::new(1, k);
688 evaluator.register_poly(B::empty_poly(&domain));
689
690 let _ = evaluator.evaluate(&Ast::ConstantTerm(pallas::Base::zero()), &domain);
692 let _ = evaluator.evaluate(&Ast::LinearTerm(pallas::Base::zero()), &domain);
693 }
694
695 test_case(k, new_evaluator::<_, _, Coeff>(|| {}));
696 test_case(k, new_evaluator::<_, _, LagrangeCoeff>(|| {}));
697 test_case(k, new_evaluator::<_, _, ExtendedLagrangeCoeff>(|| {}));
698 }
699}