p3_merkle_tree/
merkle_tree.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15/// A binary Merkle tree whose leaves are vectors of matrix rows.
16///
17/// * `F` – scalar element type inside each matrix row.
18/// * `W` – scalar element type of every digest word.
19/// * `M` – matrix type. Must implement [`Matrix<F>`].
20/// * `DIGEST_ELEMS` – number of `W` words in one digest.
21///
22/// The tree is **balanced only at the digest layer**.
23/// Leaf matrices may have arbitrary heights as long as any two heights
24/// that round **up** to the same power-of-two are equal.
25///
26/// Use [`root`] to fetch the final digest once the tree is built.
27///
28/// This generally shouldn't be used directly. If you're using a Merkle tree as an MMCS,
29/// see `MerkleTreeMmcs`.
30#[derive(Debug, Serialize, Deserialize)]
31pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
32    /// All leaf matrices in insertion order.
33    ///
34    /// Each matrix contributes rows to one or more digest layers, depending on its height.
35    /// Specifically, only the tallest matrices are included in the first digest layer,
36    /// while shorter matrices are injected into higher digest layers at positions determined
37    /// by their padded heights.
38    ///
39    /// This vector is retained only for inspection or re-opening of the tree; it is not used
40    /// after construction time.
41    pub(crate) leaves: Vec<M>,
42
43    /// All intermediate digest layers, index 0 being the first layer above
44    /// the leaves and the last layer containing exactly one root digest.
45    ///
46    /// Every inner vector holds contiguous digests `[left₀, right₀, left₁,
47    /// right₁, …]`; higher layers refer to these by index.
48    ///
49    /// Serialization requires that `[W; DIGEST_ELEMS]` implements `Serialize` and
50    /// `Deserialize`. This is automatically satisfied when `W` is a fixed-size type.
51    #[serde(
52        bound(serialize = "[W; DIGEST_ELEMS]: Serialize"),
53        bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>")
54    )]
55    pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
56
57    /// Zero-sized marker that binds the generic `F` but occupies no space.
58    _phantom: PhantomData<F>,
59}
60
61impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
62    MerkleTree<F, W, M, DIGEST_ELEMS>
63{
64    /// Build a tree from **one or more matrices**.
65    ///
66    /// * `h` – hashing function used on raw rows.
67    /// * `c` – 2-to-1 compression function used on digests.
68    /// * `leaves` – matrices to commit to. Must be non-empty.
69    ///
70    /// Matrices do **not** need to have power-of-two heights. However, any two matrices
71    /// whose heights **round up** to the same power-of-two must have **equal actual height**.
72    /// This ensures proper balancing when folding digests layer-by-layer.
73    ///
74    /// All matrices are hashed row-by-row with `h`. The resulting digests are
75    /// then folded upwards with `c` until a single root remains.
76    ///
77    /// # Panics
78    /// * If `leaves` is empty.
79    /// * If the packing widths of `P` and `PW` differ.
80    /// * If two leaf heights *round up* to the same power-of-two but are not
81    ///   equal (violates balancing rule).
82    #[instrument(name = "build merkle tree", level = "debug", skip_all,
83                 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
84    pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
85    where
86        P: PackedValue<Value = F>,
87        PW: PackedValue<Value = W>,
88        H: CryptographicHasher<F, [W; DIGEST_ELEMS]>
89            + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
90            + Sync,
91        C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>
92            + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
93            + Sync,
94    {
95        assert!(!leaves.is_empty(), "No matrices given?");
96        const {
97            assert!(P::WIDTH == PW::WIDTH, "Packing widths must match");
98        }
99
100        let mut leaves_largest_first = leaves
101            .iter()
102            .sorted_by_key(|l| Reverse(l.height()))
103            .peekable();
104
105        // check height property
106        assert!(
107            leaves_largest_first
108                .clone()
109                .map(|m| m.height())
110                .tuple_windows()
111                .all(|(curr, next)| curr == next
112                    || curr.next_power_of_two() != next.next_power_of_two()),
113            "matrix heights that round up to the same power of two must be equal"
114        );
115
116        let max_height = leaves_largest_first.peek().unwrap().height();
117        let tallest_matrices = leaves_largest_first
118            .peeking_take_while(|m| m.height() == max_height)
119            .collect_vec();
120
121        let mut digest_layers = vec![first_digest_layer::<P, _, _, _, DIGEST_ELEMS>(
122            h,
123            &tallest_matrices,
124        )];
125        loop {
126            let prev_layer = digest_layers.last().unwrap().as_slice();
127            if prev_layer.len() == 1 {
128                break;
129            }
130            let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
131
132            // The matrices that get injected at this layer.
133            let matrices_to_inject = leaves_largest_first
134                .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
135                .collect_vec();
136
137            let next_digests = compress_and_inject::<P, _, _, _, _, DIGEST_ELEMS>(
138                prev_layer,
139                &matrices_to_inject,
140                h,
141                c,
142            );
143            digest_layers.push(next_digests);
144        }
145
146        Self {
147            leaves,
148            digest_layers,
149            _phantom: PhantomData,
150        }
151    }
152
153    /// Return the root digest of the tree.
154    #[must_use]
155    pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
156    where
157        W: Copy,
158    {
159        self.digest_layers.last().unwrap()[0].into()
160    }
161}
162
163/// Hash every row of the tallest matrices and build the first digest layer.
164///
165/// This function is responsible for creating the first layer of Merkle digests,
166/// starting from raw rows of the tallest matrices. Each row is hashed using the
167/// provided cryptographic hasher `h`. The result is a vector of digests that serve
168/// as the base (leaf-level) nodes for the rest of the Merkle tree.
169///
170/// # Details
171/// - We always return an *even number of digests* (except when height is 1), to
172///   ensure even pairing at higher layers.
173/// - Matrices are "vertically packed" to allow SIMD-friendly parallel hashing,
174///   meaning rows can be processed in batches.
175/// - If the total number of rows isn't a multiple of the SIMD packing width,
176///   the final few rows are handled using a fallback scalar path.
177///
178/// # Arguments
179/// - `h`: Reference to the cryptographic hasher.
180/// - `tallest_matrices`: References to the tallest matrices (all must have same height).
181///
182/// # Returns
183/// A vector of `[PW::Value; DIGEST_ELEMS]`, containing the digests of each row.
184#[instrument(name = "first digest layer", level = "debug", skip_all)]
185fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
186    h: &H,
187    tallest_matrices: &[&M],
188) -> Vec<[PW::Value; DIGEST_ELEMS]>
189where
190    P: PackedValue,
191    PW: PackedValue,
192    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
193        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
194        + Sync,
195    M: Matrix<P::Value>,
196{
197    // The number of rows to pack and hash together in one SIMD batch.
198    let width = PW::WIDTH;
199
200    // Get the height of the tallest matrices (they are guaranteed to be equal).
201    let max_height = tallest_matrices[0].height();
202
203    // Compute the padded height to ensure we end up with an even number of digests.
204    // **Exception:** if there's only 1 row, we keep it as 1.
205    let max_height_padded = if max_height == 1 {
206        1
207    } else {
208        max_height + max_height % 2
209    };
210
211    // Prepare a default digest value to fill unused slots or padding.
212    let default_digest = [PW::Value::default(); DIGEST_ELEMS];
213
214    // Allocate the digest vector with padded size, initialized to default digest.
215    let mut digests = vec![default_digest; max_height_padded];
216
217    // Parallel loop: process complete batches of `width` rows at a time.
218    digests[0..max_height]
219        .par_chunks_exact_mut(width)
220        .enumerate()
221        .for_each(|(i, digests_chunk)| {
222            // Compute the starting row index for this chunk.
223            let first_row = i * width;
224
225            // Collect all vertically packed rows from each matrix at `first_row`.
226            // These packed rows are then hashed together using `h`.
227            let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
228                tallest_matrices
229                    .iter()
230                    .flat_map(|m| m.vertically_packed_row(first_row)),
231            );
232
233            // Unpack the resulting packed digest into individual scalar digests.
234            // Then, assign each to its slot in the current chunk.
235            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
236                *dst = src;
237            }
238        });
239
240    // Handle leftover rows that do not form a full SIMD batch (if any).
241    #[allow(clippy::needless_range_loop)]
242    for i in ((max_height / width) * width)..max_height {
243        unsafe {
244            // Safety: The loop guarantees i < max_height == matrix height.
245            // Use `row_unchecked` to avoid bounds checks for performance.
246            digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row_unchecked(i)));
247        }
248    }
249
250    // Return the final digest vector (now fully populated).
251    digests
252}
253
254/// Fold one digest layer into the next and, when present, mix in rows
255/// taken from smaller matrices whose padded height equals `prev_layer.len()/2`.
256///
257/// Pads the output so its length is even unless it becomes the root.
258fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
259    prev_layer: &[[PW::Value; DIGEST_ELEMS]],
260    matrices_to_inject: &[&M],
261    h: &H,
262    c: &C,
263) -> Vec<[PW::Value; DIGEST_ELEMS]>
264where
265    P: PackedValue,
266    PW: PackedValue,
267    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
268        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
269        + Sync,
270    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
271        + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
272        + Sync,
273    M: Matrix<P::Value>,
274{
275    if matrices_to_inject.is_empty() {
276        return compress::<PW, _, DIGEST_ELEMS>(prev_layer, c);
277    }
278
279    let width = PW::WIDTH;
280    let next_len = matrices_to_inject[0].height();
281    // We always want to return an even number of digests, except when it's the root.
282    let next_len_padded = if prev_layer.len() == 2 {
283        1
284    } else {
285        // Round prev_layer.len() / 2 up to the next even integer.
286        (prev_layer.len() / 2 + 1) & !1
287    };
288
289    let default_digest = [PW::Value::default(); DIGEST_ELEMS];
290    let mut next_digests = vec![default_digest; next_len_padded];
291    next_digests[0..next_len]
292        .par_chunks_exact_mut(width)
293        .enumerate()
294        .for_each(|(i, digests_chunk)| {
295            let first_row = i * width;
296            let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
297            let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
298            let mut packed_digest = c.compress([left, right]);
299            let tallest_digest = h.hash_iter(
300                matrices_to_inject
301                    .iter()
302                    .flat_map(|m| m.vertically_packed_row(first_row)),
303            );
304            packed_digest = c.compress([packed_digest, tallest_digest]);
305            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
306                *dst = src;
307            }
308        });
309
310    // If our packing width did not divide next_len, fall back to single-threaded scalar code
311    // for the last bit.
312    for i in (next_len / width * width)..next_len {
313        let left = prev_layer[2 * i];
314        let right = prev_layer[2 * i + 1];
315        let digest = c.compress([left, right]);
316        let rows_digest = unsafe {
317            // Safety: Clearly i < next_len = m.height().
318            h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row_unchecked(i)))
319        };
320        next_digests[i] = c.compress([digest, rows_digest]);
321    }
322
323    // At this point, we've exceeded the height of the matrices to inject, so we continue the
324    // process above except with default_digest in place of an input digest.
325    // We only need go as far as half the length of the previous layer.
326    for i in next_len..(prev_layer.len() / 2) {
327        let left = prev_layer[2 * i];
328        let right = prev_layer[2 * i + 1];
329        let digest = c.compress([left, right]);
330        next_digests[i] = c.compress([digest, default_digest]);
331    }
332
333    next_digests
334}
335
336/// Pure compression step used when no extra rows are injected.
337///
338/// Takes pairs of digests from `prev_layer`, feeds them to `c`,
339/// and writes the results in order.
340///
341/// Pads with the zero digest so the caller always receives an even-sized
342/// slice, except when the tree has shrunk to its single root.
343fn compress<P, C, const DIGEST_ELEMS: usize>(
344    prev_layer: &[[P::Value; DIGEST_ELEMS]],
345    c: &C,
346) -> Vec<[P::Value; DIGEST_ELEMS]>
347where
348    P: PackedValue,
349    C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
350        + PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
351        + Sync,
352{
353    let width = P::WIDTH;
354    // Always return an even number of digests, except when it's the root.
355    let next_len_padded = if prev_layer.len() == 2 {
356        1
357    } else {
358        // Round prev_layer.len() / 2 up to the next even integer.
359        (prev_layer.len() / 2 + 1) & !1
360    };
361    let next_len = prev_layer.len() / 2;
362
363    let default_digest = [P::Value::default(); DIGEST_ELEMS];
364    let mut next_digests = vec![default_digest; next_len_padded];
365
366    next_digests[0..next_len]
367        .par_chunks_exact_mut(width)
368        .enumerate()
369        .for_each(|(i, digests_chunk)| {
370            let first_row = i * width;
371            let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
372            let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
373            let packed_digest = c.compress([left, right]);
374            for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
375                *dst = src;
376            }
377        });
378
379    // If our packing width did not divide next_len, fall back to single-threaded scalar code
380    // for the last bit.
381    for i in (next_len / width * width)..next_len {
382        let left = prev_layer[2 * i];
383        let right = prev_layer[2 * i + 1];
384        next_digests[i] = c.compress([left, right]);
385    }
386
387    // Everything has been initialized so we can safely cast.
388    next_digests
389}
390
391/// Converts a packed array `[P; N]` into its underlying `P::WIDTH` scalar arrays.
392///
393/// Interprets `[P; N]` as the matrix `[[P::Value; P::WIDTH]; N]`, performs a transpose to
394/// get `[[P::Value; N] P::WIDTH]` and returns these `P::Value` arrays as an iterator.
395#[inline]
396fn unpack_array<P: PackedValue, const N: usize>(
397    packed_digest: [P; N],
398) -> impl Iterator<Item = [P::Value; N]> {
399    (0..P::WIDTH).map(move |j| packed_digest.map(|p| p.as_slice()[j]))
400}
401
402#[cfg(test)]
403mod tests {
404    use p3_symmetric::PseudoCompressionFunction;
405    use rand::rngs::SmallRng;
406    use rand::{Rng, SeedableRng};
407
408    use super::*;
409
410    #[derive(Clone, Copy)]
411    struct DummyCompressionFunction;
412
413    impl PseudoCompressionFunction<[u8; 32], 2> for DummyCompressionFunction {
414        fn compress(&self, input: [[u8; 32]; 2]) -> [u8; 32] {
415            let mut output = [0u8; 32];
416            for (i, o) in output.iter_mut().enumerate() {
417                // Simple XOR-based compression
418                *o = input[0][i] ^ input[1][i];
419            }
420            output
421        }
422    }
423
424    #[test]
425    fn test_compress_even_length() {
426        let prev_layer = [[0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32]];
427        let compressor = DummyCompressionFunction;
428        let expected = vec![
429            [0x03; 32], // 0x01 ^ 0x02
430            [0x07; 32], // 0x03 ^ 0x04
431        ];
432        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
433        assert_eq!(result, expected);
434    }
435
436    #[test]
437    fn test_compress_odd_length() {
438        let prev_layer = [[0x05; 32], [0x06; 32], [0x07; 32]];
439        let compressor = DummyCompressionFunction;
440        let expected = vec![
441            [0x03; 32], // 0x05 ^ 0x06
442            [0x00; 32],
443        ];
444        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
445        assert_eq!(result, expected);
446    }
447
448    #[test]
449    fn test_compress_random_values() {
450        let mut rng = SmallRng::seed_from_u64(1);
451        let prev_layer: Vec<[u8; 32]> = (0..8).map(|_| rng.random()).collect();
452        let compressor = DummyCompressionFunction;
453        let expected: Vec<[u8; 32]> = prev_layer
454            .chunks_exact(2)
455            .map(|pair| {
456                let mut result = [0u8; 32];
457                for (i, r) in result.iter_mut().enumerate() {
458                    *r = pair[0][i] ^ pair[1][i];
459                }
460                result
461            })
462            .collect();
463        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
464        assert_eq!(result, expected);
465    }
466
467    #[test]
468    fn test_compress_root_case_single_pair() {
469        // When `prev_layer.len() == 2` we are at the “root-formation” case and
470        // the function must return exactly one digest.
471        //
472        // 0xAA ^ 0x55 = 0xFF
473        let prev_layer = [[0xAA; 32], [0x55; 32]];
474        let compressor = DummyCompressionFunction;
475        let expected = vec![[0xFF; 32]];
476        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
477        assert_eq!(result, expected);
478    }
479
480    #[test]
481    fn test_compress_non_power_of_two_with_padding() {
482        // The code intentionally pads to the next even length unless the output
483        // would become the root.  With `len() == 6` the output length must be 4
484        // (three real digests plus one zero digest).
485
486        let prev_layer = [
487            [0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32], [0x05; 32], [0x06; 32],
488        ];
489        let compressor = DummyCompressionFunction;
490
491        let mut expected = vec![
492            [0x03; 32], // 01 ^ 02
493            [0x07; 32], // 03 ^ 04
494            [0x03; 32], // 05 ^ 06
495        ];
496        // extra padded digest filled with 0
497        expected.push([0x00; 32]);
498
499        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
500        assert_eq!(result, expected);
501        // also validate the padding branch explicitly
502        assert_eq!(result.len(), 4);
503    }
504
505    #[test]
506    fn test_unpack_array_basic() {
507        // Validate that `unpack_array` emits WIDTH (= 4) scalar arrays in the
508        // right order when the packed words are `[u8; 4]`.
509
510        // Two packed “words”, each four lanes wide
511        let packed: [[u8; 4]; 2] = [
512            [0, 1, 2, 3], // first word
513            [4, 5, 6, 7], // second word
514        ];
515
516        // After unpacking we expect four rows (the width),
517        // each row picking lane *j* from every packed word.
518        let rows: Vec<[u8; 2]> = unpack_array::<[u8; 4], 2>(packed).collect();
519
520        assert_eq!(
521            rows,
522            vec![
523                [0, 4], // lane-0 of both packed words
524                [1, 5], // lane-1
525                [2, 6], // lane-2
526                [3, 7], // lane-3
527            ]
528        );
529    }
530}