addchain/
lib.rs

1//! *Library for generating addition chains*
2//!
3//! An addition chain `C` for some positive integer `n` is a sequence of integers that
4//! have the following properties:
5//!
6//! - The first integer is 1.
7//! - The last integer is `n`.
8//! - Integers only appear once.
9//! - Every integer is either the sum of two earlier integers, or double an earlier
10//!   integer.
11//!
12//! An addition chain corresponds to a series of `len(C) - 1` primitive operations
13//! (doubling and addition) that can be used to compute a target integer. An *optimal*
14//! addition chain for `n` has the shortest possible length, and therefore requires the
15//! fewest operations to compute `n`. This is particularly useful in cryptographic
16//! algorithms such as modular exponentiation, where `n` is usually at least `2^128`.
17//!
18//! # Example
19//!
20//! To compute the number 87, we can represent it in binary as `1010111`, and then using
21//! the binary double-and-add algorithm (where we double for every bit, and add 1 for
22//! every bit that is set to 1) we have the following steps:
23//! ```text
24//!  i | n_i | Operation | b_i
25//! ---|-----|-----------|-----
26//!  0 |  1  |           |  1
27//!  1 |  2  | n_0 * 2   |  0
28//!  2 |  4  | n_1 * 2   |  1
29//!  3 |  5  | n_2 + n_0 |
30//!  4 | 10  | n_3 * 2   |  0
31//!  5 | 20  | n_4 * 2   |  1
32//!  6 | 21  | n_5 + n_0 |
33//!  7 | 42  | n_6 * 2   |  1
34//!  8 | 43  | n_7 + n_0 |
35//!  9 | 86  | n_8 * 2   |  1
36//! 10 | 87  | n_9 + n_0 |
37//! ```
38//!
39//! This corresponds to the addition chain `[1, 2, 4, 5, 10, 20, 21, 42, 43, 86, 87]`,
40//! which has length 11. However, the optimal addition chain length for 87 is 10, and
41//! several addition chains can be constructed with optimal length. One such chain is
42//! `[1, 2, 3, 6, 7, 10, 20, 40, 80, 87]`, which corresponds to the following steps:
43//! ```text
44//!  i | n_i | Operation
45//! ---|-----|----------
46//!  0 |  1  |
47//!  1 |  2  | n_0 * 2
48//!  2 |  3  | n_1 + n_0
49//!  3 |  6  | n_2 * 2
50//!  4 |  7  | n_3 + n_0
51//!  5 | 10  | n_4 + n_2
52//!  6 | 20  | n_5 * 2
53//!  7 | 40  | n_6 * 2
54//!  8 | 80  | n_7 * 2
55//!  9 | 87  | n_8 + n_4
56//! ```
57//!
58//! # Usage
59//!
60//! ```
61//! use addchain::{build_addition_chain, Step};
62//! use num_bigint::BigUint;
63//!
64//! assert_eq!(
65//!     build_addition_chain(BigUint::from(87u32)),
66//!     vec![
67//!         Step::Double { index: 0 },
68//!         Step::Add { left: 1, right: 0 },
69//!         Step::Double { index: 2 },
70//!         Step::Add { left: 3, right: 0 },
71//!         Step::Add { left: 4, right: 2 },
72//!         Step::Double { index: 5 },
73//!         Step::Double { index: 6 },
74//!         Step::Double { index: 7 },
75//!         Step::Add { left: 8, right: 4 },
76//!     ],
77//! );
78//! ```
79
80use num_bigint::BigUint;
81use num_traits::One;
82
83mod bbbd;
84
85/// The error kinds returned by `addchain` APIs.
86#[derive(Debug, PartialEq)]
87pub enum Error {
88    /// The provided chain is invalid.
89    InvalidChain,
90}
91
92/// Returns the shortest addition chain we can find for the given number, using all
93/// available algorithms.
94pub fn find_shortest_chain(n: BigUint) -> Vec<BigUint> {
95    bbbd::find_shortest_chain(n)
96}
97
98/// A single step in computing an addition chain.
99#[derive(Debug, PartialEq)]
100pub enum Step {
101    Double { index: usize },
102    Add { left: usize, right: usize },
103}
104
105/// Converts an addition chain into a series of steps.
106pub fn build_steps(chain: Vec<BigUint>) -> Result<Vec<Step>, Error> {
107    match chain.get(0) {
108        Some(n) if n.is_one() => (),
109        _ => return Err(Error::InvalidChain),
110    }
111
112    let mut steps = vec![];
113
114    for (i, val) in chain.iter().enumerate().skip(1) {
115        // Find the pair of previous values that add to this one
116        'search: for (j, left) in chain[..i].iter().enumerate() {
117            for (k, right) in chain[..=j].iter().enumerate() {
118                if val == &(left + right) {
119                    // Found the pair!
120                    if j == k {
121                        steps.push(Step::Double { index: j })
122                    } else {
123                        steps.push(Step::Add { left: j, right: k });
124                    }
125                    break 'search;
126                }
127            }
128        }
129
130        // We must always find a matching pair
131        if steps.len() != i {
132            return Err(Error::InvalidChain);
133        }
134    }
135
136    Ok(steps)
137}
138
139/// Generates a series of steps that will compute an addition chain for the given number.
140/// The addition chain is the shortest we can find using all available algorithms.
141pub fn build_addition_chain(n: BigUint) -> Vec<Step> {
142    build_steps(find_shortest_chain(n)).expect("chain is valid")
143}
144
145#[cfg(test)]
146mod tests {
147    use num_bigint::BigUint;
148
149    use super::{build_steps, Error, Step};
150
151    #[test]
152    fn steps_from_valid_chains() {
153        assert_eq!(
154            build_steps(vec![
155                BigUint::from(1u32),
156                BigUint::from(2u32),
157                BigUint::from(3u32),
158            ]),
159            Ok(vec![
160                Step::Double { index: 0 },
161                Step::Add { left: 1, right: 0 }
162            ]),
163        );
164
165        assert_eq!(
166            build_steps(vec![
167                BigUint::from(1u32),
168                BigUint::from(2u32),
169                BigUint::from(4u32),
170                BigUint::from(8u32),
171            ]),
172            Ok(vec![
173                Step::Double { index: 0 },
174                Step::Double { index: 1 },
175                Step::Double { index: 2 },
176            ]),
177        );
178
179        assert_eq!(
180            build_steps(vec![
181                BigUint::from(1u32),
182                BigUint::from(2u32),
183                BigUint::from(3u32),
184                BigUint::from(6u32),
185                BigUint::from(7u32),
186                BigUint::from(10u32),
187                BigUint::from(20u32),
188                BigUint::from(40u32),
189                BigUint::from(80u32),
190                BigUint::from(87u32),
191            ]),
192            Ok(vec![
193                Step::Double { index: 0 },
194                Step::Add { left: 1, right: 0 },
195                Step::Double { index: 2 },
196                Step::Add { left: 3, right: 0 },
197                Step::Add { left: 4, right: 2 },
198                Step::Double { index: 5 },
199                Step::Double { index: 6 },
200                Step::Double { index: 7 },
201                Step::Add { left: 8, right: 4 },
202            ]),
203        );
204    }
205
206    #[test]
207    fn invalid_chains() {
208        // First element is not one.
209        assert_eq!(
210            build_steps(vec![BigUint::from(2u32), BigUint::from(3u32),]),
211            Err(Error::InvalidChain),
212        );
213
214        // Missing an element of a pair.
215        assert_eq!(
216            build_steps(vec![
217                BigUint::from(1u32),
218                BigUint::from(4u32),
219                BigUint::from(8u32),
220            ]),
221            Err(Error::InvalidChain),
222        );
223    }
224}