halo2_axiom/plonk/
assigned.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use group::ff::Field;
4
5/// A value assigned to a cell within a circuit.
6///
7/// Stored as a fraction, so the backend can use batch inversion.
8///
9/// A denominator of zero maps to an assigned value of zero.
10#[derive(Clone, Copy, Debug)]
11pub enum Assigned<F> {
12    /// The field element zero.
13    Zero,
14    /// A value that does not require inversion to evaluate.
15    Trivial(F),
16    /// A value stored as a fraction to enable batch inversion.
17    Rational(F, F),
18}
19
20impl<F: Field> From<&Assigned<F>> for Assigned<F> {
21    fn from(val: &Assigned<F>) -> Self {
22        *val
23    }
24}
25
26impl<F: Field> From<&F> for Assigned<F> {
27    fn from(numerator: &F) -> Self {
28        Assigned::Trivial(*numerator)
29    }
30}
31
32impl<F: Field> From<F> for Assigned<F> {
33    fn from(numerator: F) -> Self {
34        Assigned::Trivial(numerator)
35    }
36}
37
38impl<F: Field> From<(F, F)> for Assigned<F> {
39    fn from((numerator, denominator): (F, F)) -> Self {
40        Assigned::Rational(numerator, denominator)
41    }
42}
43
44impl<F: Field> PartialEq for Assigned<F> {
45    fn eq(&self, other: &Self) -> bool {
46        match (self, other) {
47            // At least one side is directly zero.
48            (Self::Zero, Self::Zero) => true,
49            (Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(),
50
51            // One side is x/0 which maps to zero.
52            (Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator))
53                if denominator.is_zero_vartime() =>
54            {
55                x.is_zero_vartime()
56            }
57
58            // Okay, we need to do some actual math...
59            (Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs,
60            (Self::Trivial(x), Self::Rational(numerator, denominator))
61            | (Self::Rational(numerator, denominator), Self::Trivial(x)) => {
62                &(*x * denominator) == numerator
63            }
64            (
65                Self::Rational(lhs_numerator, lhs_denominator),
66                Self::Rational(rhs_numerator, rhs_denominator),
67            ) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator,
68        }
69    }
70}
71
72impl<F: Field> Eq for Assigned<F> {}
73
74impl<F: Field> Neg for Assigned<F> {
75    type Output = Assigned<F>;
76    fn neg(self) -> Self::Output {
77        match self {
78            Self::Zero => Self::Zero,
79            Self::Trivial(numerator) => Self::Trivial(-numerator),
80            Self::Rational(numerator, denominator) => Self::Rational(-numerator, denominator),
81        }
82    }
83}
84
85impl<F: Field> Neg for &Assigned<F> {
86    type Output = Assigned<F>;
87    fn neg(self) -> Self::Output {
88        -*self
89    }
90}
91
92impl<F: Field> Add for Assigned<F> {
93    type Output = Assigned<F>;
94    fn add(self, rhs: Assigned<F>) -> Assigned<F> {
95        match (self, rhs) {
96            // One side is directly zero.
97            (Self::Zero, _) => rhs,
98            (_, Self::Zero) => self,
99
100            // One side is x/0 which maps to zero.
101            (Self::Rational(_, denominator), other) | (other, Self::Rational(_, denominator))
102                if denominator.is_zero_vartime() =>
103            {
104                other
105            }
106
107            // Okay, we need to do some actual math...
108            (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs + rhs),
109            (Self::Rational(numerator, denominator), Self::Trivial(other))
110            | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
111                Self::Rational(numerator + denominator * other, denominator)
112            }
113            (
114                Self::Rational(lhs_numerator, lhs_denominator),
115                Self::Rational(rhs_numerator, rhs_denominator),
116            ) => Self::Rational(
117                lhs_numerator * rhs_denominator + lhs_denominator * rhs_numerator,
118                lhs_denominator * rhs_denominator,
119            ),
120        }
121    }
122}
123
124impl<F: Field> Add<F> for Assigned<F> {
125    type Output = Assigned<F>;
126    fn add(self, rhs: F) -> Assigned<F> {
127        self + Self::Trivial(rhs)
128    }
129}
130
131impl<F: Field> Add<F> for &Assigned<F> {
132    type Output = Assigned<F>;
133    fn add(self, rhs: F) -> Assigned<F> {
134        *self + rhs
135    }
136}
137
138impl<F: Field> Add<&Assigned<F>> for Assigned<F> {
139    type Output = Assigned<F>;
140    fn add(self, rhs: &Self) -> Assigned<F> {
141        self + *rhs
142    }
143}
144
145impl<F: Field> Add<Assigned<F>> for &Assigned<F> {
146    type Output = Assigned<F>;
147    fn add(self, rhs: Assigned<F>) -> Assigned<F> {
148        *self + rhs
149    }
150}
151
152impl<F: Field> Add<&Assigned<F>> for &Assigned<F> {
153    type Output = Assigned<F>;
154    fn add(self, rhs: &Assigned<F>) -> Assigned<F> {
155        *self + *rhs
156    }
157}
158
159impl<F: Field> AddAssign for Assigned<F> {
160    fn add_assign(&mut self, rhs: Self) {
161        *self = *self + rhs;
162    }
163}
164
165impl<F: Field> AddAssign<&Assigned<F>> for Assigned<F> {
166    fn add_assign(&mut self, rhs: &Self) {
167        *self = *self + rhs;
168    }
169}
170
171impl<F: Field> Sub for Assigned<F> {
172    type Output = Assigned<F>;
173    fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
174        self + (-rhs)
175    }
176}
177
178impl<F: Field> Sub<F> for Assigned<F> {
179    type Output = Assigned<F>;
180    fn sub(self, rhs: F) -> Assigned<F> {
181        self + (-rhs)
182    }
183}
184
185impl<F: Field> Sub<F> for &Assigned<F> {
186    type Output = Assigned<F>;
187    fn sub(self, rhs: F) -> Assigned<F> {
188        *self - rhs
189    }
190}
191
192impl<F: Field> Sub<&Assigned<F>> for Assigned<F> {
193    type Output = Assigned<F>;
194    fn sub(self, rhs: &Self) -> Assigned<F> {
195        self - *rhs
196    }
197}
198
199impl<F: Field> Sub<Assigned<F>> for &Assigned<F> {
200    type Output = Assigned<F>;
201    fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
202        *self - rhs
203    }
204}
205
206impl<F: Field> Sub<&Assigned<F>> for &Assigned<F> {
207    type Output = Assigned<F>;
208    fn sub(self, rhs: &Assigned<F>) -> Assigned<F> {
209        *self - *rhs
210    }
211}
212
213impl<F: Field> SubAssign for Assigned<F> {
214    fn sub_assign(&mut self, rhs: Self) {
215        *self = *self - rhs;
216    }
217}
218
219impl<F: Field> SubAssign<&Assigned<F>> for Assigned<F> {
220    fn sub_assign(&mut self, rhs: &Self) {
221        *self = *self - rhs;
222    }
223}
224
225impl<F: Field> Mul for Assigned<F> {
226    type Output = Assigned<F>;
227    fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
228        match (self, rhs) {
229            (Self::Zero, _) | (_, Self::Zero) => Self::Zero,
230            (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs * rhs),
231            (Self::Rational(numerator, denominator), Self::Trivial(other))
232            | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
233                Self::Rational(numerator * other, denominator)
234            }
235            (
236                Self::Rational(lhs_numerator, lhs_denominator),
237                Self::Rational(rhs_numerator, rhs_denominator),
238            ) => Self::Rational(
239                lhs_numerator * rhs_numerator,
240                lhs_denominator * rhs_denominator,
241            ),
242        }
243    }
244}
245
246impl<F: Field> Mul<F> for Assigned<F> {
247    type Output = Assigned<F>;
248    fn mul(self, rhs: F) -> Assigned<F> {
249        self * Self::Trivial(rhs)
250    }
251}
252
253impl<F: Field> Mul<F> for &Assigned<F> {
254    type Output = Assigned<F>;
255    fn mul(self, rhs: F) -> Assigned<F> {
256        *self * rhs
257    }
258}
259
260impl<F: Field> Mul<&Assigned<F>> for Assigned<F> {
261    type Output = Assigned<F>;
262    fn mul(self, rhs: &Assigned<F>) -> Assigned<F> {
263        self * *rhs
264    }
265}
266
267impl<F: Field> MulAssign for Assigned<F> {
268    fn mul_assign(&mut self, rhs: Self) {
269        *self = *self * rhs;
270    }
271}
272
273impl<F: Field> MulAssign<&Assigned<F>> for Assigned<F> {
274    fn mul_assign(&mut self, rhs: &Self) {
275        *self = *self * rhs;
276    }
277}
278
279impl<F: Field> Assigned<F> {
280    /// Returns the numerator.
281    pub fn numerator(&self) -> F {
282        match self {
283            Self::Zero => F::ZERO,
284            Self::Trivial(x) => *x,
285            Self::Rational(numerator, _) => *numerator,
286        }
287    }
288
289    /// Returns the denominator, if non-trivial.
290    pub fn denominator(&self) -> Option<F> {
291        match self {
292            Self::Zero => None,
293            Self::Trivial(_) => None,
294            Self::Rational(_, denominator) => Some(*denominator),
295        }
296    }
297
298    /// Returns true iff this element is zero.
299    pub fn is_zero_vartime(&self) -> bool {
300        match self {
301            Self::Zero => true,
302            Self::Trivial(x) => x.is_zero_vartime(),
303            // Assigned maps x/0 -> 0.
304            Self::Rational(numerator, denominator) => {
305                numerator.is_zero_vartime() || denominator.is_zero_vartime()
306            }
307        }
308    }
309
310    /// Doubles this element.
311    #[must_use]
312    pub fn double(&self) -> Self {
313        match self {
314            Self::Zero => Self::Zero,
315            Self::Trivial(x) => Self::Trivial(x.double()),
316            Self::Rational(numerator, denominator) => {
317                Self::Rational(numerator.double(), *denominator)
318            }
319        }
320    }
321
322    /// Squares this element.
323    #[must_use]
324    pub fn square(&self) -> Self {
325        match self {
326            Self::Zero => Self::Zero,
327            Self::Trivial(x) => Self::Trivial(x.square()),
328            Self::Rational(numerator, denominator) => {
329                Self::Rational(numerator.square(), denominator.square())
330            }
331        }
332    }
333
334    /// Cubes this element.
335    #[must_use]
336    pub fn cube(&self) -> Self {
337        self.square() * self
338    }
339
340    /// Inverts this assigned value (taking the inverse of zero to be zero).
341    pub fn invert(&self) -> Self {
342        match self {
343            Self::Zero => Self::Zero,
344            Self::Trivial(x) => Self::Rational(F::ONE, *x),
345            Self::Rational(numerator, denominator) => Self::Rational(*denominator, *numerator),
346        }
347    }
348
349    /// Evaluates this assigned value directly, performing an unbatched inversion if
350    /// necessary.
351    ///
352    /// If the denominator is zero, this returns zero.
353    pub fn evaluate(self) -> F {
354        match self {
355            Self::Zero => F::ZERO,
356            Self::Trivial(x) => x,
357            Self::Rational(numerator, denominator) => {
358                if denominator == F::ONE {
359                    numerator
360                } else {
361                    numerator * denominator.invert().unwrap_or(F::ZERO)
362                }
363            }
364        }
365    }
366}
367
368#[cfg(test)]
369mod tests {
370    use halo2curves::pasta::Fp;
371
372    use super::Assigned;
373    // We use (numerator, denominator) in the comments below to denote a rational.
374    #[test]
375    fn add_trivial_to_inv0_rational() {
376        // a = 2
377        // b = (1,0)
378        let a = Assigned::Trivial(Fp::from(2));
379        let b = Assigned::Rational(Fp::one(), Fp::zero());
380
381        // 2 + (1,0) = 2 + 0 = 2
382        // This fails if addition is implemented using normal rules for rationals.
383        assert_eq!((a + b).evaluate(), a.evaluate());
384        assert_eq!((b + a).evaluate(), a.evaluate());
385    }
386
387    #[test]
388    fn add_rational_to_inv0_rational() {
389        // a = (1,2)
390        // b = (1,0)
391        let a = Assigned::Rational(Fp::one(), Fp::from(2));
392        let b = Assigned::Rational(Fp::one(), Fp::zero());
393
394        // (1,2) + (1,0) = (1,2) + 0 = (1,2)
395        // This fails if addition is implemented using normal rules for rationals.
396        assert_eq!((a + b).evaluate(), a.evaluate());
397        assert_eq!((b + a).evaluate(), a.evaluate());
398    }
399
400    #[test]
401    fn sub_trivial_from_inv0_rational() {
402        // a = 2
403        // b = (1,0)
404        let a = Assigned::Trivial(Fp::from(2));
405        let b = Assigned::Rational(Fp::one(), Fp::zero());
406
407        // (1,0) - 2 = 0 - 2 = -2
408        // This fails if subtraction is implemented using normal rules for rationals.
409        assert_eq!((b - a).evaluate(), (-a).evaluate());
410
411        // 2 - (1,0) = 2 - 0 = 2
412        assert_eq!((a - b).evaluate(), a.evaluate());
413    }
414
415    #[test]
416    fn sub_rational_from_inv0_rational() {
417        // a = (1,2)
418        // b = (1,0)
419        let a = Assigned::Rational(Fp::one(), Fp::from(2));
420        let b = Assigned::Rational(Fp::one(), Fp::zero());
421
422        // (1,0) - (1,2) = 0 - (1,2) = -(1,2)
423        // This fails if subtraction is implemented using normal rules for rationals.
424        assert_eq!((b - a).evaluate(), (-a).evaluate());
425
426        // (1,2) - (1,0) = (1,2) - 0 = (1,2)
427        assert_eq!((a - b).evaluate(), a.evaluate());
428    }
429
430    #[test]
431    fn mul_rational_by_inv0_rational() {
432        // a = (1,2)
433        // b = (1,0)
434        let a = Assigned::Rational(Fp::one(), Fp::from(2));
435        let b = Assigned::Rational(Fp::one(), Fp::zero());
436
437        // (1,2) * (1,0) = (1,2) * 0 = 0
438        assert_eq!((a * b).evaluate(), Fp::zero());
439
440        // (1,0) * (1,2) = 0 * (1,2) = 0
441        assert_eq!((b * a).evaluate(), Fp::zero());
442    }
443}
444
445#[cfg(test)]
446mod proptests {
447    use std::{
448        cmp,
449        ops::{Add, Mul, Neg, Sub},
450    };
451
452    use group::ff::Field;
453    use halo2curves::pasta::Fp;
454    use proptest::{collection::vec, prelude::*, sample::select};
455
456    use super::Assigned;
457
458    trait UnaryOperand: Neg<Output = Self> {
459        fn double(&self) -> Self;
460        fn square(&self) -> Self;
461        fn cube(&self) -> Self;
462        fn inv0(&self) -> Self;
463    }
464
465    impl<F: Field> UnaryOperand for F {
466        fn double(&self) -> Self {
467            self.double()
468        }
469
470        fn square(&self) -> Self {
471            self.square()
472        }
473
474        fn cube(&self) -> Self {
475            self.cube()
476        }
477
478        fn inv0(&self) -> Self {
479            self.invert().unwrap_or(F::ZERO)
480        }
481    }
482
483    impl<F: Field> UnaryOperand for Assigned<F> {
484        fn double(&self) -> Self {
485            self.double()
486        }
487
488        fn square(&self) -> Self {
489            self.square()
490        }
491
492        fn cube(&self) -> Self {
493            self.cube()
494        }
495
496        fn inv0(&self) -> Self {
497            self.invert()
498        }
499    }
500
501    #[derive(Clone, Debug)]
502    enum UnaryOperator {
503        Neg,
504        Double,
505        Square,
506        Cube,
507        Inv0,
508    }
509
510    const UNARY_OPERATORS: &[UnaryOperator] = &[
511        UnaryOperator::Neg,
512        UnaryOperator::Double,
513        UnaryOperator::Square,
514        UnaryOperator::Cube,
515        UnaryOperator::Inv0,
516    ];
517
518    impl UnaryOperator {
519        fn apply<F: UnaryOperand>(&self, a: F) -> F {
520            match self {
521                Self::Neg => -a,
522                Self::Double => a.double(),
523                Self::Square => a.square(),
524                Self::Cube => a.cube(),
525                Self::Inv0 => a.inv0(),
526            }
527        }
528    }
529
530    trait BinaryOperand: Sized + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> {}
531    impl<F: Field> BinaryOperand for F {}
532    impl<F: Field> BinaryOperand for Assigned<F> {}
533
534    #[derive(Clone, Debug)]
535    enum BinaryOperator {
536        Add,
537        Sub,
538        Mul,
539    }
540
541    const BINARY_OPERATORS: &[BinaryOperator] = &[
542        BinaryOperator::Add,
543        BinaryOperator::Sub,
544        BinaryOperator::Mul,
545    ];
546
547    impl BinaryOperator {
548        fn apply<F: BinaryOperand>(&self, a: F, b: F) -> F {
549            match self {
550                Self::Add => a + b,
551                Self::Sub => a - b,
552                Self::Mul => a * b,
553            }
554        }
555    }
556
557    #[derive(Clone, Debug)]
558    enum Operator {
559        Unary(UnaryOperator),
560        Binary(BinaryOperator),
561    }
562
563    prop_compose! {
564        /// Use narrow that can be easily reduced.
565        fn arb_element()(val in any::<u64>()) -> Fp {
566            Fp::from(val)
567        }
568    }
569
570    prop_compose! {
571        fn arb_trivial()(element in arb_element()) -> Assigned<Fp> {
572            Assigned::Trivial(element)
573        }
574    }
575
576    prop_compose! {
577        /// Generates half of the denominators as zero to represent a deferred inversion.
578        fn arb_rational()(
579            numerator in arb_element(),
580            denominator in prop_oneof![
581                1 => Just(Fp::zero()),
582                2 => arb_element(),
583            ],
584        ) -> Assigned<Fp> {
585            Assigned::Rational(numerator, denominator)
586        }
587    }
588
589    prop_compose! {
590        fn arb_operators(num_unary: usize, num_binary: usize)(
591            unary in vec(select(UNARY_OPERATORS), num_unary),
592            binary in vec(select(BINARY_OPERATORS), num_binary),
593        ) -> Vec<Operator> {
594            unary.into_iter()
595                .map(Operator::Unary)
596                .chain(binary.into_iter().map(Operator::Binary))
597                .collect()
598        }
599    }
600
601    prop_compose! {
602        fn arb_testcase()(
603            num_unary in 0usize..5,
604            num_binary in 0usize..5,
605        )(
606            values in vec(
607                prop_oneof![
608                    1 => Just(Assigned::Zero),
609                    2 => arb_trivial(),
610                    2 => arb_rational(),
611                ],
612                // Ensure that:
613                // - we have at least one value to apply unary operators to.
614                // - we can apply every binary operator pairwise sequentially.
615                cmp::max(usize::from(num_unary > 0), num_binary + 1)),
616            operations in arb_operators(num_unary, num_binary).prop_shuffle(),
617        ) -> (Vec<Assigned<Fp>>, Vec<Operator>) {
618            (values, operations)
619        }
620    }
621
622    proptest! {
623        #[test]
624        fn operation_commutativity((values, operations) in arb_testcase()) {
625            // Evaluate the values at the start.
626            let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
627
628            // Apply the operations to both the deferred and evaluated values.
629            fn evaluate<F: UnaryOperand + BinaryOperand>(
630                items: Vec<F>,
631                operators: &[Operator],
632            ) -> F {
633                let mut ops = operators.iter();
634
635                // Process all binary operators. We are guaranteed to have exactly as many
636                // binary operators as we need calls to the reduction closure.
637                let mut res = items.into_iter().reduce(|mut a, b| loop {
638                    match ops.next() {
639                        Some(Operator::Unary(op)) => a = op.apply(a),
640                        Some(Operator::Binary(op)) => break op.apply(a, b),
641                        None => unreachable!(),
642                    }
643                }).unwrap();
644
645                // Process any unary operators that weren't handled in the reduce() call
646                // above (either if we only had one item, or there were unary operators
647                // after the last binary operator). We are guaranteed to have no binary
648                // operators remaining at this point.
649                loop {
650                    match ops.next() {
651                        Some(Operator::Unary(op)) => res = op.apply(res),
652                        Some(Operator::Binary(_)) => unreachable!(),
653                        None => break res,
654                    }
655                }
656            }
657            let deferred_result = evaluate(values, &operations);
658            let evaluated_result = evaluate(elements, &operations);
659
660            // The two should be equal, i.e. deferred inversion should commute with the
661            // list of operations.
662            assert_eq!(deferred_result.evaluate(), evaluated_result);
663        }
664    }
665}