use alloc::vec;
use alloc::vec::Vec;
use core::array;
use core::cmp::Reverse;
use core::marker::PhantomData;
use itertools::Itertools;
use p3_field::PackedValue;
use p3_matrix::Matrix;
use p3_maybe_rayon::prelude::*;
use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
use serde::{Deserialize, Serialize};
use tracing::instrument;
#[derive(Debug, Serialize, Deserialize)]
pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
pub(crate) leaves: Vec<M>,
#[serde(bound(serialize = "[W; DIGEST_ELEMS]: Serialize"))]
#[serde(bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>"))]
pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
_phantom: PhantomData<F>,
}
impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
MerkleTree<F, W, M, DIGEST_ELEMS>
{
#[instrument(name = "build merkle tree", level = "debug", skip_all,
fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
where
P: PackedValue<Value = F>,
PW: PackedValue<Value = W>,
H: CryptographicHasher<F, [W; DIGEST_ELEMS]>,
H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
H: Sync,
C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>,
C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
C: Sync,
{
assert!(!leaves.is_empty(), "No matrices given?");
assert_eq!(P::WIDTH, PW::WIDTH, "Packing widths must match");
let mut leaves_largest_first = leaves
.iter()
.sorted_by_key(|l| Reverse(l.height()))
.peekable();
assert!(
leaves_largest_first
.clone()
.map(|m| m.height())
.tuple_windows()
.all(|(curr, next)| curr == next
|| curr.next_power_of_two() != next.next_power_of_two()),
"matrix heights that round up to the same power of two must be equal"
);
let max_height = leaves_largest_first.peek().unwrap().height();
let tallest_matrices = leaves_largest_first
.peeking_take_while(|m| m.height() == max_height)
.collect_vec();
let mut digest_layers = vec![first_digest_layer::<P, PW, H, M, DIGEST_ELEMS>(
h,
tallest_matrices,
)];
loop {
let prev_layer = digest_layers.last().unwrap().as_slice();
if prev_layer.len() == 1 {
break;
}
let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
let matrices_to_inject = leaves_largest_first
.peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
.collect_vec();
let next_digests = compress_and_inject::<P, PW, H, C, M, DIGEST_ELEMS>(
prev_layer,
matrices_to_inject,
h,
c,
);
digest_layers.push(next_digests);
}
Self {
leaves,
digest_layers,
_phantom: PhantomData,
}
}
#[must_use]
pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
where
W: Copy,
{
self.digest_layers.last().unwrap()[0].into()
}
}
#[instrument(name = "first digest layer", level = "debug", skip_all)]
fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
h: &H,
tallest_matrices: Vec<&M>,
) -> Vec<[PW::Value; DIGEST_ELEMS]>
where
P: PackedValue,
PW: PackedValue,
H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
H: Sync,
M: Matrix<P::Value>,
{
let width = PW::WIDTH;
let max_height = tallest_matrices[0].height();
let max_height_padded = if max_height == 1 {
1
} else {
max_height + max_height % 2
};
let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
let mut digests = vec![default_digest; max_height_padded];
digests[0..max_height]
.par_chunks_exact_mut(width)
.enumerate()
.for_each(|(i, digests_chunk)| {
let first_row = i * width;
let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
tallest_matrices
.iter()
.flat_map(|m| m.vertically_packed_row(first_row)),
);
for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
*dst = src;
}
});
#[allow(clippy::needless_range_loop)]
for i in (max_height / width * width)..max_height {
digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row(i)));
}
digests
}
fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
prev_layer: &[[PW::Value; DIGEST_ELEMS]],
matrices_to_inject: Vec<&M>,
h: &H,
c: &C,
) -> Vec<[PW::Value; DIGEST_ELEMS]>
where
P: PackedValue,
PW: PackedValue,
H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
H: Sync,
C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
C: Sync,
M: Matrix<P::Value>,
{
if matrices_to_inject.is_empty() {
return compress::<PW, C, DIGEST_ELEMS>(prev_layer, c);
}
let width = PW::WIDTH;
let next_len = matrices_to_inject[0].height();
let next_len_padded = if prev_layer.len() == 2 {
1
} else {
(prev_layer.len() / 2 + 1) & !1
};
let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
let mut next_digests = vec![default_digest; next_len_padded];
next_digests[0..next_len]
.par_chunks_exact_mut(width)
.enumerate()
.for_each(|(i, digests_chunk)| {
let first_row = i * width;
let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
let mut packed_digest = c.compress([left, right]);
let tallest_digest = h.hash_iter(
matrices_to_inject
.iter()
.flat_map(|m| m.vertically_packed_row(first_row)),
);
packed_digest = c.compress([packed_digest, tallest_digest]);
for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
*dst = src;
}
});
for i in (next_len / width * width)..next_len {
let left = prev_layer[2 * i];
let right = prev_layer[2 * i + 1];
let digest = c.compress([left, right]);
let rows_digest = h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row(i)));
next_digests[i] = c.compress([digest, rows_digest]);
}
for i in next_len..(prev_layer.len() / 2) {
let left = prev_layer[2 * i];
let right = prev_layer[2 * i + 1];
let digest = c.compress([left, right]);
next_digests[i] = c.compress([digest, default_digest]);
}
next_digests
}
fn compress<P, C, const DIGEST_ELEMS: usize>(
prev_layer: &[[P::Value; DIGEST_ELEMS]],
c: &C,
) -> Vec<[P::Value; DIGEST_ELEMS]>
where
P: PackedValue,
C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>,
C: PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>,
C: Sync,
{
let width = P::WIDTH;
let next_len_padded = if prev_layer.len() == 2 {
1
} else {
(prev_layer.len() / 2 + 1) & !1
};
let next_len = prev_layer.len() / 2;
let default_digest: [P::Value; DIGEST_ELEMS] = [P::Value::default(); DIGEST_ELEMS];
let mut next_digests = vec![default_digest; next_len_padded];
next_digests[0..next_len]
.par_chunks_exact_mut(width)
.enumerate()
.for_each(|(i, digests_chunk)| {
let first_row = i * width;
let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
let packed_digest = c.compress([left, right]);
for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
*dst = src;
}
});
for i in (next_len / width * width)..next_len {
let left = prev_layer[2 * i];
let right = prev_layer[2 * i + 1];
next_digests[i] = c.compress([left, right]);
}
next_digests
}
#[inline]
fn unpack_array<P: PackedValue, const N: usize>(
packed_digest: [P; N],
) -> impl Iterator<Item = [P::Value; N]> {
(0..P::WIDTH).map(move |j| packed_digest.map(|p| p.as_slice()[j]))
}