1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15#[derive(Debug, Serialize, Deserialize)]
31pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
32 pub(crate) leaves: Vec<M>,
42
43 #[serde(
52 bound(serialize = "[W; DIGEST_ELEMS]: Serialize"),
53 bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>")
54 )]
55 pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
56
57 _phantom: PhantomData<F>,
59}
60
61impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
62 MerkleTree<F, W, M, DIGEST_ELEMS>
63{
64 #[instrument(name = "build merkle tree", level = "debug", skip_all,
83 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
84 pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
85 where
86 P: PackedValue<Value = F>,
87 PW: PackedValue<Value = W>,
88 H: CryptographicHasher<F, [W; DIGEST_ELEMS]>
89 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
90 + Sync,
91 C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>
92 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
93 + Sync,
94 {
95 assert!(!leaves.is_empty(), "No matrices given?");
96 const {
97 assert!(P::WIDTH == PW::WIDTH, "Packing widths must match");
98 }
99
100 let mut leaves_largest_first = leaves
101 .iter()
102 .sorted_by_key(|l| Reverse(l.height()))
103 .peekable();
104
105 assert!(
107 leaves_largest_first
108 .clone()
109 .map(|m| m.height())
110 .tuple_windows()
111 .all(|(curr, next)| curr == next
112 || curr.next_power_of_two() != next.next_power_of_two()),
113 "matrix heights that round up to the same power of two must be equal"
114 );
115
116 let max_height = leaves_largest_first.peek().unwrap().height();
117 let tallest_matrices = leaves_largest_first
118 .peeking_take_while(|m| m.height() == max_height)
119 .collect_vec();
120
121 let mut digest_layers = vec![first_digest_layer::<P, _, _, _, DIGEST_ELEMS>(
122 h,
123 &tallest_matrices,
124 )];
125 loop {
126 let prev_layer = digest_layers.last().unwrap().as_slice();
127 if prev_layer.len() == 1 {
128 break;
129 }
130 let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
131
132 let matrices_to_inject = leaves_largest_first
134 .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
135 .collect_vec();
136
137 let next_digests = compress_and_inject::<P, _, _, _, _, DIGEST_ELEMS>(
138 prev_layer,
139 &matrices_to_inject,
140 h,
141 c,
142 );
143 digest_layers.push(next_digests);
144 }
145
146 Self {
147 leaves,
148 digest_layers,
149 _phantom: PhantomData,
150 }
151 }
152
153 #[must_use]
155 pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
156 where
157 W: Copy,
158 {
159 self.digest_layers.last().unwrap()[0].into()
160 }
161}
162
163#[instrument(name = "first digest layer", level = "debug", skip_all)]
185fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
186 h: &H,
187 tallest_matrices: &[&M],
188) -> Vec<[PW::Value; DIGEST_ELEMS]>
189where
190 P: PackedValue,
191 PW: PackedValue,
192 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
193 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
194 + Sync,
195 M: Matrix<P::Value>,
196{
197 let width = PW::WIDTH;
199
200 let max_height = tallest_matrices[0].height();
202
203 let max_height_padded = if max_height == 1 {
206 1
207 } else {
208 max_height + max_height % 2
209 };
210
211 let default_digest = [PW::Value::default(); DIGEST_ELEMS];
213
214 let mut digests = vec![default_digest; max_height_padded];
216
217 digests[0..max_height]
219 .par_chunks_exact_mut(width)
220 .enumerate()
221 .for_each(|(i, digests_chunk)| {
222 let first_row = i * width;
224
225 let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
228 tallest_matrices
229 .iter()
230 .flat_map(|m| m.vertically_packed_row(first_row)),
231 );
232
233 PW::unpack_into(&packed_digest, digests_chunk);
235 });
236
237 #[allow(clippy::needless_range_loop)]
239 for i in ((max_height / width) * width)..max_height {
240 unsafe {
241 digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row_unchecked(i)));
244 }
245 }
246
247 digests
249}
250
251fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
256 prev_layer: &[[PW::Value; DIGEST_ELEMS]],
257 matrices_to_inject: &[&M],
258 h: &H,
259 c: &C,
260) -> Vec<[PW::Value; DIGEST_ELEMS]>
261where
262 P: PackedValue,
263 PW: PackedValue,
264 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
265 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
266 + Sync,
267 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
268 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
269 + Sync,
270 M: Matrix<P::Value>,
271{
272 if matrices_to_inject.is_empty() {
273 return compress::<PW, _, DIGEST_ELEMS>(prev_layer, c);
274 }
275
276 let width = PW::WIDTH;
277 let next_len = matrices_to_inject[0].height();
278 let next_len_padded = if prev_layer.len() == 2 {
280 1
281 } else {
282 (prev_layer.len() / 2 + 1) & !1
284 };
285
286 let default_digest = [PW::Value::default(); DIGEST_ELEMS];
287 let mut next_digests = vec![default_digest; next_len_padded];
288 next_digests[0..next_len]
289 .par_chunks_exact_mut(width)
290 .enumerate()
291 .for_each(|(i, digests_chunk)| {
292 let first_row = i * width;
293 let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
294 let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
295 let mut packed_digest = c.compress([left, right]);
296 let tallest_digest = h.hash_iter(
297 matrices_to_inject
298 .iter()
299 .flat_map(|m| m.vertically_packed_row(first_row)),
300 );
301 packed_digest = c.compress([packed_digest, tallest_digest]);
302 PW::unpack_into(&packed_digest, digests_chunk);
303 });
304
305 for i in (next_len / width * width)..next_len {
308 let left = prev_layer[2 * i];
309 let right = prev_layer[2 * i + 1];
310 let digest = c.compress([left, right]);
311 let rows_digest = unsafe {
312 h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row_unchecked(i)))
314 };
315 next_digests[i] = c.compress([digest, rows_digest]);
316 }
317
318 for i in next_len..(prev_layer.len() / 2) {
322 let left = prev_layer[2 * i];
323 let right = prev_layer[2 * i + 1];
324 let digest = c.compress([left, right]);
325 next_digests[i] = c.compress([digest, default_digest]);
326 }
327
328 next_digests
329}
330
331fn compress<P, C, const DIGEST_ELEMS: usize>(
339 prev_layer: &[[P::Value; DIGEST_ELEMS]],
340 c: &C,
341) -> Vec<[P::Value; DIGEST_ELEMS]>
342where
343 P: PackedValue,
344 C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>
345 + PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>
346 + Sync,
347{
348 let width = P::WIDTH;
349 let next_len_padded = if prev_layer.len() == 2 {
351 1
352 } else {
353 (prev_layer.len() / 2 + 1) & !1
355 };
356 let next_len = prev_layer.len() / 2;
357
358 let default_digest = [P::Value::default(); DIGEST_ELEMS];
359 let mut next_digests = vec![default_digest; next_len_padded];
360
361 next_digests[0..next_len]
362 .par_chunks_exact_mut(width)
363 .enumerate()
364 .for_each(|(i, digests_chunk)| {
365 let first_row = i * width;
366 let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
367 let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
368 let packed_digest = c.compress([left, right]);
369 P::unpack_into(&packed_digest, digests_chunk);
370 });
371
372 for i in (next_len / width * width)..next_len {
375 let left = prev_layer[2 * i];
376 let right = prev_layer[2 * i + 1];
377 next_digests[i] = c.compress([left, right]);
378 }
379
380 next_digests
382}
383
384#[cfg(test)]
385mod tests {
386 use p3_symmetric::PseudoCompressionFunction;
387 use rand::rngs::SmallRng;
388 use rand::{Rng, SeedableRng};
389
390 use super::*;
391
392 #[derive(Clone, Copy)]
393 struct DummyCompressionFunction;
394
395 impl PseudoCompressionFunction<[u8; 32], 2> for DummyCompressionFunction {
396 fn compress(&self, input: [[u8; 32]; 2]) -> [u8; 32] {
397 let mut output = [0u8; 32];
398 for (i, o) in output.iter_mut().enumerate() {
399 *o = input[0][i] ^ input[1][i];
401 }
402 output
403 }
404 }
405
406 #[test]
407 fn test_compress_even_length() {
408 let prev_layer = [[0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32]];
409 let compressor = DummyCompressionFunction;
410 let expected = vec![
411 [0x03; 32], [0x07; 32], ];
414 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
415 assert_eq!(result, expected);
416 }
417
418 #[test]
419 fn test_compress_odd_length() {
420 let prev_layer = [[0x05; 32], [0x06; 32], [0x07; 32]];
421 let compressor = DummyCompressionFunction;
422 let expected = vec![
423 [0x03; 32], [0x00; 32],
425 ];
426 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
427 assert_eq!(result, expected);
428 }
429
430 #[test]
431 fn test_compress_random_values() {
432 let mut rng = SmallRng::seed_from_u64(1);
433 let prev_layer: Vec<[u8; 32]> = (0..8).map(|_| rng.random()).collect();
434 let compressor = DummyCompressionFunction;
435 let expected: Vec<[u8; 32]> = prev_layer
436 .chunks_exact(2)
437 .map(|pair| {
438 let mut result = [0u8; 32];
439 for (i, r) in result.iter_mut().enumerate() {
440 *r = pair[0][i] ^ pair[1][i];
441 }
442 result
443 })
444 .collect();
445 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
446 assert_eq!(result, expected);
447 }
448
449 #[test]
450 fn test_compress_root_case_single_pair() {
451 let prev_layer = [[0xAA; 32], [0x55; 32]];
456 let compressor = DummyCompressionFunction;
457 let expected = vec![[0xFF; 32]];
458 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
459 assert_eq!(result, expected);
460 }
461
462 #[test]
463 fn test_compress_non_power_of_two_with_padding() {
464 let prev_layer = [
469 [0x01; 32], [0x02; 32], [0x03; 32], [0x04; 32], [0x05; 32], [0x06; 32],
470 ];
471 let compressor = DummyCompressionFunction;
472
473 let mut expected = vec![
474 [0x03; 32], [0x07; 32], [0x03; 32], ];
478 expected.push([0x00; 32]);
480
481 let result = compress::<u8, DummyCompressionFunction, 32>(&prev_layer, &compressor);
482 assert_eq!(result, expected);
483 assert_eq!(result.len(), 4);
485 }
486}