p3_merkle_tree/
merkle_tree.rs

1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15/// A binary Merkle tree whose leaves are vectors of matrix rows.
16///
17/// * `F` – scalar element type inside each matrix row.
18/// * `W` – scalar element type of every digest word.
19/// * `M` – matrix type. Must implement [`Matrix<F>`].
20/// * `DIGEST_ELEMS` – number of `W` words in one digest.
21///
22/// The tree is **balanced only at the digest layer**.
23/// Leaf matrices may have arbitrary heights as long as any two heights
24/// that round **up** to the same power-of-two are equal.
25///
26/// Use [`root`] to fetch the final digest once the tree is built.
27///
28/// This generally shouldn't be used directly. If you're using a Merkle tree as an MMCS,
29/// see `MerkleTreeMmcs`.
30#[derive(Debug, Serialize, Deserialize)]
31pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
32    /// All leaf matrices in insertion order.
33    ///
34    /// Each matrix contributes rows to one or more digest layers, depending on its height.
35    /// Specifically, only the tallest matrices are included in the first digest layer,
36    /// while shorter matrices are injected into higher digest layers at positions determined
37    /// by their padded heights.
38    ///
39    /// This vector is retained only for inspection or re-opening of the tree; it is not used
40    /// after construction time.
41    pub(crate) leaves: Vec<M>,
42
43    /// All intermediate digest layers, index 0 being the first layer above
44    /// the leaves and the last layer containing exactly one root digest.
45    ///
46    /// Every inner vector holds contiguous digests `[left₀, right₀, left₁,
47    /// right₁, …]`; higher layers refer to these by index.
48    ///
49    /// Serialization requires that `[W; DIGEST_ELEMS]` implements `Serialize` and
50    /// `Deserialize`. This is automatically satisfied when `W` is a fixed-size type.
51    #[serde(
52        bound(serialize = "[W; DIGEST_ELEMS]: Serialize"),
53        bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>")
54    )]
55    pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
56
57    /// Zero-sized marker that binds the generic `F` but occupies no space.
58    _phantom: PhantomData<F>,
59}
60
61impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
62    MerkleTree<F, W, M, DIGEST_ELEMS>
63{
64    /// Build a tree from **one or more matrices**.
65    ///
66    /// * `h` – hashing function used on raw rows.
67    /// * `c` – 2-to-1 compression function used on digests.
68    /// * `leaves` – matrices to commit to. Must be non-empty.
69    ///
70    /// Matrices do **not** need to have power-of-two heights. However, any two matrices
71    /// whose heights **round up** to the same power-of-two must have **equal actual height**.
72    /// This ensures proper balancing when folding digests layer-by-layer.
73    ///
74    /// All matrices are hashed row-by-row with `h`. The resulting digests are
75    /// then folded upwards with `c` until a single root remains.
76    ///
77    /// # Panics
78    /// * If `leaves` is empty.
79    /// * If the packing widths of `P` and `PW` differ.
80    /// * If two leaf heights *round up* to the same power-of-two but are not
81    ///   equal (violates balancing rule).
82    #[instrument(name = "build merkle tree", level = "debug", skip_all,
83                 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
84    pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
85    where
86        P: PackedValue<Value = F>,
87        PW: PackedValue<Value = W>,
88        H: CryptographicHasher<F, [W; DIGEST_ELEMS]>
89            + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
90            + Sync,
91        C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>
92            + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
93            + Sync,
94    {
95        assert!(!leaves.is_empty(), "No matrices given?");
96        const {
97            assert!(P::WIDTH == PW::WIDTH, "Packing widths must match");
98        }
99
100        let mut leaves_largest_first = leaves
101            .iter()
102            .sorted_by_key(|l| Reverse(l.height()))
103            .peekable();
104
105        // check height property
106        assert!(
107            leaves_largest_first
108                .clone()
109                .map(|m| m.height())
110                .tuple_windows()
111                .all(|(curr, next)| curr == next
112                    || curr.next_power_of_two() != next.next_power_of_two()),
113            "matrix heights that round up to the same power of two must be equal"
114        );
115
116        let max_height = leaves_largest_first.peek().unwrap().height();
117        let tallest_matrices = leaves_largest_first
118            .peeking_take_while(|m| m.height() == max_height)
119            .collect_vec();
120
121        let mut digest_layers = vec![first_digest_layer::<P, _, _, _, DIGEST_ELEMS>(
122            h,
123            &tallest_matrices,
124        )];
125        loop {
126            let prev_layer = digest_layers.last().unwrap().as_slice();
127            if prev_layer.len() == 1 {
128                break;
129            }
130            let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
131
132            // The matrices that get injected at this layer.
133            let matrices_to_inject = leaves_largest_first
134                .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
135                .collect_vec();
136
137            let next_digests = compress_and_inject::<P, _, _, _, _, DIGEST_ELEMS>(
138                prev_layer,
139                &matrices_to_inject,
140                h,
141                c,
142            );
143            digest_layers.push(next_digests);
144        }
145
146        Self {
147            leaves,
148            digest_layers,
149            _phantom: PhantomData,
150        }
151    }
152
153    /// Return the root digest of the tree.
154    #[must_use]
155    pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
156    where
157        W: Copy,
158    {
159        self.digest_layers.last().unwrap()[0].into()
160    }
161}
162
163/// Hash every row of the tallest matrices and build the first digest layer.
164///
165/// This function is responsible for creating the first layer of Merkle digests,
166/// starting from raw rows of the tallest matrices. Each row is hashed using the
167/// provided cryptographic hasher `h`. The result is a vector of digests that serve
168/// as the base (leaf-level) nodes for the rest of the Merkle tree.
169///
170/// # Details
171/// - We always return an *even number of digests* (except when height is 1), to
172///   ensure even pairing at higher layers.
173/// - Matrices are "vertically packed" to allow SIMD-friendly parallel hashing,
174///   meaning rows can be processed in batches.
175/// - If the total number of rows isn't a multiple of the SIMD packing width,
176///   the final few rows are handled using a fallback scalar path.
177///
178/// # Arguments
179/// - `h`: Reference to the cryptographic hasher.
180/// - `tallest_matrices`: References to the tallest matrices (all must have same height).
181///
182/// # Returns
183/// A vector of `[PW::Value; DIGEST_ELEMS]`, containing the digests of each row.
184#[instrument(name = "first digest layer", level = "debug", skip_all)]
185fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
186    h: &H,
187    tallest_matrices: &[&M],
188) -> Vec<[PW::Value; DIGEST_ELEMS]>
189where
190    P: PackedValue,
191    PW: PackedValue,
192    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
193        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
194        + Sync,
195    M: Matrix<P::Value>,
196{
197    // The number of rows to pack and hash together in one SIMD batch.
198    let width = PW::WIDTH;
199
200    // Get the height of the tallest matrices (they are guaranteed to be equal).
201    let max_height = tallest_matrices[0].height();
202
203    // Compute the padded height to ensure we end up with an even number of digests.
204    // **Exception:** if there's only 1 row, we keep it as 1.
205    let max_height_padded = if max_height == 1 {
206        1
207    } else {
208        max_height + max_height % 2
209    };
210
211    // Prepare a default digest value to fill unused slots or padding.
212    let default_digest = [PW::Value::default(); DIGEST_ELEMS];
213
214    // Allocate the digest vector with padded size, initialized to default digest.
215    let mut digests = vec![default_digest; max_height_padded];
216
217    // Parallel loop: process complete batches of `width` rows at a time.
218    digests[0..max_height]
219        .par_chunks_exact_mut(width)
220        .enumerate()
221        .for_each(|(i, digests_chunk)| {
222            // Compute the starting row index for this chunk.
223            let first_row = i * width;
224
225            // Collect all vertically packed rows from each matrix at `first_row`.
226            // These packed rows are then hashed together using `h`.
227            let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
228                tallest_matrices
229                    .iter()
230                    .flat_map(|m| m.vertically_packed_row(first_row)),
231            );
232
233            // Unpack the resulting packed digest into individual scalar digests.
234            PW::unpack_into(&packed_digest, digests_chunk);
235        });
236
237    // Handle leftover rows that do not form a full SIMD batch (if any).
238    #[allow(clippy::needless_range_loop)]
239    for i in ((max_height / width) * width)..max_height {
240        unsafe {
241            // Safety: The loop guarantees i < max_height == matrix height.
242            // Use `row_unchecked` to avoid bounds checks for performance.
243            digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row_unchecked(i)));
244        }
245    }
246
247    // Return the final digest vector (now fully populated).
248    digests
249}
250
251/// Fold one digest layer into the next and, when present, mix in rows
252/// taken from smaller matrices whose padded height equals `prev_layer.len()/2`.
253///
254/// Pads the output so its length is even unless it becomes the root.
255fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
256    prev_layer: &[[PW::Value; DIGEST_ELEMS]],
257    matrices_to_inject: &[&M],
258    h: &H,
259    c: &C,
260) -> Vec<[PW::Value; DIGEST_ELEMS]>
261where
262    P: PackedValue,
263    PW: PackedValue,
264    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
265        + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
266        + Sync,
267    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
268        + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
269        + Sync,
270    M: Matrix<P::Value>,
271{
272    if matrices_to_inject.is_empty() {
273        return compress::<PW, _, DIGEST_ELEMS>(prev_layer, c);
274    }
275
276    let width = PW::WIDTH;
277    let next_len = matrices_to_inject[0].height();
278    // We always want to return an even number of digests, except when it's the root.
279    let next_len_padded = if prev_layer.len() == 2 {
280        1
281    } else {
282        // Round prev_layer.len() / 2 up to the next even integer.
283        (prev_layer.len() / 2 + 1) & !1
284    };
285
286    let default_digest = [PW::Value::default(); DIGEST_ELEMS];
287    let mut next_digests = vec![default_digest; next_len_padded];
288    next_digests[0..next_len]
289        .par_chunks_exact_mut(width)
290        .enumerate()
291        .for_each(|(i, digests_chunk)| {
292            let first_row = i * width;
293            let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
294            let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
295            let mut packed_digest = c.compress([left, right]);
296            let tallest_digest = h.hash_iter(
297                matrices_to_inject
298                    .iter()
299                    .flat_map(|m| m.vertically_packed_row(first_row)),
300            );
301            packed_digest = c.compress([packed_digest, tallest_digest]);
302            PW::unpack_into(&packed_digest, digests_chunk);
303        });
304
305    // If our packing width did not divide next_len, fall back to single-threaded scalar code
306    // for the last bit.
307    for i in (next_len / width * width)..next_len {
308        let left = prev_layer[2 * i];
309        let right = prev_layer[2 * i + 1];
310        let digest = c.compress([left, right]);
311        let rows_digest = unsafe {
312            // Safety: Clearly i < next_len = m.height().
313            h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row_unchecked(i)))
314        };
315        next_digests[i] = c.compress([digest, rows_digest]);
316    }
317
318    // At this point, we've exceeded the height of the matrices to inject, so we continue the
319    // process above except with default_digest in place of an input digest.
320    // We only need go as far as half the length of the previous layer.
321    for i in next_len..(prev_layer.len() / 2) {
322        let left = prev_layer[2 * i];
323        let right = prev_layer[2 * i + 1];
324        let digest = c.compress([left, right]);
325        next_digests[i] = c.compress([digest, default_digest]);
326    }
327
328    next_digests
329}
330
331/// Pure compression step used when no extra rows are injected.
332///
333/// Takes pairs of digests from `prev_layer`, feeds them to `c`,
334/// and writes the results in order.
335///
336/// Pads with the zero digest so the caller always receives an even-sized
337/// slice, except when the tree has shrunk to its single root.
338fn compress<P, C, const DIGEST_ELEMS: usize>(
339    prev_layer: &[[P::Value; DIGEST_ELEMS]],
340    c: &C,
341) -> Vec<[P::Value; DIGEST_ELEMS]>
342where
343    P: PackedValue,
344    C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
345        + PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
346        + Sync,
347{
348    let width = P::WIDTH;
349    // Always return an even number of digests, except when it's the root.
350    let next_len_padded = if prev_layer.len() == 2 {
351        1
352    } else {
353        // Round prev_layer.len() / 2 up to the next even integer.
354        (prev_layer.len() / 2 + 1) & !1
355    };
356    let next_len = prev_layer.len() / 2;
357
358    let default_digest = [P::Value::default(); DIGEST_ELEMS];
359    let mut next_digests = vec![default_digest; next_len_padded];
360
361    next_digests[0..next_len]
362        .par_chunks_exact_mut(width)
363        .enumerate()
364        .for_each(|(i, digests_chunk)| {
365            let first_row = i * width;
366            let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
367            let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
368            let packed_digest = c.compress([left, right]);
369            P::unpack_into(&packed_digest, digests_chunk);
370        });
371
372    // If our packing width did not divide next_len, fall back to single-threaded scalar code
373    // for the last bit.
374    for i in (next_len / width * width)..next_len {
375        let left = prev_layer[2 * i];
376        let right = prev_layer[2 * i + 1];
377        next_digests[i] = c.compress([left, right]);
378    }
379
380    // Everything has been initialized so we can safely cast.
381    next_digests
382}
383
384#[cfg(test)]
385mod tests {
386    use p3_symmetric::PseudoCompressionFunction;
387    use rand::rngs::SmallRng;
388    use rand::{Rng, SeedableRng};
389
390    use super::*;
391
392    #[derive(Clone, Copy)]
393    struct DummyCompressionFunction;
394
395    impl PseudoCompressionFunction<[u8; 32], 2> for DummyCompressionFunction {
396        fn compress(&self, input: [[u8; 32]; 2]) -> [u8; 32] {
397            let mut output = [0u8; 32];
398            for (i, o) in output.iter_mut().enumerate() {
399                // Simple XOR-based compression
400                *o = input[0][i] ^ input[1][i];
401            }
402            output
403        }
404    }
405
406    #[test]
407    fn test_compress_even_length() {
408        let prev_layer = [[0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32]];
409        let compressor = DummyCompressionFunction;
410        let expected = vec![
411            [0x03; 32], // 0x01 ^ 0x02
412            [0x07; 32], // 0x03 ^ 0x04
413        ];
414        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
415        assert_eq!(result, expected);
416    }
417
418    #[test]
419    fn test_compress_odd_length() {
420        let prev_layer = [[0x05; 32], [0x06; 32], [0x07; 32]];
421        let compressor = DummyCompressionFunction;
422        let expected = vec![
423            [0x03; 32], // 0x05 ^ 0x06
424            [0x00; 32],
425        ];
426        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
427        assert_eq!(result, expected);
428    }
429
430    #[test]
431    fn test_compress_random_values() {
432        let mut rng = SmallRng::seed_from_u64(1);
433        let prev_layer: Vec<[u8; 32]> = (0..8).map(|_| rng.random()).collect();
434        let compressor = DummyCompressionFunction;
435        let expected: Vec<[u8; 32]> = prev_layer
436            .chunks_exact(2)
437            .map(|pair| {
438                let mut result = [0u8; 32];
439                for (i, r) in result.iter_mut().enumerate() {
440                    *r = pair[0][i] ^ pair[1][i];
441                }
442                result
443            })
444            .collect();
445        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
446        assert_eq!(result, expected);
447    }
448
449    #[test]
450    fn test_compress_root_case_single_pair() {
451        // When `prev_layer.len() == 2` we are at the “root-formation” case and
452        // the function must return exactly one digest.
453        //
454        // 0xAA ^ 0x55 = 0xFF
455        let prev_layer = [[0xAA; 32], [0x55; 32]];
456        let compressor = DummyCompressionFunction;
457        let expected = vec![[0xFF; 32]];
458        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
459        assert_eq!(result, expected);
460    }
461
462    #[test]
463    fn test_compress_non_power_of_two_with_padding() {
464        // The code intentionally pads to the next even length unless the output
465        // would become the root.  With `len() == 6` the output length must be 4
466        // (three real digests plus one zero digest).
467
468        let prev_layer = [
469            [0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32], [0x05; 32], [0x06; 32],
470        ];
471        let compressor = DummyCompressionFunction;
472
473        let mut expected = vec![
474            [0x03; 32], // 01 ^ 02
475            [0x07; 32], // 03 ^ 04
476            [0x03; 32], // 05 ^ 06
477        ];
478        // extra padded digest filled with 0
479        expected.push([0x00; 32]);
480
481        let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
482        assert_eq!(result, expected);
483        // also validate the padding branch explicitly
484        assert_eq!(result.len(), 4);
485    }
486}