zkhash/merkle_tree/
merkle_tree_fp.rs
1use ark_ff::PrimeField;
2use std::marker::PhantomData;
3
4pub trait MerkleTreeHash<F: PrimeField> {
5 fn compress(&self, input: &[&F]) -> F;
6}
7
8#[derive(Clone, Debug)]
9pub struct MerkleTree<F: PrimeField, P: MerkleTreeHash<F>> {
10 perm: P,
11 field: PhantomData<F>,
12}
13
14impl<F: PrimeField, P: MerkleTreeHash<F>> MerkleTree<F, P> {
15 pub fn new(perm: P) -> Self {
16 MerkleTree {
17 perm,
18 field: PhantomData,
19 }
20 }
21
22 fn round_up_pow_n(input: usize, n: usize) -> usize {
23 debug_assert!(n >= 1);
24 let mut res = 1;
25 loop {
27 res *= n;
28 if res >= input {
29 break;
30 }
31 }
32 res
33 }
34
35 pub fn accumulate(&mut self, set: &[F]) -> F {
36 let set_size = set.len();
37 let mut bound = Self::round_up_pow_n(set_size, 2);
38 loop {
39 if bound >= 2 {
40 break;
41 }
42 bound *= 2;
43 }
44 let mut nodes: Vec<F> = Vec::with_capacity(bound);
45 for s in set {
46 nodes.push(s.to_owned());
47 }
48 for _ in nodes.len()..bound {
50 nodes.push(nodes[set_size - 1].to_owned());
51 }
52
53 while nodes.len() > 1 {
54 let new_len = nodes.len() / 2;
55 let mut new_nodes: Vec<F> = Vec::with_capacity(new_len);
56 for i in (0..nodes.len()).step_by(2) {
57 let inp = [&nodes[i], &nodes[i + 1]];
58 let dig = self.perm.compress(&inp);
59 new_nodes.push(dig);
60 }
61 nodes = new_nodes;
62 }
63 nodes[0].to_owned()
64 }
65}