addchain/
bbbd.rs

1//! The Bergeron-Berstel-Brlek-Duboc algorithm for finding short addition chains.
2//!
3//! References:
4//! - Bergeron, Berstel, Brlek, Duboc.
5//!   ["Addition chains using continued fractions."][BBBD1989]
6//! - [Handbook of Elliptic and Hyperelliptic Curve Cryptography][HEHCC], Chapter 9:
7//!   Exponentiation
8//!
9//! [BBBD1989]: https://doi.org/10.1016/0196-6774(89)90036-9
10//! [HEHCC]: https://www.hyperelliptic.org/HEHCC/index.html
11
12use num_bigint::BigUint;
13use num_integer::Integer;
14use num_traits::{One, Zero};
15use std::ops::{Add, Mul};
16
17/// A wrapper around an addition chain. Addition and multiplication operations are defined
18/// according to the BBBD algorithm.
19#[derive(Debug)]
20pub(super) struct Chain(Vec<BigUint>);
21
22impl Add<BigUint> for Chain {
23    type Output = Self;
24
25    fn add(mut self, k: BigUint) -> Self {
26        self.0.push(k + self.0.last().expect("chain is not empty"));
27        self
28    }
29}
30
31impl Mul<Chain> for Chain {
32    type Output = Self;
33
34    fn mul(mut self, mut other: Chain) -> Self {
35        let last = self.0.last().expect("chain is not empty");
36
37        // The first element of every chain is 1, so we skip it to prevent duplicate
38        // entries in the resulting chain.
39        assert!(other.0.remove(0).is_one());
40
41        for w in other.0.iter_mut() {
42            *w *= last;
43        }
44        self.0.append(&mut other.0);
45
46        self
47    }
48}
49
50pub(super) fn find_shortest_chain(n: BigUint) -> Vec<BigUint> {
51    minchain(n).0
52}
53
54fn minchain(n: BigUint) -> Chain {
55    let log_n = n.bits() - 1;
56    if n == BigUint::one() << log_n {
57        Chain((0..=log_n).map(|i| BigUint::one() << i).collect())
58    } else if n == BigUint::from(3u32) {
59        Chain(vec![BigUint::one(), BigUint::from(2u32), n])
60    } else {
61        // The minchain() algorithm on page 162 of HEHCC indicates that k should be set to
62        // 2^(log(n) / 2) in the call to chain(). This is at odds with the definition of k
63        // at the bottom of page 161; the latter gives the intended result.
64        let k = &n / (BigUint::one() << (log_n / 2));
65        chain(n, k)
66    }
67}
68
69fn chain(n: BigUint, k: BigUint) -> Chain {
70    let (q, r) = n.div_rem(&k);
71    if r.is_zero() || r.is_one() {
72        // We handle the r = 1 case here to prevent unnecessary recursion.
73        minchain(k) * minchain(q) + r
74    } else {
75        chain(k, r.clone()) * minchain(q) + r
76    }
77}
78
79#[cfg(test)]
80mod tests {
81    use num_bigint::BigUint;
82
83    use super::minchain;
84
85    #[test]
86    fn minchain_87() {
87        // Example 9.37 from HEHCC.
88        let chain = minchain(BigUint::from(87u32));
89        assert_eq!(
90            chain.0,
91            vec![
92                BigUint::from(1u32),
93                BigUint::from(2u32),
94                BigUint::from(3u32),
95                BigUint::from(6u32),
96                BigUint::from(7u32),
97                BigUint::from(10u32),
98                BigUint::from(20u32),
99                BigUint::from(40u32),
100                BigUint::from(80u32),
101                BigUint::from(87u32),
102            ]
103        );
104    }
105}