zkhash/merkle_tree/
merkle_tree_fp.rs

1use ark_ff::PrimeField;
2use std::marker::PhantomData;
3
4pub trait MerkleTreeHash<F: PrimeField> {
5    fn compress(&self, input: &[&F]) -> F;
6}
7
8#[derive(Clone, Debug)]
9pub struct MerkleTree<F: PrimeField, P: MerkleTreeHash<F>> {
10    perm: P,
11    field: PhantomData<F>,
12}
13
14impl<F: PrimeField, P: MerkleTreeHash<F>> MerkleTree<F, P> {
15    pub fn new(perm: P) -> Self {
16        MerkleTree {
17            perm,
18            field: PhantomData,
19        }
20    }
21
22    fn round_up_pow_n(input: usize, n: usize) -> usize {
23        debug_assert!(n >= 1);
24        let mut res = 1;
25        // try powers, starting from n
26        loop {
27            res *= n;
28            if res >= input {
29                break;
30            }
31        }
32        res
33    }
34
35    pub fn accumulate(&mut self, set: &[F]) -> F {
36        let set_size = set.len();
37        let mut bound = Self::round_up_pow_n(set_size, 2);
38        loop {
39            if bound >= 2 {
40                break;
41            }
42            bound *= 2;
43        }
44        let mut nodes: Vec<F> = Vec::with_capacity(bound);
45        for s in set {
46            nodes.push(s.to_owned());
47        }
48        // pad
49        for _ in nodes.len()..bound {
50            nodes.push(nodes[set_size - 1].to_owned());
51        }
52
53        while nodes.len() > 1 {
54            let new_len = nodes.len() / 2;
55            let mut new_nodes: Vec<F> = Vec::with_capacity(new_len);
56            for i in (0..nodes.len()).step_by(2) {
57                let inp = [&nodes[i], &nodes[i + 1]];
58                let dig = self.perm.compress(&inp);
59                new_nodes.push(dig);
60            }
61            nodes = new_nodes;
62        }
63        nodes[0].to_owned()
64    }
65}