1use num_bigint::BigUint;
13use num_integer::Integer;
14use num_traits::{One, Zero};
15use std::ops::{Add, Mul};
16
17#[derive(Debug)]
20pub(super) struct Chain(Vec<BigUint>);
21
22impl Add<BigUint> for Chain {
23 type Output = Self;
24
25 fn add(mut self, k: BigUint) -> Self {
26 self.0.push(k + self.0.last().expect("chain is not empty"));
27 self
28 }
29}
30
31impl Mul<Chain> for Chain {
32 type Output = Self;
33
34 fn mul(mut self, mut other: Chain) -> Self {
35 let last = self.0.last().expect("chain is not empty");
36
37 assert!(other.0.remove(0).is_one());
40
41 for w in other.0.iter_mut() {
42 *w *= last;
43 }
44 self.0.append(&mut other.0);
45
46 self
47 }
48}
49
50pub(super) fn find_shortest_chain(n: BigUint) -> Vec<BigUint> {
51 minchain(n).0
52}
53
54fn minchain(n: BigUint) -> Chain {
55 let log_n = n.bits() - 1;
56 if n == BigUint::one() << log_n {
57 Chain((0..=log_n).map(|i| BigUint::one() << i).collect())
58 } else if n == BigUint::from(3u32) {
59 Chain(vec![BigUint::one(), BigUint::from(2u32), n])
60 } else {
61 let k = &n / (BigUint::one() << (log_n / 2));
65 chain(n, k)
66 }
67}
68
69fn chain(n: BigUint, k: BigUint) -> Chain {
70 let (q, r) = n.div_rem(&k);
71 if r.is_zero() || r.is_one() {
72 minchain(k) * minchain(q) + r
74 } else {
75 chain(k, r.clone()) * minchain(q) + r
76 }
77}
78
79#[cfg(test)]
80mod tests {
81 use num_bigint::BigUint;
82
83 use super::minchain;
84
85 #[test]
86 fn minchain_87() {
87 let chain = minchain(BigUint::from(87u32));
89 assert_eq!(
90 chain.0,
91 vec![
92 BigUint::from(1u32),
93 BigUint::from(2u32),
94 BigUint::from(3u32),
95 BigUint::from(6u32),
96 BigUint::from(7u32),
97 BigUint::from(10u32),
98 BigUint::from(20u32),
99 BigUint::from(40u32),
100 BigUint::from(80u32),
101 BigUint::from(87u32),
102 ]
103 );
104 }
105}