substrate_bn/
arith.rs

1use core::cmp::Ordering;
2use rand::Rng;
3use crunchy::unroll;
4
5use byteorder::{BigEndian, ByteOrder};
6
7/// 256-bit, stack allocated biginteger for use in prime field
8/// arithmetic.
9#[derive(Copy, Clone, Debug, PartialEq, Eq)]
10#[repr(C)]
11pub struct U256(pub [u128; 2]);
12
13impl From<[u64; 4]> for U256 {
14    fn from(d: [u64; 4]) -> Self {
15        let mut a = [0u128; 2];
16        a[0] = (d[1] as u128) << 64 | d[0] as u128;
17        a[1] = (d[3] as u128) << 64 | d[2] as u128;
18        U256(a)
19    }
20}
21
22impl From<u64> for U256 {
23    fn from(d: u64) -> Self {
24        U256::from([d, 0, 0, 0])
25    }
26}
27
28/// 512-bit, stack allocated biginteger for use in extension
29/// field serialization and scalar interpretation.
30#[derive(Copy, Clone, Debug, PartialEq, Eq)]
31#[repr(C)]
32pub struct U512(pub [u128; 4]);
33
34impl From<[u64; 8]> for U512 {
35    fn from(d: [u64; 8]) -> Self {
36        let mut a = [0u128; 4];
37        a[0] = (d[1] as u128) << 64 | d[0] as u128;
38        a[1] = (d[3] as u128) << 64 | d[2] as u128;
39        a[2] = (d[5] as u128) << 64 | d[4] as u128;
40        a[3] = (d[7] as u128) << 64 | d[6] as u128;
41        U512(a)
42    }
43}
44
45impl U512 {
46    /// Multiplies c1 by modulo, adds c0.
47    pub fn new(c1: &U256, c0: &U256, modulo: &U256) -> U512 {
48        let mut res = [0; 4];
49
50        debug_assert_eq!(c1.0.len(), 2);
51        unroll! {
52            for i in 0..2 {
53                mac_digit(i, &mut res, &modulo.0, c1.0[i]);
54            }
55        }
56
57        let mut carry = 0;
58
59        debug_assert_eq!(res.len(), 4);
60        unroll! {
61            for i in 0..2 {
62                res[i] = adc(res[i], c0.0[i], &mut carry);
63            }
64        }
65
66        unroll! {
67            for i in 0..2 {
68                let (a1, a0) = split_u128(res[i + 2]);
69                let (c, r0) = split_u128(a0 + carry);
70                let (c, r1) = split_u128(a1 + c);
71                carry = c;
72
73                res[i + 2] = combine_u128(r1, r0);
74            }
75        }
76
77        debug_assert!(0 == carry);
78
79        U512(res)
80    }
81
82     pub fn from_slice(s: &[u8]) -> Result<U512, Error> {
83        if s.len() != 64 {
84            return Err(Error::InvalidLength {
85                expected: 32,
86                actual: s.len(),
87            });
88        }
89
90        let mut n = [0; 4];
91        for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
92            n[l] = BigEndian::read_u128(&s[i..]);
93        }
94
95        Ok(U512(n))
96    }
97
98    /// Get a random U512
99    pub fn random<R: Rng>(rng: &mut R) -> U512 {
100        U512(rng.gen())
101    }
102
103    pub fn get_bit(&self, n: usize) -> Option<bool> {
104        if n >= 512 {
105            None
106        } else {
107            let part = n / 128;
108            let bit = n - (128 * part);
109
110            Some(self.0[part] & (1 << bit) > 0)
111        }
112    }
113
114    /// Divides self by modulo, returning remainder and, if
115    /// possible, a quotient smaller than the modulus.
116    pub fn divrem(&self, modulo: &U256) -> (Option<U256>, U256) {
117        let mut q = Some(U256::zero());
118        let mut r = U256::zero();
119
120        for i in (0..512).rev() {
121            // NB: modulo's first two bits are always unset
122            // so this will never destroy information
123            mul2(&mut r.0);
124            assert!(r.set_bit(0, self.get_bit(i).unwrap()));
125            if &r >= modulo {
126                sub_noborrow(&mut r.0, &modulo.0);
127                if q.is_some() && !q.as_mut().unwrap().set_bit(i, true) {
128                    q = None
129                }
130            }
131        }
132
133        if q.is_some() && (q.as_ref().unwrap() >= modulo) {
134            (None, r)
135        } else {
136            (q, r)
137        }
138    }
139
140    pub fn interpret(buf: &[u8; 64]) -> U512 {
141        let mut n = [0; 4];
142        for (l, i) in (0..4).rev().zip((0..4).map(|i| i * 16)) {
143            n[l] = BigEndian::read_u128(&buf[i..]);
144        }
145
146        U512(n)
147    }
148}
149
150impl Ord for U512 {
151    #[inline]
152    fn cmp(&self, other: &U512) -> Ordering {
153        for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
154            if *a < *b {
155                return Ordering::Less;
156            } else if *a > *b {
157                return Ordering::Greater;
158            }
159        }
160
161        return Ordering::Equal;
162    }
163}
164
165impl PartialOrd for U512 {
166    #[inline]
167    fn partial_cmp(&self, other: &U512) -> Option<Ordering> {
168        Some(self.cmp(other))
169    }
170}
171
172impl Ord for U256 {
173    #[inline]
174    fn cmp(&self, other: &U256) -> Ordering {
175        for (a, b) in self.0.iter().zip(other.0.iter()).rev() {
176            if *a < *b {
177                return Ordering::Less;
178            } else if *a > *b {
179                return Ordering::Greater;
180            }
181        }
182
183        return Ordering::Equal;
184    }
185}
186
187impl PartialOrd for U256 {
188    #[inline]
189    fn partial_cmp(&self, other: &U256) -> Option<Ordering> {
190        Some(self.cmp(other))
191    }
192}
193
194/// U256/U512 errors
195#[derive(Debug)]
196pub enum Error {
197    InvalidLength { expected: usize, actual: usize },
198}
199
200impl U256 {
201    /// Initialize U256 from slice of bytes (big endian)
202    pub fn from_slice(s: &[u8]) -> Result<U256, Error> {
203        if s.len() != 32 {
204            return Err(Error::InvalidLength {
205                expected: 32,
206                actual: s.len(),
207            });
208        }
209
210        let mut n = [0; 2];
211        for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
212            n[l] = BigEndian::read_u128(&s[i..]);
213        }
214
215        Ok(U256(n))
216    }
217
218    pub fn to_big_endian(&self, s: &mut [u8]) -> Result<(), Error> {
219        if s.len() != 32 {
220            return Err(Error::InvalidLength {
221                expected: 32,
222                actual: s.len(),
223            });
224        }
225
226        for (l, i) in (0..2).rev().zip((0..2).map(|i| i * 16)) {
227            BigEndian::write_u128(&mut s[i..], self.0[l]);
228        }
229
230        Ok(())
231    }
232
233    #[inline]
234    pub fn zero() -> U256 {
235        U256([0, 0])
236    }
237
238    #[inline]
239    pub fn one() -> U256 {
240        U256([1, 0])
241    }
242
243    /// Produce a random number (mod `modulo`)
244    pub fn random<R: Rng>(rng: &mut R, modulo: &U256) -> U256 {
245        U512::random(rng).divrem(modulo).1
246    }
247
248    pub fn is_zero(&self) -> bool {
249        self.0[0] == 0 && self.0[1] == 0
250    }
251
252    pub fn set_bit(&mut self, n: usize, to: bool) -> bool {
253        if n >= 256 {
254            false
255        } else {
256            let part = n / 128;
257            let bit = n - (128 * part);
258
259            if to {
260                self.0[part] |= 1 << bit;
261            } else {
262                self.0[part] &= !(1 << bit);
263            }
264
265            true
266        }
267    }
268
269    pub fn get_bit(&self, n: usize) -> Option<bool> {
270        if n >= 256 {
271            None
272        } else {
273            let part = n / 128;
274            let bit = n - (128 * part);
275
276            Some(self.0[part] & (1 << bit) > 0)
277        }
278    }
279
280    /// Add `other` to `self` (mod `modulo`)
281    pub fn add(&mut self, other: &U256, modulo: &U256) {
282        add_nocarry(&mut self.0, &other.0);
283
284        if *self >= *modulo {
285            sub_noborrow(&mut self.0, &modulo.0);
286        }
287    }
288
289    /// Subtract `other` from `self` (mod `modulo`)
290    pub fn sub(&mut self, other: &U256, modulo: &U256) {
291        if *self < *other {
292            add_nocarry(&mut self.0, &modulo.0);
293        }
294
295        sub_noborrow(&mut self.0, &other.0);
296    }
297
298    /// Multiply `self` by `other` (mod `modulo`) via the Montgomery
299    /// multiplication method.
300    pub fn mul(&mut self, other: &U256, modulo: &U256, inv: u128) {
301        mul_reduce(&mut self.0, &other.0, &modulo.0, inv);
302
303        if *self >= *modulo {
304            sub_noborrow(&mut self.0, &modulo.0);
305        }
306    }
307
308    /// Turn `self` into its additive inverse (mod `modulo`)
309    pub fn neg(&mut self, modulo: &U256) {
310        if *self > Self::zero() {
311            let mut tmp = modulo.0;
312            sub_noborrow(&mut tmp, &self.0);
313
314            self.0 = tmp;
315        }
316    }
317
318    #[inline]
319    pub fn is_even(&self) -> bool {
320        self.0[0] & 1 == 0
321    }
322
323    /// Turn `self` into its multiplicative inverse (mod `modulo`)
324    pub fn invert(&mut self, modulo: &U256) {
325        // Guajardo Kumar Paar Pelzl
326        // Efficient Software-Implementation of Finite Fields with Applications to Cryptography
327        // Algorithm 16 (BEA for Inversion in Fp)
328
329        let mut u = *self;
330        let mut v = *modulo;
331        let mut b = U256::one();
332        let mut c = U256::zero();
333
334        while u != U256::one() && v != U256::one() {
335            while u.is_even() {
336                div2(&mut u.0);
337
338                if b.is_even() {
339                    div2(&mut b.0);
340                } else {
341                    add_nocarry(&mut b.0, &modulo.0);
342                    div2(&mut b.0);
343                }
344            }
345            while v.is_even() {
346                div2(&mut v.0);
347
348                if c.is_even() {
349                    div2(&mut c.0);
350                } else {
351                    add_nocarry(&mut c.0, &modulo.0);
352                    div2(&mut c.0);
353                }
354            }
355
356            if u >= v {
357                sub_noborrow(&mut u.0, &v.0);
358                b.sub(&c, modulo);
359            } else {
360                sub_noborrow(&mut v.0, &u.0);
361                c.sub(&b, modulo);
362            }
363        }
364
365        if u == U256::one() {
366            self.0 = b.0;
367        } else {
368            self.0 = c.0;
369        }
370    }
371
372    /// Return an Iterator<Item=bool> over all bits from
373    /// MSB to LSB.
374    pub fn bits(&self) -> BitIterator {
375        BitIterator { int: &self, n: 256 }
376    }
377}
378
379pub struct BitIterator<'a> {
380    int: &'a U256,
381    n: usize,
382}
383
384impl<'a> Iterator for BitIterator<'a> {
385    type Item = bool;
386
387    fn next(&mut self) -> Option<bool> {
388        if self.n == 0 {
389            None
390        } else {
391            self.n -= 1;
392
393            self.int.get_bit(self.n)
394        }
395    }
396}
397
398/// Divide by two
399#[inline]
400fn div2(a: &mut [u128; 2]) {
401    let tmp = a[1] << 127;
402    a[1] >>= 1;
403    a[0] >>= 1;
404    a[0] |= tmp;
405}
406
407/// Multiply by two
408#[inline]
409fn mul2(a: &mut [u128; 2]) {
410    let tmp = a[0] >> 127;
411    a[0] <<= 1;
412    a[1] <<= 1;
413    a[1] |= tmp;
414}
415
416#[inline(always)]
417fn split_u128(i: u128) -> (u128, u128) {
418    (i >> 64, i & 0xFFFFFFFFFFFFFFFF)
419}
420
421#[inline(always)]
422fn combine_u128(hi: u128, lo: u128) -> u128 {
423    (hi << 64) | lo
424}
425
426#[inline]
427fn adc(a: u128, b: u128, carry: &mut u128) -> u128 {
428    let (a1, a0) = split_u128(a);
429    let (b1, b0) = split_u128(b);
430    let (c, r0) = split_u128(a0 + b0 + *carry);
431    let (c, r1) = split_u128(a1 + b1 + c);
432    *carry = c;
433
434    combine_u128(r1, r0)
435}
436
437#[inline]
438fn add_nocarry(a: &mut [u128; 2], b: &[u128; 2]) {
439    let mut carry = 0;
440
441    for (a, b) in a.into_iter().zip(b.iter()) {
442        *a = adc(*a, *b, &mut carry);
443    }
444
445    debug_assert!(0 == carry);
446}
447
448#[inline]
449fn sub_noborrow(a: &mut [u128; 2], b: &[u128; 2]) {
450    #[inline]
451    fn sbb(a: u128, b: u128, borrow: &mut u128) -> u128 {
452        let (a1, a0) = split_u128(a);
453        let (b1, b0) = split_u128(b);
454        let (b, r0) = split_u128((1 << 64) + a0 - b0 - *borrow);
455        let (b, r1) = split_u128((1 << 64) + a1 - b1 - ((b == 0) as u128));
456
457        *borrow = (b == 0) as u128;
458
459        combine_u128(r1, r0)
460    }
461
462    let mut borrow = 0;
463
464    for (a, b) in a.into_iter().zip(b.iter()) {
465        *a = sbb(*a, *b, &mut borrow);
466    }
467
468    debug_assert!(0 == borrow);
469}
470
471// TODO: Make `from_index` a const param
472#[inline(always)]
473fn mac_digit(from_index: usize, acc: &mut [u128; 4], b: &[u128; 2], c: u128) {
474    #[inline]
475    fn mac_with_carry(a: u128, b: u128, c: u128, carry: &mut u128) -> u128 {
476        let (b_hi, b_lo) = split_u128(b);
477        let (c_hi, c_lo) = split_u128(c);
478
479        let (a_hi, a_lo) = split_u128(a);
480        let (carry_hi, carry_lo) = split_u128(*carry);
481        let (x_hi, x_lo) = split_u128(b_lo * c_lo + a_lo + carry_lo);
482        let (y_hi, y_lo) = split_u128(b_lo * c_hi);
483        let (z_hi, z_lo) = split_u128(b_hi * c_lo);
484        // Brackets to allow better ILP
485        let (r_hi, r_lo) = split_u128((x_hi + y_lo) + (z_lo + a_hi) + carry_hi);
486
487        *carry = (b_hi * c_hi) + r_hi + y_hi + z_hi;
488
489        combine_u128(r_lo, x_lo)
490    }
491
492    if c == 0 {
493        return;
494    }
495
496    let mut carry = 0;
497
498    debug_assert_eq!(acc.len(), 4);
499    unroll! {
500        for i in 0..2 {
501            let a_index = i + from_index;
502            acc[a_index] = mac_with_carry(acc[a_index], b[i], c, &mut carry);
503        }
504    }
505    unroll! {
506        for i in 0..2 {
507            let a_index = i + from_index + 2;
508            if a_index < 4 {
509                let (a_hi, a_lo) = split_u128(acc[a_index]);
510                let (carry_hi, carry_lo) = split_u128(carry);
511                let (x_hi, x_lo) = split_u128(a_lo + carry_lo);
512                let (r_hi, r_lo) = split_u128(x_hi + a_hi + carry_hi);
513
514                carry = r_hi;
515
516                acc[a_index] = combine_u128(r_lo, x_lo);
517            }
518        }
519    }
520
521    debug_assert!(carry == 0);
522}
523
524#[inline]
525fn mul_reduce(this: &mut [u128; 2], by: &[u128; 2], modulus: &[u128; 2], inv: u128) {
526    // The Montgomery reduction here is based on Algorithm 14.32 in
527    // Handbook of Applied Cryptography
528    // <http://cacr.uwaterloo.ca/hac/about/chap14.pdf>.
529
530    let mut res = [0; 2 * 2];
531    unroll! {
532        for i in 0..2 {
533            mac_digit(i, &mut res, by, this[i]);
534        }
535    }
536
537    unroll! {
538        for i in 0..2 {
539            let k = inv.wrapping_mul(res[i]);
540            mac_digit(i, &mut res, modulus, k);
541        }
542    }
543
544    this.copy_from_slice(&res[2..]);
545}
546
547#[test]
548fn setting_bits() {
549    let rng = &mut ::rand::thread_rng();
550    let modulo = U256::from([0xffffffffffffffff; 4]);
551
552    let a = U256::random(rng, &modulo);
553    let mut e = U256::zero();
554    for (i, b) in a.bits().enumerate() {
555        assert!(e.set_bit(255 - i, b));
556    }
557
558    assert_eq!(a, e);
559}
560
561#[test]
562fn from_slice() {
563    let tst = U256::one();
564    let mut s = [0u8; 32];
565    s[31] = 1;
566
567    let num =
568        U256::from_slice(&s).expect("U256 should initialize ok from slice in `from_slice` test");
569    assert_eq!(num, tst);
570}
571
572#[test]
573fn to_big_endian() {
574    let num = U256::one();
575    let mut s = [0u8; 32];
576
577    num.to_big_endian(&mut s)
578        .expect("U256 should convert to bytes ok in `to_big_endian` test");
579    assert_eq!(
580        s,
581        [
582            0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8,
583            0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 0u8, 1u8,
584        ]
585    );
586}
587
588#[test]
589fn testing_divrem() {
590    let rng = &mut ::rand::thread_rng();
591
592    let modulo = U256::from([
593        0x3c208c16d87cfd47,
594        0x97816a916871ca8d,
595        0xb85045b68181585d,
596        0x30644e72e131a029,
597    ]);
598
599    for _ in 0..100 {
600        let c0 = U256::random(rng, &modulo);
601        let c1 = U256::random(rng, &modulo);
602
603        let c1q_plus_c0 = U512::new(&c1, &c0, &modulo);
604
605        let (new_c1, new_c0) = c1q_plus_c0.divrem(&modulo);
606
607        assert!(c1 == new_c1.unwrap());
608        assert!(c0 == new_c0);
609    }
610
611    {
612        // Modulus should become 1*q + 0
613        let a = U512::from([
614            0x3c208c16d87cfd47,
615            0x97816a916871ca8d,
616            0xb85045b68181585d,
617            0x30644e72e131a029,
618            0,
619            0,
620            0,
621            0,
622        ]);
623
624        let (c1, c0) = a.divrem(&modulo);
625        assert_eq!(c1.unwrap(), U256::one());
626        assert_eq!(c0, U256::zero());
627    }
628
629    {
630        // Modulus squared minus 1 should be (q-1) q + q-1
631        let a = U512::from([
632            0x3b5458a2275d69b0,
633            0xa602072d09eac101,
634            0x4a50189c6d96cadc,
635            0x04689e957a1242c8,
636            0x26edfa5c34c6b38d,
637            0xb00b855116375606,
638            0x599a6f7c0348d21c,
639            0x0925c4b8763cbf9c,
640        ]);
641
642        let (c1, c0) = a.divrem(&modulo);
643        assert_eq!(
644            c1.unwrap(),
645            U256::from([
646                0x3c208c16d87cfd46,
647                0x97816a916871ca8d,
648                0xb85045b68181585d,
649                0x30644e72e131a029
650            ])
651        );
652        assert_eq!(
653            c0,
654            U256::from([
655                0x3c208c16d87cfd46,
656                0x97816a916871ca8d,
657                0xb85045b68181585d,
658                0x30644e72e131a029
659            ])
660        );
661    }
662
663    {
664        // Modulus squared minus 2 should be (q-1) q + q-2
665        let a = U512::from([
666            0x3b5458a2275d69af,
667            0xa602072d09eac101,
668            0x4a50189c6d96cadc,
669            0x04689e957a1242c8,
670            0x26edfa5c34c6b38d,
671            0xb00b855116375606,
672            0x599a6f7c0348d21c,
673            0x0925c4b8763cbf9c,
674        ]);
675
676        let (c1, c0) = a.divrem(&modulo);
677
678        assert_eq!(
679            c1.unwrap(),
680            U256::from([
681                0x3c208c16d87cfd46,
682                0x97816a916871ca8d,
683                0xb85045b68181585d,
684                0x30644e72e131a029
685            ])
686        );
687        assert_eq!(
688            c0,
689            U256::from([
690                0x3c208c16d87cfd45,
691                0x97816a916871ca8d,
692                0xb85045b68181585d,
693                0x30644e72e131a029
694            ])
695        );
696    }
697
698    {
699        // Ridiculously large number should fail
700        let a = U512::from([
701            0xffffffffffffffff,
702            0xffffffffffffffff,
703            0xffffffffffffffff,
704            0xffffffffffffffff,
705            0xffffffffffffffff,
706            0xffffffffffffffff,
707            0xffffffffffffffff,
708            0xffffffffffffffff,
709        ]);
710
711        let (c1, c0) = a.divrem(&modulo);
712        assert!(c1.is_none());
713        assert_eq!(
714            c0,
715            U256::from([
716                0xf32cfc5b538afa88,
717                0xb5e71911d44501fb,
718                0x47ab1eff0a417ff6,
719                0x06d89f71cab8351f
720            ])
721        );
722    }
723
724    {
725        // Modulus squared should fail
726        let a = U512::from([
727            0x3b5458a2275d69b1,
728            0xa602072d09eac101,
729            0x4a50189c6d96cadc,
730            0x04689e957a1242c8,
731            0x26edfa5c34c6b38d,
732            0xb00b855116375606,
733            0x599a6f7c0348d21c,
734            0x0925c4b8763cbf9c,
735        ]);
736
737        let (c1, c0) = a.divrem(&modulo);
738        assert!(c1.is_none());
739        assert_eq!(c0, U256::zero());
740    }
741
742    {
743        // Modulus squared plus one should fail
744        let a = U512::from([
745            0x3b5458a2275d69b2,
746            0xa602072d09eac101,
747            0x4a50189c6d96cadc,
748            0x04689e957a1242c8,
749            0x26edfa5c34c6b38d,
750            0xb00b855116375606,
751            0x599a6f7c0348d21c,
752            0x0925c4b8763cbf9c,
753        ]);
754
755        let (c1, c0) = a.divrem(&modulo);
756        assert!(c1.is_none());
757        assert_eq!(c0, U256::one());
758    }
759
760    {
761        let modulo = U256::from([
762            0x43e1f593f0000001,
763            0x2833e84879b97091,
764            0xb85045b68181585d,
765            0x30644e72e131a029,
766        ]);
767
768        // Fr modulus masked off is valid
769        let a = U512::from([
770            0xffffffffffffffff,
771            0xffffffffffffffff,
772            0xffffffffffffffff,
773            0xffffffffffffffff,
774            0xffffffffffffffff,
775            0xffffffffffffffff,
776            0xffffffffffffffff,
777            0x07ffffffffffffff,
778        ]);
779
780        let (c1, c0) = a.divrem(&modulo);
781
782        assert!(c1.unwrap() < modulo);
783        assert!(c0 < modulo);
784    }
785}