p3_merkle_tree/
merkle_tree.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15/// A binary Merkle tree for packed data. It has leaves of type `F` and digests of type
16/// `[W; DIGEST_ELEMS]`.
17///
18/// This generally shouldn't be used directly. If you're using a Merkle tree as an MMCS,
19/// see `MerkleTreeMmcs`.
20#[derive(Debug, Serialize, Deserialize)]
21pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
22    pub(crate) leaves: Vec<M>,
23    // Enable serialization for this type whenever the underlying array type supports it (len 1-32).
24    #[serde(bound(serialize = "[W; DIGEST_ELEMS]: Serialize"))]
25    // Enable deserialization for this type whenever the underlying array type supports it (len 1-32).
26    #[serde(bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>"))]
27    pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
28    _phantom: PhantomData<F>,
29}
30
31impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
32    MerkleTree<F, W, M, DIGEST_ELEMS>
33{
34    /// Matrix heights need not be powers of two. However, if the heights of two given matrices
35    /// round up to the same power of two, they must be equal.
36    #[instrument(name = "build merkle tree", level = "debug", skip_all,
37                 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
38    pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
39    where
40        P: PackedValue<Value = F>,
41        PW: PackedValue<Value = W>,
42        H: CryptographicHasher<F, [W; DIGEST_ELEMS]>,
43        H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
44        H: Sync,
45        C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>,
46        C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
47        C: Sync,
48    {
49        assert!(!leaves.is_empty(), "No matrices given?");
50
51        assert_eq!(P::WIDTH, PW::WIDTH, "Packing widths must match");
52
53        let mut leaves_largest_first = leaves
54            .iter()
55            .sorted_by_key(|l| Reverse(l.height()))
56            .peekable();
57
58        // check height property
59        assert!(
60            leaves_largest_first
61                .clone()
62                .map(|m| m.height())
63                .tuple_windows()
64                .all(|(curr, next)| curr == next
65                    || curr.next_power_of_two() != next.next_power_of_two()),
66            "matrix heights that round up to the same power of two must be equal"
67        );
68
69        let max_height = leaves_largest_first.peek().unwrap().height();
70        let tallest_matrices = leaves_largest_first
71            .peeking_take_while(|m| m.height() == max_height)
72            .collect_vec();
73
74        let mut digest_layers = vec![first_digest_layer::<P, PW, H, M, DIGEST_ELEMS>(
75            h,
76            tallest_matrices,
77        )];
78        loop {
79            let prev_layer = digest_layers.last().unwrap().as_slice();
80            if prev_layer.len() == 1 {
81                break;
82            }
83            let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
84
85            // The matrices that get injected at this layer.
86            let matrices_to_inject = leaves_largest_first
87                .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
88                .collect_vec();
89
90            let next_digests = compress_and_inject::<P, PW, H, C, M, DIGEST_ELEMS>(
91                prev_layer,
92                matrices_to_inject,
93                h,
94                c,
95            );
96            digest_layers.push(next_digests);
97        }
98
99        Self {
100            leaves,
101            digest_layers,
102            _phantom: PhantomData,
103        }
104    }
105
106    #[must_use]
107    pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
108    where
109        W: Copy,
110    {
111        self.digest_layers.last().unwrap()[0].into()
112    }
113}
114
115#[instrument(name = "first digest layer", level = "debug", skip_all)]
116fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
117    h: &H,
118    tallest_matrices: Vec<&M>,
119) -> Vec<[PW::Value; DIGEST_ELEMS]>
120where
121    P: PackedValue,
122    PW: PackedValue,
123    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
124    H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
125    H: Sync,
126    M: Matrix<P::Value>,
127{
128    let width = PW::WIDTH;
129    let max_height = tallest_matrices[0].height();
130    // we always want to return an even number of digests, except when it's the root.
131    let max_height_padded = if max_height == 1 {
132        1
133    } else {
134        max_height + max_height % 2
135    };
136
137    let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
138    let mut digests = vec![default_digest; max_height_padded];
139
140    digests[0..max_height]
141        .par_chunks_exact_mut(width)
142        .enumerate()
143        .for_each(|(i, digests_chunk)| {
144            let first_row = i * width;
145            let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
146                tallest_matrices
147                    .iter()
148                    .flat_map(|m| m.vertically_packed_row(first_row)),
149            );
150            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
151                *dst = src;
152            }
153        });
154
155    // If our packing width did not divide max_height, fall back to single-threaded scalar code
156    // for the last bit.
157    #[allow(clippy::needless_range_loop)]
158    for i in (max_height / width * width)..max_height {
159        digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row(i)));
160    }
161
162    // Everything has been initialized so we can safely cast.
163    digests
164}
165
166/// Compress `n` digests from the previous layer into `n/2` digests, while potentially mixing in
167/// some leaf data, if there are input matrices with (padded) height `n/2`.
168fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
169    prev_layer: &[[PW::Value; DIGEST_ELEMS]],
170    matrices_to_inject: Vec<&M>,
171    h: &H,
172    c: &C,
173) -> Vec<[PW::Value; DIGEST_ELEMS]>
174where
175    P: PackedValue,
176    PW: PackedValue,
177    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
178    H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
179    H: Sync,
180    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
181    C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
182    C: Sync,
183    M: Matrix<P::Value>,
184{
185    if matrices_to_inject.is_empty() {
186        return compress::<PW, C, DIGEST_ELEMS>(prev_layer, c);
187    }
188
189    let width = PW::WIDTH;
190    let next_len = matrices_to_inject[0].height();
191    // We always want to return an even number of digests, except when it's the root.
192    let next_len_padded = if prev_layer.len() == 2 {
193        1
194    } else {
195        (prev_layer.len() / 2 + 1) & !1
196    };
197
198    let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
199    let mut next_digests = vec![default_digest; next_len_padded];
200    next_digests[0..next_len]
201        .par_chunks_exact_mut(width)
202        .enumerate()
203        .for_each(|(i, digests_chunk)| {
204            let first_row = i * width;
205            let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
206            let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
207            let mut packed_digest = c.compress([left, right]);
208            let tallest_digest = h.hash_iter(
209                matrices_to_inject
210                    .iter()
211                    .flat_map(|m| m.vertically_packed_row(first_row)),
212            );
213            packed_digest = c.compress([packed_digest, tallest_digest]);
214            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
215                *dst = src;
216            }
217        });
218
219    // If our packing width did not divide next_len, fall back to single-threaded scalar code
220    // for the last bit.
221    for i in (next_len / width * width)..next_len {
222        let left = prev_layer[2 * i];
223        let right = prev_layer[2 * i + 1];
224        let digest = c.compress([left, right]);
225        let rows_digest = h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row(i)));
226        next_digests[i] = c.compress([digest, rows_digest]);
227    }
228
229    // At this point, we've exceeded the height of the matrices to inject, so we continue the
230    // process above except with default_digest in place of an input digest.
231    // We only need go as far as half the length of the previous layer.
232    for i in next_len..(prev_layer.len() / 2) {
233        let left = prev_layer[2 * i];
234        let right = prev_layer[2 * i + 1];
235        let digest = c.compress([left, right]);
236        next_digests[i] = c.compress([digest, default_digest]);
237    }
238
239    next_digests
240}
241
242/// Compress `n` digests from the previous layer into `n/2` digests.
243fn compress<P, C, const DIGEST_ELEMS: usize>(
244    prev_layer: &[[P::Value; DIGEST_ELEMS]],
245    c: &C,
246) -> Vec<[P::Value; DIGEST_ELEMS]>
247where
248    P: PackedValue,
249    C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>,
250    C: PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>,
251    C: Sync,
252{
253    let width = P::WIDTH;
254    // Always return an even number of digests, except when it's the root.
255    let next_len_padded = if prev_layer.len() == 2 {
256        1
257    } else {
258        (prev_layer.len() / 2 + 1) & !1
259    };
260    let next_len = prev_layer.len() / 2;
261
262    let default_digest: [P::Value; DIGEST_ELEMS] = [P::Value::default(); DIGEST_ELEMS];
263    let mut next_digests = vec![default_digest; next_len_padded];
264
265    next_digests[0..next_len]
266        .par_chunks_exact_mut(width)
267        .enumerate()
268        .for_each(|(i, digests_chunk)| {
269            let first_row = i * width;
270            let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
271            let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
272            let packed_digest = c.compress([left, right]);
273            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
274                *dst = src;
275            }
276        });
277
278    // If our packing width did not divide next_len, fall back to single-threaded scalar code
279    // for the last bit.
280    for i in (next_len / width * width)..next_len {
281        let left = prev_layer[2 * i];
282        let right = prev_layer[2 * i + 1];
283        next_digests[i] = c.compress([left, right]);
284    }
285
286    // Everything has been initialized so we can safely cast.
287    next_digests
288}
289
290/// Converts a packed array `[P; N]` into its underlying `P::WIDTH` scalar arrays.
291#[inline]
292fn unpack_array<P: PackedValue, const N: usize>(
293    packed_digest: [P; N],
294) -> impl Iterator<Item = [P::Value; N]> {
295    (0..P::WIDTH).map(move |j| packed_digest.map(|p| p.as_slice()[j]))
296}