halo2_proofs/plonk/
assigned.rs

1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use group::ff::Field;
4
5/// A value assigned to a cell within a circuit.
6///
7/// Stored as a fraction, so the backend can use batch inversion.
8///
9/// A denominator of zero maps to an assigned value of zero.
10#[derive(Clone, Copy, Debug)]
11pub enum Assigned<F> {
12    /// The field element zero.
13    Zero,
14    /// A value that does not require inversion to evaluate.
15    Trivial(F),
16    /// A value stored as a fraction to enable batch inversion.
17    Rational(F, F),
18}
19
20impl<F: Field> From<&Assigned<F>> for Assigned<F> {
21    fn from(val: &Assigned<F>) -> Self {
22        *val
23    }
24}
25
26impl<F: Field> From<&F> for Assigned<F> {
27    fn from(numerator: &F) -> Self {
28        Assigned::Trivial(*numerator)
29    }
30}
31
32impl<F: Field> From<F> for Assigned<F> {
33    fn from(numerator: F) -> Self {
34        Assigned::Trivial(numerator)
35    }
36}
37
38impl<F: Field> From<(F, F)> for Assigned<F> {
39    fn from((numerator, denominator): (F, F)) -> Self {
40        Assigned::Rational(numerator, denominator)
41    }
42}
43
44impl<F: Field> PartialEq for Assigned<F> {
45    fn eq(&self, other: &Self) -> bool {
46        match (self, other) {
47            // At least one side is directly zero.
48            (Self::Zero, Self::Zero) => true,
49            (Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(),
50
51            // One side is x/0 which maps to zero.
52            (Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator))
53                if denominator.is_zero_vartime() =>
54            {
55                x.is_zero_vartime()
56            }
57
58            // Okay, we need to do some actual math...
59            (Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs,
60            (Self::Trivial(x), Self::Rational(numerator, denominator))
61            | (Self::Rational(numerator, denominator), Self::Trivial(x)) => {
62                &(*x * denominator) == numerator
63            }
64            (
65                Self::Rational(lhs_numerator, lhs_denominator),
66                Self::Rational(rhs_numerator, rhs_denominator),
67            ) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator,
68        }
69    }
70}
71
72impl<F: Field> Eq for Assigned<F> {}
73
74impl<F: Field> Neg for Assigned<F> {
75    type Output = Assigned<F>;
76    fn neg(self) -> Self::Output {
77        match self {
78            Self::Zero => Self::Zero,
79            Self::Trivial(numerator) => Self::Trivial(-numerator),
80            Self::Rational(numerator, denominator) => Self::Rational(-numerator, denominator),
81        }
82    }
83}
84
85impl<F: Field> Add for Assigned<F> {
86    type Output = Assigned<F>;
87    fn add(self, rhs: Assigned<F>) -> Assigned<F> {
88        match (self, rhs) {
89            // One side is directly zero.
90            (Self::Zero, _) => rhs,
91            (_, Self::Zero) => self,
92
93            // One side is x/0 which maps to zero.
94            (Self::Rational(_, denominator), other) | (other, Self::Rational(_, denominator))
95                if denominator.is_zero_vartime() =>
96            {
97                other
98            }
99
100            // Okay, we need to do some actual math...
101            (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs + rhs),
102            (Self::Rational(numerator, denominator), Self::Trivial(other))
103            | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
104                Self::Rational(numerator + denominator * other, denominator)
105            }
106            (
107                Self::Rational(lhs_numerator, lhs_denominator),
108                Self::Rational(rhs_numerator, rhs_denominator),
109            ) => Self::Rational(
110                lhs_numerator * rhs_denominator + lhs_denominator * rhs_numerator,
111                lhs_denominator * rhs_denominator,
112            ),
113        }
114    }
115}
116
117impl<F: Field> Add<F> for Assigned<F> {
118    type Output = Assigned<F>;
119    fn add(self, rhs: F) -> Assigned<F> {
120        self + Self::Trivial(rhs)
121    }
122}
123
124impl<F: Field> Add<F> for &Assigned<F> {
125    type Output = Assigned<F>;
126    fn add(self, rhs: F) -> Assigned<F> {
127        *self + rhs
128    }
129}
130
131impl<F: Field> Add<&Assigned<F>> for Assigned<F> {
132    type Output = Assigned<F>;
133    fn add(self, rhs: &Self) -> Assigned<F> {
134        self + *rhs
135    }
136}
137
138impl<F: Field> Add<Assigned<F>> for &Assigned<F> {
139    type Output = Assigned<F>;
140    fn add(self, rhs: Assigned<F>) -> Assigned<F> {
141        *self + rhs
142    }
143}
144
145impl<F: Field> Add<&Assigned<F>> for &Assigned<F> {
146    type Output = Assigned<F>;
147    fn add(self, rhs: &Assigned<F>) -> Assigned<F> {
148        *self + *rhs
149    }
150}
151
152impl<F: Field> AddAssign for Assigned<F> {
153    fn add_assign(&mut self, rhs: Self) {
154        *self = *self + rhs;
155    }
156}
157
158impl<F: Field> AddAssign<&Assigned<F>> for Assigned<F> {
159    fn add_assign(&mut self, rhs: &Self) {
160        *self = *self + rhs;
161    }
162}
163
164impl<F: Field> Sub for Assigned<F> {
165    type Output = Assigned<F>;
166    fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
167        self + (-rhs)
168    }
169}
170
171impl<F: Field> Sub<F> for Assigned<F> {
172    type Output = Assigned<F>;
173    fn sub(self, rhs: F) -> Assigned<F> {
174        self + (-rhs)
175    }
176}
177
178impl<F: Field> Sub<F> for &Assigned<F> {
179    type Output = Assigned<F>;
180    fn sub(self, rhs: F) -> Assigned<F> {
181        *self - rhs
182    }
183}
184
185impl<F: Field> Sub<&Assigned<F>> for Assigned<F> {
186    type Output = Assigned<F>;
187    fn sub(self, rhs: &Self) -> Assigned<F> {
188        self - *rhs
189    }
190}
191
192impl<F: Field> Sub<Assigned<F>> for &Assigned<F> {
193    type Output = Assigned<F>;
194    fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
195        *self - rhs
196    }
197}
198
199impl<F: Field> Sub<&Assigned<F>> for &Assigned<F> {
200    type Output = Assigned<F>;
201    fn sub(self, rhs: &Assigned<F>) -> Assigned<F> {
202        *self - *rhs
203    }
204}
205
206impl<F: Field> SubAssign for Assigned<F> {
207    fn sub_assign(&mut self, rhs: Self) {
208        *self = *self - rhs;
209    }
210}
211
212impl<F: Field> SubAssign<&Assigned<F>> for Assigned<F> {
213    fn sub_assign(&mut self, rhs: &Self) {
214        *self = *self - rhs;
215    }
216}
217
218impl<F: Field> Mul for Assigned<F> {
219    type Output = Assigned<F>;
220    fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
221        match (self, rhs) {
222            (Self::Zero, _) | (_, Self::Zero) => Self::Zero,
223            (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs * rhs),
224            (Self::Rational(numerator, denominator), Self::Trivial(other))
225            | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
226                Self::Rational(numerator * other, denominator)
227            }
228            (
229                Self::Rational(lhs_numerator, lhs_denominator),
230                Self::Rational(rhs_numerator, rhs_denominator),
231            ) => Self::Rational(
232                lhs_numerator * rhs_numerator,
233                lhs_denominator * rhs_denominator,
234            ),
235        }
236    }
237}
238
239impl<F: Field> Mul<F> for Assigned<F> {
240    type Output = Assigned<F>;
241    fn mul(self, rhs: F) -> Assigned<F> {
242        self * Self::Trivial(rhs)
243    }
244}
245
246impl<F: Field> Mul<&Assigned<F>> for Assigned<F> {
247    type Output = Assigned<F>;
248    fn mul(self, rhs: &Assigned<F>) -> Assigned<F> {
249        self * *rhs
250    }
251}
252
253impl<F: Field> MulAssign for Assigned<F> {
254    fn mul_assign(&mut self, rhs: Self) {
255        *self = *self * rhs;
256    }
257}
258
259impl<F: Field> MulAssign<&Assigned<F>> for Assigned<F> {
260    fn mul_assign(&mut self, rhs: &Self) {
261        *self = *self * rhs;
262    }
263}
264
265impl<F: Field> Assigned<F> {
266    /// Returns the numerator.
267    pub fn numerator(&self) -> F {
268        match self {
269            Self::Zero => F::zero(),
270            Self::Trivial(x) => *x,
271            Self::Rational(numerator, _) => *numerator,
272        }
273    }
274
275    /// Returns the denominator, if non-trivial.
276    pub fn denominator(&self) -> Option<F> {
277        match self {
278            Self::Zero => None,
279            Self::Trivial(_) => None,
280            Self::Rational(_, denominator) => Some(*denominator),
281        }
282    }
283
284    /// Returns true iff this element is zero.
285    pub fn is_zero_vartime(&self) -> bool {
286        match self {
287            Self::Zero => true,
288            Self::Trivial(x) => x.is_zero_vartime(),
289            // Assigned maps x/0 -> 0.
290            Self::Rational(numerator, denominator) => {
291                numerator.is_zero_vartime() || denominator.is_zero_vartime()
292            }
293        }
294    }
295
296    /// Doubles this element.
297    #[must_use]
298    pub fn double(&self) -> Self {
299        match self {
300            Self::Zero => Self::Zero,
301            Self::Trivial(x) => Self::Trivial(x.double()),
302            Self::Rational(numerator, denominator) => {
303                Self::Rational(numerator.double(), *denominator)
304            }
305        }
306    }
307
308    /// Squares this element.
309    #[must_use]
310    pub fn square(&self) -> Self {
311        match self {
312            Self::Zero => Self::Zero,
313            Self::Trivial(x) => Self::Trivial(x.square()),
314            Self::Rational(numerator, denominator) => {
315                Self::Rational(numerator.square(), denominator.square())
316            }
317        }
318    }
319
320    /// Cubes this element.
321    #[must_use]
322    pub fn cube(&self) -> Self {
323        self.square() * self
324    }
325
326    /// Inverts this assigned value (taking the inverse of zero to be zero).
327    pub fn invert(&self) -> Self {
328        match self {
329            Self::Zero => Self::Zero,
330            Self::Trivial(x) => Self::Rational(F::one(), *x),
331            Self::Rational(numerator, denominator) => Self::Rational(*denominator, *numerator),
332        }
333    }
334
335    /// Evaluates this assigned value directly, performing an unbatched inversion if
336    /// necessary.
337    ///
338    /// If the denominator is zero, this returns zero.
339    pub fn evaluate(self) -> F {
340        match self {
341            Self::Zero => F::zero(),
342            Self::Trivial(x) => x,
343            Self::Rational(numerator, denominator) => {
344                if denominator == F::one() {
345                    numerator
346                } else {
347                    numerator * denominator.invert().unwrap_or(F::zero())
348                }
349            }
350        }
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use pasta_curves::Fp;
357
358    use super::Assigned;
359    // We use (numerator, denominator) in the comments below to denote a rational.
360    #[test]
361    fn add_trivial_to_inv0_rational() {
362        // a = 2
363        // b = (1,0)
364        let a = Assigned::Trivial(Fp::from(2));
365        let b = Assigned::Rational(Fp::one(), Fp::zero());
366
367        // 2 + (1,0) = 2 + 0 = 2
368        // This fails if addition is implemented using normal rules for rationals.
369        assert_eq!((a + b).evaluate(), a.evaluate());
370        assert_eq!((b + a).evaluate(), a.evaluate());
371    }
372
373    #[test]
374    fn add_rational_to_inv0_rational() {
375        // a = (1,2)
376        // b = (1,0)
377        let a = Assigned::Rational(Fp::one(), Fp::from(2));
378        let b = Assigned::Rational(Fp::one(), Fp::zero());
379
380        // (1,2) + (1,0) = (1,2) + 0 = (1,2)
381        // This fails if addition is implemented using normal rules for rationals.
382        assert_eq!((a + b).evaluate(), a.evaluate());
383        assert_eq!((b + a).evaluate(), a.evaluate());
384    }
385
386    #[test]
387    fn sub_trivial_from_inv0_rational() {
388        // a = 2
389        // b = (1,0)
390        let a = Assigned::Trivial(Fp::from(2));
391        let b = Assigned::Rational(Fp::one(), Fp::zero());
392
393        // (1,0) - 2 = 0 - 2 = -2
394        // This fails if subtraction is implemented using normal rules for rationals.
395        assert_eq!((b - a).evaluate(), (-a).evaluate());
396
397        // 2 - (1,0) = 2 - 0 = 2
398        assert_eq!((a - b).evaluate(), a.evaluate());
399    }
400
401    #[test]
402    fn sub_rational_from_inv0_rational() {
403        // a = (1,2)
404        // b = (1,0)
405        let a = Assigned::Rational(Fp::one(), Fp::from(2));
406        let b = Assigned::Rational(Fp::one(), Fp::zero());
407
408        // (1,0) - (1,2) = 0 - (1,2) = -(1,2)
409        // This fails if subtraction is implemented using normal rules for rationals.
410        assert_eq!((b - a).evaluate(), (-a).evaluate());
411
412        // (1,2) - (1,0) = (1,2) - 0 = (1,2)
413        assert_eq!((a - b).evaluate(), a.evaluate());
414    }
415
416    #[test]
417    fn mul_rational_by_inv0_rational() {
418        // a = (1,2)
419        // b = (1,0)
420        let a = Assigned::Rational(Fp::one(), Fp::from(2));
421        let b = Assigned::Rational(Fp::one(), Fp::zero());
422
423        // (1,2) * (1,0) = (1,2) * 0 = 0
424        assert_eq!((a * b).evaluate(), Fp::zero());
425
426        // (1,0) * (1,2) = 0 * (1,2) = 0
427        assert_eq!((b * a).evaluate(), Fp::zero());
428    }
429}
430
431#[cfg(test)]
432mod proptests {
433    use std::{
434        cmp,
435        convert::TryFrom,
436        ops::{Add, Mul, Neg, Sub},
437    };
438
439    use group::ff::Field;
440    use pasta_curves::{arithmetic::FieldExt, Fp};
441    use proptest::{collection::vec, prelude::*, sample::select};
442
443    use super::Assigned;
444
445    trait UnaryOperand: Neg<Output = Self> {
446        fn double(&self) -> Self;
447        fn square(&self) -> Self;
448        fn cube(&self) -> Self;
449        fn inv0(&self) -> Self;
450    }
451
452    impl<F: Field> UnaryOperand for F {
453        fn double(&self) -> Self {
454            self.double()
455        }
456
457        fn square(&self) -> Self {
458            self.square()
459        }
460
461        fn cube(&self) -> Self {
462            self.cube()
463        }
464
465        fn inv0(&self) -> Self {
466            self.invert().unwrap_or(F::zero())
467        }
468    }
469
470    impl<F: Field> UnaryOperand for Assigned<F> {
471        fn double(&self) -> Self {
472            self.double()
473        }
474
475        fn square(&self) -> Self {
476            self.square()
477        }
478
479        fn cube(&self) -> Self {
480            self.cube()
481        }
482
483        fn inv0(&self) -> Self {
484            self.invert()
485        }
486    }
487
488    #[derive(Clone, Debug)]
489    enum UnaryOperator {
490        Neg,
491        Double,
492        Square,
493        Cube,
494        Inv0,
495    }
496
497    const UNARY_OPERATORS: &[UnaryOperator] = &[
498        UnaryOperator::Neg,
499        UnaryOperator::Double,
500        UnaryOperator::Square,
501        UnaryOperator::Cube,
502        UnaryOperator::Inv0,
503    ];
504
505    impl UnaryOperator {
506        fn apply<F: UnaryOperand>(&self, a: F) -> F {
507            match self {
508                Self::Neg => -a,
509                Self::Double => a.double(),
510                Self::Square => a.square(),
511                Self::Cube => a.cube(),
512                Self::Inv0 => a.inv0(),
513            }
514        }
515    }
516
517    trait BinaryOperand: Sized + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> {}
518    impl<F: Field> BinaryOperand for F {}
519    impl<F: Field> BinaryOperand for Assigned<F> {}
520
521    #[derive(Clone, Debug)]
522    enum BinaryOperator {
523        Add,
524        Sub,
525        Mul,
526    }
527
528    const BINARY_OPERATORS: &[BinaryOperator] = &[
529        BinaryOperator::Add,
530        BinaryOperator::Sub,
531        BinaryOperator::Mul,
532    ];
533
534    impl BinaryOperator {
535        fn apply<F: BinaryOperand>(&self, a: F, b: F) -> F {
536            match self {
537                Self::Add => a + b,
538                Self::Sub => a - b,
539                Self::Mul => a * b,
540            }
541        }
542    }
543
544    #[derive(Clone, Debug)]
545    enum Operator {
546        Unary(UnaryOperator),
547        Binary(BinaryOperator),
548    }
549
550    prop_compose! {
551        /// Use narrow that can be easily reduced.
552        fn arb_element()(val in any::<u64>()) -> Fp {
553            Fp::from(val)
554        }
555    }
556
557    prop_compose! {
558        fn arb_trivial()(element in arb_element()) -> Assigned<Fp> {
559            Assigned::Trivial(element)
560        }
561    }
562
563    prop_compose! {
564        /// Generates half of the denominators as zero to represent a deferred inversion.
565        fn arb_rational()(
566            numerator in arb_element(),
567            denominator in prop_oneof![
568                1 => Just(Fp::zero()),
569                2 => arb_element(),
570            ],
571        ) -> Assigned<Fp> {
572            Assigned::Rational(numerator, denominator)
573        }
574    }
575
576    prop_compose! {
577        fn arb_operators(num_unary: usize, num_binary: usize)(
578            unary in vec(select(UNARY_OPERATORS), num_unary),
579            binary in vec(select(BINARY_OPERATORS), num_binary),
580        ) -> Vec<Operator> {
581            unary.into_iter()
582                .map(Operator::Unary)
583                .chain(binary.into_iter().map(Operator::Binary))
584                .collect()
585        }
586    }
587
588    prop_compose! {
589        fn arb_testcase()(
590            num_unary in 0usize..5,
591            num_binary in 0usize..5,
592        )(
593            values in vec(
594                prop_oneof![
595                    1 => Just(Assigned::Zero),
596                    2 => arb_trivial(),
597                    2 => arb_rational(),
598                ],
599                // Ensure that:
600                // - we have at least one value to apply unary operators to.
601                // - we can apply every binary operator pairwise sequentially.
602                cmp::max(if num_unary > 0 { 1 } else { 0 }, num_binary + 1)),
603            operations in arb_operators(num_unary, num_binary).prop_shuffle(),
604        ) -> (Vec<Assigned<Fp>>, Vec<Operator>) {
605            (values, operations)
606        }
607    }
608
609    proptest! {
610        #[test]
611        fn operation_commutativity((values, operations) in arb_testcase()) {
612            // Evaluate the values at the start.
613            let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
614
615            // Apply the operations to both the deferred and evaluated values.
616            fn evaluate<F: UnaryOperand + BinaryOperand>(
617                items: Vec<F>,
618                operators: &[Operator],
619            ) -> F {
620                let mut ops = operators.iter();
621
622                // Process all binary operators. We are guaranteed to have exactly as many
623                // binary operators as we need calls to the reduction closure.
624                let mut res = items.into_iter().reduce(|mut a, b| loop {
625                    match ops.next() {
626                        Some(Operator::Unary(op)) => a = op.apply(a),
627                        Some(Operator::Binary(op)) => break op.apply(a, b),
628                        None => unreachable!(),
629                    }
630                }).unwrap();
631
632                // Process any unary operators that weren't handled in the reduce() call
633                // above (either if we only had one item, or there were unary operators
634                // after the last binary operator). We are guaranteed to have no binary
635                // operators remaining at this point.
636                loop {
637                    match ops.next() {
638                        Some(Operator::Unary(op)) => res = op.apply(res),
639                        Some(Operator::Binary(_)) => unreachable!(),
640                        None => break res,
641                    }
642                }
643            }
644            let deferred_result = evaluate(values, &operations);
645            let evaluated_result = evaluate(elements, &operations);
646
647            // The two should be equal, i.e. deferred inversion should commute with the
648            // list of operations.
649            assert_eq!(deferred_result.evaluate(), evaluated_result);
650        }
651    }
652}