zkhash/merkle_tree/
merkle_tree_sapling.rs

1use jubjub::Base;
2
3type F = Base;
4
5pub trait MerkleTreeHash {
6    fn compress(&self, level: usize, input: &[&F; 2]) -> F;
7}
8
9#[derive(Clone, Debug)]
10pub struct MerkleTree<P: MerkleTreeHash> {
11    perm: P,
12}
13
14impl<P: MerkleTreeHash> MerkleTree<P> {
15    pub fn new(perm: P) -> Self {
16        MerkleTree { perm }
17    }
18
19    fn round_up_pow_n(input: usize, n: usize) -> usize {
20        debug_assert!(n >= 1);
21        let mut res = 1;
22        // try powers, starting from n
23        loop {
24            res *= n;
25            if res >= input {
26                break;
27            }
28        }
29        res
30    }
31
32    pub fn accumulate(&mut self, set: &[F]) -> F {
33        let set_size = set.len();
34        let mut bound = Self::round_up_pow_n(set_size, 2);
35        loop {
36            if bound >= 2 {
37                break;
38            }
39            bound *= 2;
40        }
41        let mut nodes: Vec<F> = Vec::with_capacity(bound);
42        for s in set {
43            nodes.push(s.to_owned());
44        }
45        // pad
46        for _ in nodes.len()..bound {
47            nodes.push(nodes[set_size - 1].to_owned());
48        }
49
50        let mut lv = 0;
51        while nodes.len() > 1 {
52            let new_len = nodes.len() / 2;
53            let mut new_nodes: Vec<F> = Vec::with_capacity(new_len);
54            for i in (0..nodes.len()).step_by(2) {
55                let inp = [&nodes[i], &nodes[i + 1]];
56                let dig = self.perm.compress(lv, &inp);
57                new_nodes.push(dig);
58            }
59            lv += 1;
60            nodes = new_nodes;
61        }
62        nodes[0].to_owned()
63    }
64}