1use alloc::vec;
2use alloc::vec::Vec;
3use core::array;
4use core::cmp::Reverse;
5use core::marker::PhantomData;
6
7use itertools::Itertools;
8use p3_field::PackedValue;
9use p3_matrix::Matrix;
10use p3_maybe_rayon::prelude::*;
11use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
12use serde::{Deserialize, Serialize};
13use tracing::instrument;
14
15#[derive(Debug, Serialize, Deserialize)]
21pub struct MerkleTree<F, W, M, const DIGEST_ELEMS: usize> {
22 pub(crate) leaves: Vec<M>,
23 #[serde(bound(serialize = "[W; DIGEST_ELEMS]: Serialize"))]
25 #[serde(bound(deserialize = "[W; DIGEST_ELEMS]: Deserialize<'de>"))]
27 pub(crate) digest_layers: Vec<Vec<[W; DIGEST_ELEMS]>>,
28 _phantom: PhantomData<F>,
29}
30
31impl<F: Clone + Send + Sync, W: Clone, M: Matrix<F>, const DIGEST_ELEMS: usize>
32 MerkleTree<F, W, M, DIGEST_ELEMS>
33{
34 #[instrument(name = "build merkle tree", level = "debug", skip_all,
37 fields(dimensions = alloc::format!("{:?}", leaves.iter().map(|l| l.dimensions()).collect::<Vec<_>>())))]
38 pub fn new<P, PW, H, C>(h: &H, c: &C, leaves: Vec<M>) -> Self
39 where
40 P: PackedValue<Value = F>,
41 PW: PackedValue<Value = W>,
42 H: CryptographicHasher<F, [W; DIGEST_ELEMS]>,
43 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
44 H: Sync,
45 C: PseudoCompressionFunction<[W; DIGEST_ELEMS], 2>,
46 C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
47 C: Sync,
48 {
49 assert!(!leaves.is_empty(), "No matrices given?");
50
51 assert_eq!(P::WIDTH, PW::WIDTH, "Packing widths must match");
52
53 let mut leaves_largest_first = leaves
54 .iter()
55 .sorted_by_key(|l| Reverse(l.height()))
56 .peekable();
57
58 assert!(
60 leaves_largest_first
61 .clone()
62 .map(|m| m.height())
63 .tuple_windows()
64 .all(|(curr, next)| curr == next
65 || curr.next_power_of_two() != next.next_power_of_two()),
66 "matrix heights that round up to the same power of two must be equal"
67 );
68
69 let max_height = leaves_largest_first.peek().unwrap().height();
70 let tallest_matrices = leaves_largest_first
71 .peeking_take_while(|m| m.height() == max_height)
72 .collect_vec();
73
74 let mut digest_layers = vec![first_digest_layer::<P, PW, H, M, DIGEST_ELEMS>(
75 h,
76 tallest_matrices,
77 )];
78 loop {
79 let prev_layer = digest_layers.last().unwrap().as_slice();
80 if prev_layer.len() == 1 {
81 break;
82 }
83 let next_layer_len = (prev_layer.len() / 2).next_power_of_two();
84
85 let matrices_to_inject = leaves_largest_first
87 .peeking_take_while(|m| m.height().next_power_of_two() == next_layer_len)
88 .collect_vec();
89
90 let next_digests = compress_and_inject::<P, PW, H, C, M, DIGEST_ELEMS>(
91 prev_layer,
92 matrices_to_inject,
93 h,
94 c,
95 );
96 digest_layers.push(next_digests);
97 }
98
99 Self {
100 leaves,
101 digest_layers,
102 _phantom: PhantomData,
103 }
104 }
105
106 #[must_use]
107 pub fn root(&self) -> Hash<F, W, DIGEST_ELEMS>
108 where
109 W: Copy,
110 {
111 self.digest_layers.last().unwrap()[0].into()
112 }
113}
114
115#[instrument(name = "first digest layer", level = "debug", skip_all)]
116fn first_digest_layer<P, PW, H, M, const DIGEST_ELEMS: usize>(
117 h: &H,
118 tallest_matrices: Vec<&M>,
119) -> Vec<[PW::Value; DIGEST_ELEMS]>
120where
121 P: PackedValue,
122 PW: PackedValue,
123 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
124 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
125 H: Sync,
126 M: Matrix<P::Value>,
127{
128 let width = PW::WIDTH;
129 let max_height = tallest_matrices[0].height();
130 let max_height_padded = if max_height == 1 {
132 1
133 } else {
134 max_height + max_height % 2
135 };
136
137 let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
138 let mut digests = vec![default_digest; max_height_padded];
139
140 digests[0..max_height]
141 .par_chunks_exact_mut(width)
142 .enumerate()
143 .for_each(|(i, digests_chunk)| {
144 let first_row = i * width;
145 let packed_digest: [PW; DIGEST_ELEMS] = h.hash_iter(
146 tallest_matrices
147 .iter()
148 .flat_map(|m| m.vertically_packed_row(first_row)),
149 );
150 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
151 *dst = src;
152 }
153 });
154
155 #[allow(clippy::needless_range_loop)]
158 for i in (max_height / width * width)..max_height {
159 digests[i] = h.hash_iter(tallest_matrices.iter().flat_map(|m| m.row(i)));
160 }
161
162 digests
164}
165
166fn compress_and_inject<P, PW, H, C, M, const DIGEST_ELEMS: usize>(
169 prev_layer: &[[PW::Value; DIGEST_ELEMS]],
170 matrices_to_inject: Vec<&M>,
171 h: &H,
172 c: &C,
173) -> Vec<[PW::Value; DIGEST_ELEMS]>
174where
175 P: PackedValue,
176 PW: PackedValue,
177 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
178 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
179 H: Sync,
180 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
181 C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
182 C: Sync,
183 M: Matrix<P::Value>,
184{
185 if matrices_to_inject.is_empty() {
186 return compress::<PW, C, DIGEST_ELEMS>(prev_layer, c);
187 }
188
189 let width = PW::WIDTH;
190 let next_len = matrices_to_inject[0].height();
191 let next_len_padded = if prev_layer.len() == 2 {
193 1
194 } else {
195 (prev_layer.len() / 2 + 1) & !1
196 };
197
198 let default_digest: [PW::Value; DIGEST_ELEMS] = [PW::Value::default(); DIGEST_ELEMS];
199 let mut next_digests = vec![default_digest; next_len_padded];
200 next_digests[0..next_len]
201 .par_chunks_exact_mut(width)
202 .enumerate()
203 .for_each(|(i, digests_chunk)| {
204 let first_row = i * width;
205 let left = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
206 let right = array::from_fn(|j| PW::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
207 let mut packed_digest = c.compress([left, right]);
208 let tallest_digest = h.hash_iter(
209 matrices_to_inject
210 .iter()
211 .flat_map(|m| m.vertically_packed_row(first_row)),
212 );
213 packed_digest = c.compress([packed_digest, tallest_digest]);
214 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
215 *dst = src;
216 }
217 });
218
219 for i in (next_len / width * width)..next_len {
222 let left = prev_layer[2 * i];
223 let right = prev_layer[2 * i + 1];
224 let digest = c.compress([left, right]);
225 let rows_digest = h.hash_iter(matrices_to_inject.iter().flat_map(|m| m.row(i)));
226 next_digests[i] = c.compress([digest, rows_digest]);
227 }
228
229 for i in next_len..(prev_layer.len() / 2) {
233 let left = prev_layer[2 * i];
234 let right = prev_layer[2 * i + 1];
235 let digest = c.compress([left, right]);
236 next_digests[i] = c.compress([digest, default_digest]);
237 }
238
239 next_digests
240}
241
242fn compress<P, C, const DIGEST_ELEMS: usize>(
244 prev_layer: &[[P::Value; DIGEST_ELEMS]],
245 c: &C,
246) -> Vec<[P::Value; DIGEST_ELEMS]>
247where
248 P: PackedValue,
249 C: PseudoCompressionFunction<[P::Value; DIGEST_ELEMS], 2>,
250 C: PseudoCompressionFunction<[P; DIGEST_ELEMS], 2>,
251 C: Sync,
252{
253 let width = P::WIDTH;
254 let next_len_padded = if prev_layer.len() == 2 {
256 1
257 } else {
258 (prev_layer.len() / 2 + 1) & !1
259 };
260 let next_len = prev_layer.len() / 2;
261
262 let default_digest: [P::Value; DIGEST_ELEMS] = [P::Value::default(); DIGEST_ELEMS];
263 let mut next_digests = vec![default_digest; next_len_padded];
264
265 next_digests[0..next_len]
266 .par_chunks_exact_mut(width)
267 .enumerate()
268 .for_each(|(i, digests_chunk)| {
269 let first_row = i * width;
270 let left = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k)][j]));
271 let right = array::from_fn(|j| P::from_fn(|k| prev_layer[2 * (first_row + k) + 1][j]));
272 let packed_digest = c.compress([left, right]);
273 for (dst, src) in digests_chunk.iter_mut().zip(unpack_array(packed_digest)) {
274 *dst = src;
275 }
276 });
277
278 for i in (next_len / width * width)..next_len {
281 let left = prev_layer[2 * i];
282 let right = prev_layer[2 * i + 1];
283 next_digests[i] = c.compress([left, right]);
284 }
285
286 next_digests
288}
289
290#[inline]
292fn unpack_array<P: PackedValue, const N: usize>(
293 packed_digest: [P; N],
294) -> impl Iterator<Item = [P::Value; N]> {
295 (0..P::WIDTH).map(move |j| packed_digest.map(|p| p.as_slice()[j]))
296}