1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15#[derive(Debug, Serialize, Deserialize)]
31pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
32 pub(crate) leaves: Vec<M>,
42
43 #[serde(
52 bound(serialize = "[W; DIGEST_ELEMS]: Serialize"),
53 bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>")
54 )]
55 pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
56
57 _phantom: PhantomData<F>,
59}
60
61impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
62 MerkleTree<F, W, M, DIGEST_ELEMS>
63{
64 #[instrument(name = "build merkle tree", level = "debug", skip_all,
83 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
84 pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
85 where
86 P: PackedValue<Value = F>,
87 PW: PackedValue<Value = W>,
88 H: CryptographicHasher<F, [W; DIGEST_ELEMS]>
89 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
90 + Sync,
91 C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>
92 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
93 + Sync,
94 {
95 assert!(!leaves.is_empty(), "No matrices given?");
96 const {
97 assert!(P::WIDTH == PW::WIDTH, "Packing widths must match");
98 }
99
100 let mut leaves_largest_first = leaves
101 .iter()
102 .sorted_by_key(|l| Reverse(l.height()))
103 .peekable();
104
105 assert!(
107 leaves_largest_first
108 .clone()
109 .map(|m| m.height())
110 .tuple_windows()
111 .all(|(curr, next)| curr == next
112 || curr.next_power_of_two() != next.next_power_of_two()),
113 "matrix heights that round up to the same power of two must be equal"
114 );
115
116 let max_height = leaves_largest_first.peek().unwrap().height();
117 let tallest_matrices = leaves_largest_first
118 .peeking_take_while(|m| m.height() == max_height)
119 .collect_vec();
120
121 let mut digest_layers = vec![first_digest_layer::<P, _, _, _, DIGEST_ELEMS>(
122 h,
123 &tallest_matrices,
124 )];
125 loop {
126 let prev_layer = digest_layers.last().unwrap().as_slice();
127 if prev_layer.len() == 1 {
128 break;
129 }
130 let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
131
132 let matrices_to_inject = leaves_largest_first
134 .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
135 .collect_vec();
136
137 let next_digests = compress_and_inject::<P, _, _, _, _, DIGEST_ELEMS>(
138 prev_layer,
139 &matrices_to_inject,
140 h,
141 c,
142 );
143 digest_layers.push(next_digests);
144 }
145
146 Self {
147 leaves,
148 digest_layers,
149 _phantom: PhantomData,
150 }
151 }
152
153 #[must_use]
155 pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
156 where
157 W: Copy,
158 {
159 self.digest_layers.last().unwrap()[0].into()
160 }
161}
162
163#[instrument(name = "first digest layer", level = "debug", skip_all)]
185fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
186 h: &H,
187 tallest_matrices: &[&M],
188) -> Vec<[PW::Value; DIGEST_ELEMS]>
189where
190 P: PackedValue,
191 PW: PackedValue,
192 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
193 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
194 + Sync,
195 M: Matrix<P::Value>,
196{
197 let width = PW::WIDTH;
199
200 let max_height = tallest_matrices[0].height();
202
203 let max_height_padded = if max_height == 1 {
206 1
207 } else {
208 max_height + max_height % 2
209 };
210
211 let default_digest = [PW::Value::default(); DIGEST_ELEMS];
213
214 let mut digests = vec![default_digest; max_height_padded];
216
217 digests[0..max_height]
219 .par_chunks_exact_mut(width)
220 .enumerate()
221 .for_each(|(i, digests_chunk)| {
222 let first_row = i * width;
224
225 let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
228 tallest_matrices
229 .iter()
230 .flat_map(|m| m.vertically_packed_row(first_row)),
231 );
232
233 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
236 *dst = src;
237 }
238 });
239
240 #[allow(clippy::needless_range_loop)]
242 for i in ((max_height / width) * width)..max_height {
243 unsafe {
244 digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row_unchecked(i)));
247 }
248 }
249
250 digests
252}
253
254fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
259 prev_layer: &[[PW::Value; DIGEST_ELEMS]],
260 matrices_to_inject: &[&M],
261 h: &H,
262 c: &C,
263) -> Vec<[PW::Value; DIGEST_ELEMS]>
264where
265 P: PackedValue,
266 PW: PackedValue,
267 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
268 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
269 + Sync,
270 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
271 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
272 + Sync,
273 M: Matrix<P::Value>,
274{
275 if matrices_to_inject.is_empty() {
276 return compress::<PW, _, DIGEST_ELEMS>(prev_layer, c);
277 }
278
279 let width = PW::WIDTH;
280 let next_len = matrices_to_inject[0].height();
281 let next_len_padded = if prev_layer.len() == 2 {
283 1
284 } else {
285 (prev_layer.len() / 2 + 1) & !1
287 };
288
289 let default_digest = [PW::Value::default(); DIGEST_ELEMS];
290 let mut next_digests = vec![default_digest; next_len_padded];
291 next_digests[0..next_len]
292 .par_chunks_exact_mut(width)
293 .enumerate()
294 .for_each(|(i, digests_chunk)| {
295 let first_row = i * width;
296 let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
297 let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
298 let mut packed_digest = c.compress([left, right]);
299 let tallest_digest = h.hash_iter(
300 matrices_to_inject
301 .iter()
302 .flat_map(|m| m.vertically_packed_row(first_row)),
303 );
304 packed_digest = c.compress([packed_digest, tallest_digest]);
305 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
306 *dst = src;
307 }
308 });
309
310 for i in (next_len / width * width)..next_len {
313 let left = prev_layer[2 * i];
314 let right = prev_layer[2 * i + 1];
315 let digest = c.compress([left, right]);
316 let rows_digest = unsafe {
317 h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row_unchecked(i)))
319 };
320 next_digests[i] = c.compress([digest, rows_digest]);
321 }
322
323 for i in next_len..(prev_layer.len() / 2) {
327 let left = prev_layer[2 * i];
328 let right = prev_layer[2 * i + 1];
329 let digest = c.compress([left, right]);
330 next_digests[i] = c.compress([digest, default_digest]);
331 }
332
333 next_digests
334}
335
336fn compress<P, C, const DIGEST_ELEMS: usize>(
344 prev_layer: &[[P::Value; DIGEST_ELEMS]],
345 c: &C,
346) -> Vec<[P::Value; DIGEST_ELEMS]>
347where
348 P: PackedValue,
349 C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
350 + PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
351 + Sync,
352{
353 let width = P::WIDTH;
354 let next_len_padded = if prev_layer.len() == 2 {
356 1
357 } else {
358 (prev_layer.len() / 2 + 1) & !1
360 };
361 let next_len = prev_layer.len() / 2;
362
363 let default_digest = [P::Value::default(); DIGEST_ELEMS];
364 let mut next_digests = vec![default_digest; next_len_padded];
365
366 next_digests[0..next_len]
367 .par_chunks_exact_mut(width)
368 .enumerate()
369 .for_each(|(i, digests_chunk)| {
370 let first_row = i * width;
371 let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
372 let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
373 let packed_digest = c.compress([left, right]);
374 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
375 *dst = src;
376 }
377 });
378
379 for i in (next_len / width * width)..next_len {
382 let left = prev_layer[2 * i];
383 let right = prev_layer[2 * i + 1];
384 next_digests[i] = c.compress([left, right]);
385 }
386
387 next_digests
389}
390
391#[inline]
396fn unpack_array<P: PackedValue, const N: usize>(
397 packed_digest: [P; N],
398) -> impl Iterator<Item = [P::Value; N]> {
399 (0..P::WIDTH).map(move |j| packed_digest.map(|p| p.as_slice()[j]))
400}
401
402#[cfg(test)]
403mod tests {
404 use p3_symmetric::PseudoCompressionFunction;
405 use rand::rngs::SmallRng;
406 use rand::{Rng, SeedableRng};
407
408 use super::*;
409
410 #[derive(Clone, Copy)]
411 struct DummyCompressionFunction;
412
413 impl PseudoCompressionFunction<[u8; 32], 2> for DummyCompressionFunction {
414 fn compress(&self, input: [[u8; 32]; 2]) -> [u8; 32] {
415 let mut output = [0u8; 32];
416 for (i, o) in output.iter_mut().enumerate() {
417 *o = input[0][i] ^ input[1][i];
419 }
420 output
421 }
422 }
423
424 #[test]
425 fn test_compress_even_length() {
426 let prev_layer = [[0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32]];
427 let compressor = DummyCompressionFunction;
428 let expected = vec![
429 [0x03; 32], [0x07; 32], ];
432 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
433 assert_eq!(result, expected);
434 }
435
436 #[test]
437 fn test_compress_odd_length() {
438 let prev_layer = [[0x05; 32], [0x06; 32], [0x07; 32]];
439 let compressor = DummyCompressionFunction;
440 let expected = vec![
441 [0x03; 32], [0x00; 32],
443 ];
444 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
445 assert_eq!(result, expected);
446 }
447
448 #[test]
449 fn test_compress_random_values() {
450 let mut rng = SmallRng::seed_from_u64(1);
451 let prev_layer: Vec<[u8; 32]> = (0..8).map(|_| rng.random()).collect();
452 let compressor = DummyCompressionFunction;
453 let expected: Vec<[u8; 32]> = prev_layer
454 .chunks_exact(2)
455 .map(|pair| {
456 let mut result = [0u8; 32];
457 for (i, r) in result.iter_mut().enumerate() {
458 *r = pair[0][i] ^ pair[1][i];
459 }
460 result
461 })
462 .collect();
463 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
464 assert_eq!(result, expected);
465 }
466
467 #[test]
468 fn test_compress_root_case_single_pair() {
469 let prev_layer = [[0xAA; 32], [0x55; 32]];
474 let compressor = DummyCompressionFunction;
475 let expected = vec![[0xFF; 32]];
476 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
477 assert_eq!(result, expected);
478 }
479
480 #[test]
481 fn test_compress_non_power_of_two_with_padding() {
482 let prev_layer = [
487 [0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32], [0x05; 32], [0x06; 32],
488 ];
489 let compressor = DummyCompressionFunction;
490
491 let mut expected = vec![
492 [0x03; 32], [0x07; 32], [0x03; 32], ];
496 expected.push([0x00; 32]);
498
499 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
500 assert_eq!(result, expected);
501 assert_eq!(result.len(), 4);
503 }
504
505 #[test]
506 fn test_unpack_array_basic() {
507 let packed: [[u8; 4]; 2] = [
512 [0, 1, 2, 3], [4, 5, 6, 7], ];
515
516 let rows: Vec<[u8; 2]> = unpack_array::<[u8; 4], 2>(packed).collect();
519
520 assert_eq!(
521 rows,
522 vec![
523 [0, 4], [1, 5], [2, 6], [3, 7], ]
528 );
529 }
530}