1use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
2
3use group::ff::Field;
4
5#[derive(Clone, Copy, Debug)]
11pub enum Assigned<F> {
12 Zero,
14 Trivial(F),
16 Rational(F, F),
18}
19
20impl<F: Field> From<&Assigned<F>> for Assigned<F> {
21 fn from(val: &Assigned<F>) -> Self {
22 *val
23 }
24}
25
26impl<F: Field> From<&F> for Assigned<F> {
27 fn from(numerator: &F) -> Self {
28 Assigned::Trivial(*numerator)
29 }
30}
31
32impl<F: Field> From<F> for Assigned<F> {
33 fn from(numerator: F) -> Self {
34 Assigned::Trivial(numerator)
35 }
36}
37
38impl<F: Field> From<(F, F)> for Assigned<F> {
39 fn from((numerator, denominator): (F, F)) -> Self {
40 Assigned::Rational(numerator, denominator)
41 }
42}
43
44impl<F: Field> PartialEq for Assigned<F> {
45 fn eq(&self, other: &Self) -> bool {
46 match (self, other) {
47 (Self::Zero, Self::Zero) => true,
49 (Self::Zero, x) | (x, Self::Zero) => x.is_zero_vartime(),
50
51 (Self::Rational(_, denominator), x) | (x, Self::Rational(_, denominator))
53 if denominator.is_zero_vartime() =>
54 {
55 x.is_zero_vartime()
56 }
57
58 (Self::Trivial(lhs), Self::Trivial(rhs)) => lhs == rhs,
60 (Self::Trivial(x), Self::Rational(numerator, denominator))
61 | (Self::Rational(numerator, denominator), Self::Trivial(x)) => {
62 &(*x * denominator) == numerator
63 }
64 (
65 Self::Rational(lhs_numerator, lhs_denominator),
66 Self::Rational(rhs_numerator, rhs_denominator),
67 ) => *lhs_numerator * rhs_denominator == *lhs_denominator * rhs_numerator,
68 }
69 }
70}
71
72impl<F: Field> Eq for Assigned<F> {}
73
74impl<F: Field> Neg for Assigned<F> {
75 type Output = Assigned<F>;
76 fn neg(self) -> Self::Output {
77 match self {
78 Self::Zero => Self::Zero,
79 Self::Trivial(numerator) => Self::Trivial(-numerator),
80 Self::Rational(numerator, denominator) => Self::Rational(-numerator, denominator),
81 }
82 }
83}
84
85impl<F: Field> Neg for &Assigned<F> {
86 type Output = Assigned<F>;
87 fn neg(self) -> Self::Output {
88 -*self
89 }
90}
91
92impl<F: Field> Add for Assigned<F> {
93 type Output = Assigned<F>;
94 fn add(self, rhs: Assigned<F>) -> Assigned<F> {
95 match (self, rhs) {
96 (Self::Zero, _) => rhs,
98 (_, Self::Zero) => self,
99
100 (Self::Rational(_, denominator), other) | (other, Self::Rational(_, denominator))
102 if denominator.is_zero_vartime() =>
103 {
104 other
105 }
106
107 (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs + rhs),
109 (Self::Rational(numerator, denominator), Self::Trivial(other))
110 | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
111 Self::Rational(numerator + denominator * other, denominator)
112 }
113 (
114 Self::Rational(lhs_numerator, lhs_denominator),
115 Self::Rational(rhs_numerator, rhs_denominator),
116 ) => Self::Rational(
117 lhs_numerator * rhs_denominator + lhs_denominator * rhs_numerator,
118 lhs_denominator * rhs_denominator,
119 ),
120 }
121 }
122}
123
124impl<F: Field> Add<F> for Assigned<F> {
125 type Output = Assigned<F>;
126 fn add(self, rhs: F) -> Assigned<F> {
127 self + Self::Trivial(rhs)
128 }
129}
130
131impl<F: Field> Add<F> for &Assigned<F> {
132 type Output = Assigned<F>;
133 fn add(self, rhs: F) -> Assigned<F> {
134 *self + rhs
135 }
136}
137
138impl<F: Field> Add<&Assigned<F>> for Assigned<F> {
139 type Output = Assigned<F>;
140 fn add(self, rhs: &Self) -> Assigned<F> {
141 self + *rhs
142 }
143}
144
145impl<F: Field> Add<Assigned<F>> for &Assigned<F> {
146 type Output = Assigned<F>;
147 fn add(self, rhs: Assigned<F>) -> Assigned<F> {
148 *self + rhs
149 }
150}
151
152impl<F: Field> Add<&Assigned<F>> for &Assigned<F> {
153 type Output = Assigned<F>;
154 fn add(self, rhs: &Assigned<F>) -> Assigned<F> {
155 *self + *rhs
156 }
157}
158
159impl<F: Field> AddAssign for Assigned<F> {
160 fn add_assign(&mut self, rhs: Self) {
161 *self = *self + rhs;
162 }
163}
164
165impl<F: Field> AddAssign<&Assigned<F>> for Assigned<F> {
166 fn add_assign(&mut self, rhs: &Self) {
167 *self = *self + rhs;
168 }
169}
170
171impl<F: Field> Sub for Assigned<F> {
172 type Output = Assigned<F>;
173 fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
174 self + (-rhs)
175 }
176}
177
178impl<F: Field> Sub<F> for Assigned<F> {
179 type Output = Assigned<F>;
180 fn sub(self, rhs: F) -> Assigned<F> {
181 self + (-rhs)
182 }
183}
184
185impl<F: Field> Sub<F> for &Assigned<F> {
186 type Output = Assigned<F>;
187 fn sub(self, rhs: F) -> Assigned<F> {
188 *self - rhs
189 }
190}
191
192impl<F: Field> Sub<&Assigned<F>> for Assigned<F> {
193 type Output = Assigned<F>;
194 fn sub(self, rhs: &Self) -> Assigned<F> {
195 self - *rhs
196 }
197}
198
199impl<F: Field> Sub<Assigned<F>> for &Assigned<F> {
200 type Output = Assigned<F>;
201 fn sub(self, rhs: Assigned<F>) -> Assigned<F> {
202 *self - rhs
203 }
204}
205
206impl<F: Field> Sub<&Assigned<F>> for &Assigned<F> {
207 type Output = Assigned<F>;
208 fn sub(self, rhs: &Assigned<F>) -> Assigned<F> {
209 *self - *rhs
210 }
211}
212
213impl<F: Field> SubAssign for Assigned<F> {
214 fn sub_assign(&mut self, rhs: Self) {
215 *self = *self - rhs;
216 }
217}
218
219impl<F: Field> SubAssign<&Assigned<F>> for Assigned<F> {
220 fn sub_assign(&mut self, rhs: &Self) {
221 *self = *self - rhs;
222 }
223}
224
225impl<F: Field> Mul for Assigned<F> {
226 type Output = Assigned<F>;
227 fn mul(self, rhs: Assigned<F>) -> Assigned<F> {
228 match (self, rhs) {
229 (Self::Zero, _) | (_, Self::Zero) => Self::Zero,
230 (Self::Trivial(lhs), Self::Trivial(rhs)) => Self::Trivial(lhs * rhs),
231 (Self::Rational(numerator, denominator), Self::Trivial(other))
232 | (Self::Trivial(other), Self::Rational(numerator, denominator)) => {
233 Self::Rational(numerator * other, denominator)
234 }
235 (
236 Self::Rational(lhs_numerator, lhs_denominator),
237 Self::Rational(rhs_numerator, rhs_denominator),
238 ) => Self::Rational(
239 lhs_numerator * rhs_numerator,
240 lhs_denominator * rhs_denominator,
241 ),
242 }
243 }
244}
245
246impl<F: Field> Mul<F> for Assigned<F> {
247 type Output = Assigned<F>;
248 fn mul(self, rhs: F) -> Assigned<F> {
249 self * Self::Trivial(rhs)
250 }
251}
252
253impl<F: Field> Mul<F> for &Assigned<F> {
254 type Output = Assigned<F>;
255 fn mul(self, rhs: F) -> Assigned<F> {
256 *self * rhs
257 }
258}
259
260impl<F: Field> Mul<&Assigned<F>> for Assigned<F> {
261 type Output = Assigned<F>;
262 fn mul(self, rhs: &Assigned<F>) -> Assigned<F> {
263 self * *rhs
264 }
265}
266
267impl<F: Field> MulAssign for Assigned<F> {
268 fn mul_assign(&mut self, rhs: Self) {
269 *self = *self * rhs;
270 }
271}
272
273impl<F: Field> MulAssign<&Assigned<F>> for Assigned<F> {
274 fn mul_assign(&mut self, rhs: &Self) {
275 *self = *self * rhs;
276 }
277}
278
279impl<F: Field> Assigned<F> {
280 pub fn numerator(&self) -> F {
282 match self {
283 Self::Zero => F::ZERO,
284 Self::Trivial(x) => *x,
285 Self::Rational(numerator, _) => *numerator,
286 }
287 }
288
289 pub fn denominator(&self) -> Option<F> {
291 match self {
292 Self::Zero => None,
293 Self::Trivial(_) => None,
294 Self::Rational(_, denominator) => Some(*denominator),
295 }
296 }
297
298 pub fn is_zero_vartime(&self) -> bool {
300 match self {
301 Self::Zero => true,
302 Self::Trivial(x) => x.is_zero_vartime(),
303 Self::Rational(numerator, denominator) => {
305 numerator.is_zero_vartime() || denominator.is_zero_vartime()
306 }
307 }
308 }
309
310 #[must_use]
312 pub fn double(&self) -> Self {
313 match self {
314 Self::Zero => Self::Zero,
315 Self::Trivial(x) => Self::Trivial(x.double()),
316 Self::Rational(numerator, denominator) => {
317 Self::Rational(numerator.double(), *denominator)
318 }
319 }
320 }
321
322 #[must_use]
324 pub fn square(&self) -> Self {
325 match self {
326 Self::Zero => Self::Zero,
327 Self::Trivial(x) => Self::Trivial(x.square()),
328 Self::Rational(numerator, denominator) => {
329 Self::Rational(numerator.square(), denominator.square())
330 }
331 }
332 }
333
334 #[must_use]
336 pub fn cube(&self) -> Self {
337 self.square() * self
338 }
339
340 pub fn invert(&self) -> Self {
342 match self {
343 Self::Zero => Self::Zero,
344 Self::Trivial(x) => Self::Rational(F::ONE, *x),
345 Self::Rational(numerator, denominator) => Self::Rational(*denominator, *numerator),
346 }
347 }
348
349 pub fn evaluate(self) -> F {
354 match self {
355 Self::Zero => F::ZERO,
356 Self::Trivial(x) => x,
357 Self::Rational(numerator, denominator) => {
358 if denominator == F::ONE {
359 numerator
360 } else {
361 numerator * denominator.invert().unwrap_or(F::ZERO)
362 }
363 }
364 }
365 }
366}
367
368#[cfg(test)]
369mod tests {
370 use halo2curves::pasta::Fp;
371
372 use super::Assigned;
373 #[test]
375 fn add_trivial_to_inv0_rational() {
376 let a = Assigned::Trivial(Fp::from(2));
379 let b = Assigned::Rational(Fp::one(), Fp::zero());
380
381 assert_eq!((a + b).evaluate(), a.evaluate());
384 assert_eq!((b + a).evaluate(), a.evaluate());
385 }
386
387 #[test]
388 fn add_rational_to_inv0_rational() {
389 let a = Assigned::Rational(Fp::one(), Fp::from(2));
392 let b = Assigned::Rational(Fp::one(), Fp::zero());
393
394 assert_eq!((a + b).evaluate(), a.evaluate());
397 assert_eq!((b + a).evaluate(), a.evaluate());
398 }
399
400 #[test]
401 fn sub_trivial_from_inv0_rational() {
402 let a = Assigned::Trivial(Fp::from(2));
405 let b = Assigned::Rational(Fp::one(), Fp::zero());
406
407 assert_eq!((b - a).evaluate(), (-a).evaluate());
410
411 assert_eq!((a - b).evaluate(), a.evaluate());
413 }
414
415 #[test]
416 fn sub_rational_from_inv0_rational() {
417 let a = Assigned::Rational(Fp::one(), Fp::from(2));
420 let b = Assigned::Rational(Fp::one(), Fp::zero());
421
422 assert_eq!((b - a).evaluate(), (-a).evaluate());
425
426 assert_eq!((a - b).evaluate(), a.evaluate());
428 }
429
430 #[test]
431 fn mul_rational_by_inv0_rational() {
432 let a = Assigned::Rational(Fp::one(), Fp::from(2));
435 let b = Assigned::Rational(Fp::one(), Fp::zero());
436
437 assert_eq!((a * b).evaluate(), Fp::zero());
439
440 assert_eq!((b * a).evaluate(), Fp::zero());
442 }
443}
444
445#[cfg(test)]
446mod proptests {
447 use std::{
448 cmp,
449 ops::{Add, Mul, Neg, Sub},
450 };
451
452 use group::ff::Field;
453 use halo2curves::pasta::Fp;
454 use proptest::{collection::vec, prelude::*, sample::select};
455
456 use super::Assigned;
457
458 trait UnaryOperand: Neg<Output = Self> {
459 fn double(&self) -> Self;
460 fn square(&self) -> Self;
461 fn cube(&self) -> Self;
462 fn inv0(&self) -> Self;
463 }
464
465 impl<F: Field> UnaryOperand for F {
466 fn double(&self) -> Self {
467 self.double()
468 }
469
470 fn square(&self) -> Self {
471 self.square()
472 }
473
474 fn cube(&self) -> Self {
475 self.cube()
476 }
477
478 fn inv0(&self) -> Self {
479 self.invert().unwrap_or(F::ZERO)
480 }
481 }
482
483 impl<F: Field> UnaryOperand for Assigned<F> {
484 fn double(&self) -> Self {
485 self.double()
486 }
487
488 fn square(&self) -> Self {
489 self.square()
490 }
491
492 fn cube(&self) -> Self {
493 self.cube()
494 }
495
496 fn inv0(&self) -> Self {
497 self.invert()
498 }
499 }
500
501 #[derive(Clone, Debug)]
502 enum UnaryOperator {
503 Neg,
504 Double,
505 Square,
506 Cube,
507 Inv0,
508 }
509
510 const UNARY_OPERATORS: &[UnaryOperator] = &[
511 UnaryOperator::Neg,
512 UnaryOperator::Double,
513 UnaryOperator::Square,
514 UnaryOperator::Cube,
515 UnaryOperator::Inv0,
516 ];
517
518 impl UnaryOperator {
519 fn apply<F: UnaryOperand>(&self, a: F) -> F {
520 match self {
521 Self::Neg => -a,
522 Self::Double => a.double(),
523 Self::Square => a.square(),
524 Self::Cube => a.cube(),
525 Self::Inv0 => a.inv0(),
526 }
527 }
528 }
529
530 trait BinaryOperand: Sized + Add<Output = Self> + Sub<Output = Self> + Mul<Output = Self> {}
531 impl<F: Field> BinaryOperand for F {}
532 impl<F: Field> BinaryOperand for Assigned<F> {}
533
534 #[derive(Clone, Debug)]
535 enum BinaryOperator {
536 Add,
537 Sub,
538 Mul,
539 }
540
541 const BINARY_OPERATORS: &[BinaryOperator] = &[
542 BinaryOperator::Add,
543 BinaryOperator::Sub,
544 BinaryOperator::Mul,
545 ];
546
547 impl BinaryOperator {
548 fn apply<F: BinaryOperand>(&self, a: F, b: F) -> F {
549 match self {
550 Self::Add => a + b,
551 Self::Sub => a - b,
552 Self::Mul => a * b,
553 }
554 }
555 }
556
557 #[derive(Clone, Debug)]
558 enum Operator {
559 Unary(UnaryOperator),
560 Binary(BinaryOperator),
561 }
562
563 prop_compose! {
564 fn arb_element()(val in any::<u64>()) -> Fp {
566 Fp::from(val)
567 }
568 }
569
570 prop_compose! {
571 fn arb_trivial()(element in arb_element()) -> Assigned<Fp> {
572 Assigned::Trivial(element)
573 }
574 }
575
576 prop_compose! {
577 fn arb_rational()(
579 numerator in arb_element(),
580 denominator in prop_oneof![
581 1 => Just(Fp::zero()),
582 2 => arb_element(),
583 ],
584 ) -> Assigned<Fp> {
585 Assigned::Rational(numerator, denominator)
586 }
587 }
588
589 prop_compose! {
590 fn arb_operators(num_unary: usize, num_binary: usize)(
591 unary in vec(select(UNARY_OPERATORS), num_unary),
592 binary in vec(select(BINARY_OPERATORS), num_binary),
593 ) -> Vec<Operator> {
594 unary.into_iter()
595 .map(Operator::Unary)
596 .chain(binary.into_iter().map(Operator::Binary))
597 .collect()
598 }
599 }
600
601 prop_compose! {
602 fn arb_testcase()(
603 num_unary in 0usize..5,
604 num_binary in 0usize..5,
605 )(
606 values in vec(
607 prop_oneof![
608 1 => Just(Assigned::Zero),
609 2 => arb_trivial(),
610 2 => arb_rational(),
611 ],
612 cmp::max(usize::from(num_unary > 0), num_binary + 1)),
616 operations in arb_operators(num_unary, num_binary).prop_shuffle(),
617 ) -> (Vec<Assigned<Fp>>, Vec<Operator>) {
618 (values, operations)
619 }
620 }
621
622 proptest! {
623 #[test]
624 fn operation_commutativity((values, operations) in arb_testcase()) {
625 let elements: Vec<_> = values.iter().cloned().map(|v| v.evaluate()).collect();
627
628 fn evaluate<F: UnaryOperand + BinaryOperand>(
630 items: Vec<F>,
631 operators: &[Operator],
632 ) -> F {
633 let mut ops = operators.iter();
634
635 let mut res = items.into_iter().reduce(|mut a, b| loop {
638 match ops.next() {
639 Some(Operator::Unary(op)) => a = op.apply(a),
640 Some(Operator::Binary(op)) => break op.apply(a, b),
641 None => unreachable!(),
642 }
643 }).unwrap();
644
645 loop {
650 match ops.next() {
651 Some(Operator::Unary(op)) => res = op.apply(res),
652 Some(Operator::Binary(_)) => unreachable!(),
653 None => break res,
654 }
655 }
656 }
657 let deferred_result = evaluate(values, &operations);
658 let evaluated_result = evaluate(elements, &operations);
659
660 assert_eq!(deferred_result.evaluate(), evaluated_result);
663 }
664 }
665}