1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use group::ff::Field;
4
5#[derive(Clone, Copy, Debug)]
11pub enum Assigned<F> {
12 Zero,
14 Trivial(F),
16 Rational(F, F),
18}
19
20impl<F: Field> From<&Assigned<F>> for Assigned<F> {
21 fn from(val: &Assigned<F>) -> Self {
22 *val
23 }
24}
25
26impl<F: Field> From<&F> for Assigned<F> {
27 fn from(numerator: &F) -> Self {
28 Assigned::Trivial(*numerator)
29 }
30}
31
32impl<F: Field> From<F> for Assigned<F> {
33 fn from(numerator: F) -> Self {
34 Assigned::Trivial(numerator)
35 }
36}
37
38impl<F: Field> From<(F, F)> for Assigned<F> {
39 fn from((numerator, denominator): (F, F)) -> Self {
40 Assigned::Rational(numerator, denominator)
41 }
42}
43
44impl<F: Field> PartialEq for Assigned<F> {
45 fn eq(&self, other: &Self) -> bool {
46 match (self, other) {
47 (Self::Zero, Self::Zero) => true,
49 (Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(),
50
51 (Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator))
53 if denominator.is_zero_vartime() =>
54 {
55 x.is_zero_vartime()
56 }
57
58 (Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs,
60 (Self::Trivial(x), Self::Rational(numerator, denominator))
61 | (Self::Rational(numerator, denominator), Self::Trivial(x)) => {
62 &(*x * denominator) == numerator
63 }
64 (
65 Self::Rational(lhs_numerator, lhs_denominator),
66 Self::Rational(rhs_numerator, rhs_denominator),
67 ) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator,
68 }
69 }
70}
71
72impl<F: Field> Eq for Assigned<F> {}
73
74impl<F: Field> Neg for Assigned<F> {
75 type Output = Assigned<F>;
76 fn neg(self) -> Self::Output {
77 match self {
78 Self::Zero => Self::Zero,
79 Self::Trivial(numerator) => Self::Trivial(-numerator),
80 Self::Rational(numerator, denominator) => Self::Rational(-numerator, denominator),
81 }
82 }
83}
84
85impl<F: Field> Add for Assigned<F> {
86 type Output = Assigned<F>;
87 fn add(self, rhs: Assigned<F>) -> Assigned<F> {
88 match (self, rhs) {
89 (Self::Zero, _) => rhs,
91 (_, Self::Zero) => self,
92
93 (Self::Rational(_, denominator), other) | (other, Self::Rational(_, denominator))
95 if denominator.is_zero_vartime() =>
96 {
97 other
98 }
99
100 (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs + rhs),
102 (Self::Rational(numerator, denominator), Self::Trivial(other))
103 | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
104 Self::Rational(numerator + denominator * other, denominator)
105 }
106 (
107 Self::Rational(lhs_numerator, lhs_denominator),
108 Self::Rational(rhs_numerator, rhs_denominator),
109 ) => Self::Rational(
110 lhs_numerator * rhs_denominator + lhs_denominator * rhs_numerator,
111 lhs_denominator * rhs_denominator,
112 ),
113 }
114 }
115}
116
117impl<F: Field> Add<F> for Assigned<F> {
118 type Output = Assigned<F>;
119 fn add(self, rhs: F) -> Assigned<F> {
120 self + Self::Trivial(rhs)
121 }
122}
123
124impl<F: Field> Add<F> for &Assigned<F> {
125 type Output = Assigned<F>;
126 fn add(self, rhs: F) -> Assigned<F> {
127 *self + rhs
128 }
129}
130
131impl<F: Field> Add<&Assigned<F>> for Assigned<F> {
132 type Output = Assigned<F>;
133 fn add(self, rhs: &Self) -> Assigned<F> {
134 self + *rhs
135 }
136}
137
138impl<F: Field> Add<Assigned<F>> for &Assigned<F> {
139 type Output = Assigned<F>;
140 fn add(self, rhs: Assigned<F>) -> Assigned<F> {
141 *self + rhs
142 }
143}
144
145impl<F: Field> Add<&Assigned<F>> for &Assigned<F> {
146 type Output = Assigned<F>;
147 fn add(self, rhs: &Assigned<F>) -> Assigned<F> {
148 *self + *rhs
149 }
150}
151
152impl<F: Field> AddAssign for Assigned<F> {
153 fn add_assign(&mut self, rhs: Self) {
154 *self = *self + rhs;
155 }
156}
157
158impl<F: Field> AddAssign<&Assigned<F>> for Assigned<F> {
159 fn add_assign(&mut self, rhs: &Self) {
160 *self = *self + rhs;
161 }
162}
163
164impl<F: Field> Sub for Assigned<F> {
165 type Output = Assigned<F>;
166 fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
167 self + (-rhs)
168 }
169}
170
171impl<F: Field> Sub<F> for Assigned<F> {
172 type Output = Assigned<F>;
173 fn sub(self, rhs: F) -> Assigned<F> {
174 self + (-rhs)
175 }
176}
177
178impl<F: Field> Sub<F> for &Assigned<F> {
179 type Output = Assigned<F>;
180 fn sub(self, rhs: F) -> Assigned<F> {
181 *self - rhs
182 }
183}
184
185impl<F: Field> Sub<&Assigned<F>> for Assigned<F> {
186 type Output = Assigned<F>;
187 fn sub(self, rhs: &Self) -> Assigned<F> {
188 self - *rhs
189 }
190}
191
192impl<F: Field> Sub<Assigned<F>> for &Assigned<F> {
193 type Output = Assigned<F>;
194 fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
195 *self - rhs
196 }
197}
198
199impl<F: Field> Sub<&Assigned<F>> for &Assigned<F> {
200 type Output = Assigned<F>;
201 fn sub(self, rhs: &Assigned<F>) -> Assigned<F> {
202 *self - *rhs
203 }
204}
205
206impl<F: Field> SubAssign for Assigned<F> {
207 fn sub_assign(&mut self, rhs: Self) {
208 *self = *self - rhs;
209 }
210}
211
212impl<F: Field> SubAssign<&Assigned<F>> for Assigned<F> {
213 fn sub_assign(&mut self, rhs: &Self) {
214 *self = *self - rhs;
215 }
216}
217
218impl<F: Field> Mul for Assigned<F> {
219 type Output = Assigned<F>;
220 fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
221 match (self, rhs) {
222 (Self::Zero, _) | (_, Self::Zero) => Self::Zero,
223 (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs * rhs),
224 (Self::Rational(numerator, denominator), Self::Trivial(other))
225 | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
226 Self::Rational(numerator * other, denominator)
227 }
228 (
229 Self::Rational(lhs_numerator, lhs_denominator),
230 Self::Rational(rhs_numerator, rhs_denominator),
231 ) => Self::Rational(
232 lhs_numerator * rhs_numerator,
233 lhs_denominator * rhs_denominator,
234 ),
235 }
236 }
237}
238
239impl<F: Field> Mul<F> for Assigned<F> {
240 type Output = Assigned<F>;
241 fn mul(self, rhs: F) -> Assigned<F> {
242 self * Self::Trivial(rhs)
243 }
244}
245
246impl<F: Field> Mul<&Assigned<F>> for Assigned<F> {
247 type Output = Assigned<F>;
248 fn mul(self, rhs: &Assigned<F>) -> Assigned<F> {
249 self * *rhs
250 }
251}
252
253impl<F: Field> MulAssign for Assigned<F> {
254 fn mul_assign(&mut self, rhs: Self) {
255 *self = *self * rhs;
256 }
257}
258
259impl<F: Field> MulAssign<&Assigned<F>> for Assigned<F> {
260 fn mul_assign(&mut self, rhs: &Self) {
261 *self = *self * rhs;
262 }
263}
264
265impl<F: Field> Assigned<F> {
266 pub fn numerator(&self) -> F {
268 match self {
269 Self::Zero => F::zero(),
270 Self::Trivial(x) => *x,
271 Self::Rational(numerator, _) => *numerator,
272 }
273 }
274
275 pub fn denominator(&self) -> Option<F> {
277 match self {
278 Self::Zero => None,
279 Self::Trivial(_) => None,
280 Self::Rational(_, denominator) => Some(*denominator),
281 }
282 }
283
284 pub fn is_zero_vartime(&self) -> bool {
286 match self {
287 Self::Zero => true,
288 Self::Trivial(x) => x.is_zero_vartime(),
289 Self::Rational(numerator, denominator) => {
291 numerator.is_zero_vartime() || denominator.is_zero_vartime()
292 }
293 }
294 }
295
296 #[must_use]
298 pub fn double(&self) -> Self {
299 match self {
300 Self::Zero => Self::Zero,
301 Self::Trivial(x) => Self::Trivial(x.double()),
302 Self::Rational(numerator, denominator) => {
303 Self::Rational(numerator.double(), *denominator)
304 }
305 }
306 }
307
308 #[must_use]
310 pub fn square(&self) -> Self {
311 match self {
312 Self::Zero => Self::Zero,
313 Self::Trivial(x) => Self::Trivial(x.square()),
314 Self::Rational(numerator, denominator) => {
315 Self::Rational(numerator.square(), denominator.square())
316 }
317 }
318 }
319
320 #[must_use]
322 pub fn cube(&self) -> Self {
323 self.square() * self
324 }
325
326 pub fn invert(&self) -> Self {
328 match self {
329 Self::Zero => Self::Zero,
330 Self::Trivial(x) => Self::Rational(F::one(), *x),
331 Self::Rational(numerator, denominator) => Self::Rational(*denominator, *numerator),
332 }
333 }
334
335 pub fn evaluate(self) -> F {
340 match self {
341 Self::Zero => F::zero(),
342 Self::Trivial(x) => x,
343 Self::Rational(numerator, denominator) => {
344 if denominator == F::one() {
345 numerator
346 } else {
347 numerator * denominator.invert().unwrap_or(F::zero())
348 }
349 }
350 }
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use pasta_curves::Fp;
357
358 use super::Assigned;
359 #[test]
361 fn add_trivial_to_inv0_rational() {
362 let a = Assigned::Trivial(Fp::from(2));
365 let b = Assigned::Rational(Fp::one(), Fp::zero());
366
367 assert_eq!((a + b).evaluate(), a.evaluate());
370 assert_eq!((b + a).evaluate(), a.evaluate());
371 }
372
373 #[test]
374 fn add_rational_to_inv0_rational() {
375 let a = Assigned::Rational(Fp::one(), Fp::from(2));
378 let b = Assigned::Rational(Fp::one(), Fp::zero());
379
380 assert_eq!((a + b).evaluate(), a.evaluate());
383 assert_eq!((b + a).evaluate(), a.evaluate());
384 }
385
386 #[test]
387 fn sub_trivial_from_inv0_rational() {
388 let a = Assigned::Trivial(Fp::from(2));
391 let b = Assigned::Rational(Fp::one(), Fp::zero());
392
393 assert_eq!((b - a).evaluate(), (-a).evaluate());
396
397 assert_eq!((a - b).evaluate(), a.evaluate());
399 }
400
401 #[test]
402 fn sub_rational_from_inv0_rational() {
403 let a = Assigned::Rational(Fp::one(), Fp::from(2));
406 let b = Assigned::Rational(Fp::one(), Fp::zero());
407
408 assert_eq!((b - a).evaluate(), (-a).evaluate());
411
412 assert_eq!((a - b).evaluate(), a.evaluate());
414 }
415
416 #[test]
417 fn mul_rational_by_inv0_rational() {
418 let a = Assigned::Rational(Fp::one(), Fp::from(2));
421 let b = Assigned::Rational(Fp::one(), Fp::zero());
422
423 assert_eq!((a * b).evaluate(), Fp::zero());
425
426 assert_eq!((b * a).evaluate(), Fp::zero());
428 }
429}
430
431#[cfg(test)]
432mod proptests {
433 use std::{
434 cmp,
435 convert::TryFrom,
436 ops::{Add, Mul, Neg, Sub},
437 };
438
439 use group::ff::Field;
440 use pasta_curves::{arithmetic::FieldExt, Fp};
441 use proptest::{collection::vec, prelude::*, sample::select};
442
443 use super::Assigned;
444
445 trait UnaryOperand: Neg<Output = Self> {
446 fn double(&self) -> Self;
447 fn square(&self) -> Self;
448 fn cube(&self) -> Self;
449 fn inv0(&self) -> Self;
450 }
451
452 impl<F: Field> UnaryOperand for F {
453 fn double(&self) -> Self {
454 self.double()
455 }
456
457 fn square(&self) -> Self {
458 self.square()
459 }
460
461 fn cube(&self) -> Self {
462 self.cube()
463 }
464
465 fn inv0(&self) -> Self {
466 self.invert().unwrap_or(F::zero())
467 }
468 }
469
470 impl<F: Field> UnaryOperand for Assigned<F> {
471 fn double(&self) -> Self {
472 self.double()
473 }
474
475 fn square(&self) -> Self {
476 self.square()
477 }
478
479 fn cube(&self) -> Self {
480 self.cube()
481 }
482
483 fn inv0(&self) -> Self {
484 self.invert()
485 }
486 }
487
488 #[derive(Clone, Debug)]
489 enum UnaryOperator {
490 Neg,
491 Double,
492 Square,
493 Cube,
494 Inv0,
495 }
496
497 const UNARY_OPERATORS: &[UnaryOperator] = &[
498 UnaryOperator::Neg,
499 UnaryOperator::Double,
500 UnaryOperator::Square,
501 UnaryOperator::Cube,
502 UnaryOperator::Inv0,
503 ];
504
505 impl UnaryOperator {
506 fn apply<F: UnaryOperand>(&self, a: F) -> F {
507 match self {
508 Self::Neg => -a,
509 Self::Double => a.double(),
510 Self::Square => a.square(),
511 Self::Cube => a.cube(),
512 Self::Inv0 => a.inv0(),
513 }
514 }
515 }
516
517 trait BinaryOperand: Sized + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> {}
518 impl<F: Field> BinaryOperand for F {}
519 impl<F: Field> BinaryOperand for Assigned<F> {}
520
521 #[derive(Clone, Debug)]
522 enum BinaryOperator {
523 Add,
524 Sub,
525 Mul,
526 }
527
528 const BINARY_OPERATORS: &[BinaryOperator] = &[
529 BinaryOperator::Add,
530 BinaryOperator::Sub,
531 BinaryOperator::Mul,
532 ];
533
534 impl BinaryOperator {
535 fn apply<F: BinaryOperand>(&self, a: F, b: F) -> F {
536 match self {
537 Self::Add => a + b,
538 Self::Sub => a - b,
539 Self::Mul => a * b,
540 }
541 }
542 }
543
544 #[derive(Clone, Debug)]
545 enum Operator {
546 Unary(UnaryOperator),
547 Binary(BinaryOperator),
548 }
549
550 prop_compose! {
551 fn arb_element()(val in any::<u64>()) -> Fp {
553 Fp::from(val)
554 }
555 }
556
557 prop_compose! {
558 fn arb_trivial()(element in arb_element()) -> Assigned<Fp> {
559 Assigned::Trivial(element)
560 }
561 }
562
563 prop_compose! {
564 fn arb_rational()(
566 numerator in arb_element(),
567 denominator in prop_oneof![
568 1 => Just(Fp::zero()),
569 2 => arb_element(),
570 ],
571 ) -> Assigned<Fp> {
572 Assigned::Rational(numerator, denominator)
573 }
574 }
575
576 prop_compose! {
577 fn arb_operators(num_unary: usize, num_binary: usize)(
578 unary in vec(select(UNARY_OPERATORS), num_unary),
579 binary in vec(select(BINARY_OPERATORS), num_binary),
580 ) -> Vec<Operator> {
581 unary.into_iter()
582 .map(Operator::Unary)
583 .chain(binary.into_iter().map(Operator::Binary))
584 .collect()
585 }
586 }
587
588 prop_compose! {
589 fn arb_testcase()(
590 num_unary in 0usize..5,
591 num_binary in 0usize..5,
592 )(
593 values in vec(
594 prop_oneof![
595 1 => Just(Assigned::Zero),
596 2 => arb_trivial(),
597 2 => arb_rational(),
598 ],
599 cmp::max(if num_unary > 0 { 1 } else { 0 }, num_binary + 1)),
603 operations in arb_operators(num_unary, num_binary).prop_shuffle(),
604 ) -> (Vec<Assigned<Fp>>, Vec<Operator>) {
605 (values, operations)
606 }
607 }
608
609 proptest! {
610 #[test]
611 fn operation_commutativity((values, operations) in arb_testcase()) {
612 let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
614
615 fn evaluate<F: UnaryOperand + BinaryOperand>(
617 items: Vec<F>,
618 operators: &[Operator],
619 ) -> F {
620 let mut ops = operators.iter();
621
622 let mut res = items.into_iter().reduce(|mut a, b| loop {
625 match ops.next() {
626 Some(Operator::Unary(op)) => a = op.apply(a),
627 Some(Operator::Binary(op)) => break op.apply(a, b),
628 None => unreachable!(),
629 }
630 }).unwrap();
631
632 loop {
637 match ops.next() {
638 Some(Operator::Unary(op)) => res = op.apply(res),
639 Some(Operator::Binary(_)) => unreachable!(),
640 None => break res,
641 }
642 }
643 }
644 let deferred_result = evaluate(values, &operations);
645 let evaluated_result = evaluate(elements, &operations);
646
647 assert_eq!(deferred_result.evaluate(), evaluated_result);
650 }
651 }
652}