zkhash/merkle_tree/
merkle_tree_f2.rs

1use sha2::{
2    digest::{FixedOutputReset, Output},
3    Digest,
4};
5
6#[derive(Clone, Debug)]
7pub struct MerkleTree<F: Digest + FixedOutputReset + Clone> {
8    hasher: F,
9}
10
11impl<F: Digest + FixedOutputReset + Clone> Default for MerkleTree<F> {
12    fn default() -> Self {
13        MerkleTree { hasher: F::new() }
14    }
15}
16
17impl<F: Digest + FixedOutputReset + Clone> MerkleTree<F> {
18    pub fn new() -> Self {
19        MerkleTree { hasher: F::new() }
20    }
21
22    fn round_up_pow_n(input: usize, n: usize) -> usize {
23        debug_assert!(n >= 1);
24        let mut res = 1;
25        // try powers, starting from n
26        loop {
27            res *= n;
28            if res >= input {
29                break;
30            }
31        }
32        res
33    }
34
35    fn compress(&mut self, input: &[&Output<F>; 2]) -> Output<F> {
36        <F as Digest>::update(&mut self.hasher, input[0]);
37        <F as Digest>::update(&mut self.hasher, input[1]);
38        self.hasher.finalize_reset()
39    }
40
41    pub fn accumulate(&mut self, set: &[Output<F>]) -> Output<F> {
42        let set_size = set.len();
43        let mut bound = Self::round_up_pow_n(set_size, 2);
44        loop {
45            if bound >= 2 {
46                break;
47            }
48            bound *= 2;
49        }
50        let mut nodes: Vec<Output<F>> = Vec::with_capacity(bound);
51        for s in set {
52            nodes.push(s.to_owned());
53        }
54        // pad
55        for _ in nodes.len()..bound {
56            nodes.push(nodes[set_size - 1].to_owned());
57        }
58
59        while nodes.len() > 1 {
60            let new_len = nodes.len() / 2;
61            let mut new_nodes: Vec<Output<F>> = Vec::with_capacity(new_len);
62            for i in (0..nodes.len()).step_by(2) {
63                let inp = [&nodes[i], &nodes[i + 1]];
64                let dig = self.compress(&inp);
65                new_nodes.push(dig);
66            }
67            nodes = new_nodes;
68        }
69        nodes[0].to_owned()
70    }
71}