1use alloc::vec::Vec;
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5use itertools::Itertools;
6use p3_commit::Mmcs;
7use p3_field::PackedValue;
8use p3_matrix::{Dimensions, Matrix};
9use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
10use p3_util::log2_ceil_usize;
11use serde::{Deserialize, Serialize};
12
13use crate::MerkleTree;
14use crate::MerkleTreeError::{EmptyBatch, RootMismatch, WrongBatchSize, WrongHeight};
15
16#[derive(Copy, Clone, Debug)]
24pub struct MerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
25 hash: H,
26 compress: C,
27 _phantom: PhantomData<(P, PW)>,
28}
29
30#[derive(Debug)]
31pub enum MerkleTreeError {
32 WrongBatchSize,
33 WrongWidth,
34 WrongHeight {
35 max_height: usize,
36 num_siblings: usize,
37 },
38 RootMismatch,
39 EmptyBatch,
40}
41
42impl<P, PW, H, C, const DIGEST_ELEMS: usize> MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
43 pub const fn new(hash: H, compress: C) -> Self {
44 Self {
45 hash,
46 compress,
47 _phantom: PhantomData,
48 }
49 }
50}
51
52impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Value>
53 for MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
54where
55 P: PackedValue,
56 PW: PackedValue,
57 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
58 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
59 H: Sync,
60 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
61 C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
62 C: Sync,
63 PW::Value: Eq,
64 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
65{
66 type ProverData<M> = MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>;
67 type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
68 type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
69 type Error = MerkleTreeError;
70
71 fn commit<M: Matrix<P::Value>>(
72 &self,
73 inputs: Vec<M>,
74 ) -> (Self::Commitment, Self::ProverData<M>) {
75 let tree = MerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
76 let root = tree.root();
77 (root, tree)
78 }
79
80 fn open_batch<M: Matrix<P::Value>>(
81 &self,
82 index: usize,
83 prover_data: &MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>,
84 ) -> (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>) {
85 let max_height = self.get_max_height(prover_data);
86 let log_max_height = log2_ceil_usize(max_height);
87
88 let openings = prover_data
89 .leaves
90 .iter()
91 .map(|matrix| {
92 let log2_height = log2_ceil_usize(matrix.height());
93 let bits_reduced = log_max_height - log2_height;
94 let reduced_index = index >> bits_reduced;
95 matrix.row(reduced_index).collect()
96 })
97 .collect_vec();
98
99 let proof: Vec<_> = (0..log_max_height)
100 .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
101 .collect();
102
103 (openings, proof)
104 }
105
106 fn get_matrices<'a, M: Matrix<P::Value>>(
107 &self,
108 prover_data: &'a Self::ProverData<M>,
109 ) -> Vec<&'a M> {
110 prover_data.leaves.iter().collect()
111 }
112
113 fn verify_batch(
114 &self,
115 commit: &Self::Commitment,
116 dimensions: &[Dimensions],
117 mut index: usize,
118 opened_values: &[Vec<P::Value>],
119 proof: &Self::Proof,
120 ) -> Result<(), Self::Error> {
121 if dimensions.len() != opened_values.len() {
123 return Err(WrongBatchSize);
124 }
125
126 let Some(max_height) = dimensions.iter().map(|dim| dim.height).max() else {
135 return Err(EmptyBatch);
137 };
138 let log_max_height = log2_ceil_usize(max_height);
139 if proof.len() != log_max_height {
140 return Err(WrongHeight {
141 max_height,
142 num_siblings: proof.len(),
143 });
144 }
145
146 let mut heights_tallest_first = dimensions
147 .iter()
148 .enumerate()
149 .sorted_by_key(|(_, dims)| Reverse(dims.height))
150 .peekable();
151
152 let Some(mut curr_height_padded) = heights_tallest_first
153 .peek()
154 .map(|x| x.1.height.next_power_of_two())
155 else {
156 return Err(EmptyBatch);
158 };
159
160 let mut root = self.hash.hash_iter_slices(
161 heights_tallest_first
162 .peeking_take_while(|(_, dims)| {
163 dims.height.next_power_of_two() == curr_height_padded
164 })
165 .map(|(i, _)| opened_values[i].as_slice()),
166 );
167
168 for &sibling in proof.iter() {
169 let (left, right) = if index & 1 == 0 {
170 (root, sibling)
171 } else {
172 (sibling, root)
173 };
174
175 root = self.compress.compress([left, right]);
176 index >>= 1;
177 curr_height_padded >>= 1;
178
179 let next_height = heights_tallest_first
180 .peek()
181 .map(|(_, dims)| dims.height)
182 .filter(|h| h.next_power_of_two() == curr_height_padded);
183 if let Some(next_height) = next_height {
184 let next_height_openings_digest = self.hash.hash_iter_slices(
185 heights_tallest_first
186 .peeking_take_while(|(_, dims)| dims.height == next_height)
187 .map(|(i, _)| opened_values[i].as_slice()),
188 );
189
190 root = self.compress.compress([root, next_height_openings_digest]);
191 }
192 }
193
194 if commit == &root {
195 Ok(())
196 } else {
197 Err(RootMismatch)
198 }
199 }
200}
201
202#[cfg(test)]
203mod tests {
204 use alloc::vec;
205
206 use itertools::Itertools;
207 use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
208 use p3_commit::Mmcs;
209 use p3_field::{Field, FieldAlgebra};
210 use p3_matrix::dense::RowMajorMatrix;
211 use p3_matrix::{Dimensions, Matrix};
212 use p3_symmetric::{
213 CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
214 };
215 use rand::thread_rng;
216
217 use super::MerkleTreeMmcs;
218
219 type F = BabyBear;
220
221 type Perm = Poseidon2BabyBear<16>;
222 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
223 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
224 type MyMmcs =
225 MerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
226
227 #[test]
228 fn commit_single_1x8() {
229 let perm = Perm::new_from_rng_128(&mut thread_rng());
230 let hash = MyHash::new(perm.clone());
231 let compress = MyCompress::new(perm);
232 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
233
234 let v = vec![
236 F::TWO,
237 F::ONE,
238 F::TWO,
239 F::TWO,
240 F::ZERO,
241 F::ZERO,
242 F::ONE,
243 F::ZERO,
244 ];
245 let (commit, _) = mmcs.commit_vec(v.clone());
246
247 let expected_result = compress.compress([
248 compress.compress([
249 compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
250 compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
251 ]),
252 compress.compress([
253 compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
254 compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
255 ]),
256 ]);
257 assert_eq!(commit, expected_result);
258 }
259
260 #[test]
261 fn commit_single_8x1() {
262 let perm = Perm::new_from_rng_128(&mut thread_rng());
263 let hash = MyHash::new(perm.clone());
264 let compress = MyCompress::new(perm);
265 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
266
267 let mat = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1, 8);
268 let (commit, _) = mmcs.commit(vec![mat.clone()]);
269
270 let expected_result = hash.hash_iter(mat.clone().vertically_packed_row(0));
271 assert_eq!(commit, expected_result);
272 }
273
274 #[test]
275 fn commit_single_2x2() {
276 let perm = Perm::new_from_rng_128(&mut thread_rng());
277 let hash = MyHash::new(perm.clone());
278 let compress = MyCompress::new(perm);
279 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
280
281 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE], 2);
286
287 let (commit, _) = mmcs.commit(vec![mat]);
288
289 let expected_result = compress.compress([
290 hash.hash_slice(&[F::ZERO, F::ONE]),
291 hash.hash_slice(&[F::TWO, F::ONE]),
292 ]);
293 assert_eq!(commit, expected_result);
294 }
295
296 #[test]
297 fn commit_single_2x3() {
298 let perm = Perm::new_from_rng_128(&mut thread_rng());
299 let hash = MyHash::new(perm.clone());
300 let compress = MyCompress::new(perm);
301 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
302 let default_digest = [F::ZERO; 8];
303
304 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE, F::TWO, F::TWO], 2);
310
311 let (commit, _) = mmcs.commit(vec![mat]);
312
313 let expected_result = compress.compress([
314 compress.compress([
315 hash.hash_slice(&[F::ZERO, F::ONE]),
316 hash.hash_slice(&[F::TWO, F::ONE]),
317 ]),
318 compress.compress([hash.hash_slice(&[F::TWO, F::TWO]), default_digest]),
319 ]);
320 assert_eq!(commit, expected_result);
321 }
322
323 #[test]
324 fn commit_mixed() {
325 let perm = Perm::new_from_rng_128(&mut thread_rng());
326 let hash = MyHash::new(perm.clone());
327 let compress = MyCompress::new(perm);
328 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
329 let default_digest = [F::ZERO; 8];
330
331 let mat_1 = RowMajorMatrix::new(
339 vec![
340 F::ZERO,
341 F::ONE,
342 F::TWO,
343 F::ONE,
344 F::TWO,
345 F::TWO,
346 F::TWO,
347 F::ONE,
348 F::TWO,
349 F::TWO,
350 ],
351 2,
352 );
353 let mat_2 = RowMajorMatrix::new(
359 vec![
360 F::ONE,
361 F::TWO,
362 F::ONE,
363 F::ZERO,
364 F::TWO,
365 F::TWO,
366 F::ONE,
367 F::TWO,
368 F::ONE,
369 ],
370 3,
371 );
372
373 let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
374
375 let mat_1_leaf_hashes = [
376 hash.hash_slice(&[F::ZERO, F::ONE]),
377 hash.hash_slice(&[F::TWO, F::ONE]),
378 hash.hash_slice(&[F::TWO, F::TWO]),
379 hash.hash_slice(&[F::TWO, F::ONE]),
380 hash.hash_slice(&[F::TWO, F::TWO]),
381 ];
382 let mat_2_leaf_hashes = [
383 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
384 hash.hash_slice(&[F::ZERO, F::TWO, F::TWO]),
385 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
386 ];
387
388 let expected_result = compress.compress([
389 compress.compress([
390 compress.compress([
391 compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
392 mat_2_leaf_hashes[0],
393 ]),
394 compress.compress([
395 compress.compress([mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]]),
396 mat_2_leaf_hashes[1],
397 ]),
398 ]),
399 compress.compress([
400 compress.compress([
401 compress.compress([mat_1_leaf_hashes[4], default_digest]),
402 mat_2_leaf_hashes[2],
403 ]),
404 default_digest,
405 ]),
406 ]);
407
408 assert_eq!(commit, expected_result);
409
410 let (opened_values, _proof) = mmcs.open_batch(2, &prover_data);
411 assert_eq!(
412 opened_values,
413 vec![vec![F::TWO, F::TWO], vec![F::ZERO, F::TWO, F::TWO]]
414 );
415 }
416
417 #[test]
418 fn commit_either_order() {
419 let mut rng = thread_rng();
420 let perm = Perm::new_from_rng_128(&mut rng);
421 let hash = MyHash::new(perm.clone());
422 let compress = MyCompress::new(perm);
423 let mmcs = MyMmcs::new(hash, compress);
424
425 let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
426 let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
427
428 let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
429 let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
430 assert_eq!(commit_1_2, commit_2_1);
431 }
432
433 #[test]
434 #[should_panic]
435 fn mismatched_heights() {
436 let mut rng = thread_rng();
437 let perm = Perm::new_from_rng_128(&mut rng);
438 let hash = MyHash::new(perm.clone());
439 let compress = MyCompress::new(perm);
440 let mmcs = MyMmcs::new(hash, compress);
441
442 let large_mat = RowMajorMatrix::new(
444 [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
445 1,
446 );
447 let small_mat =
448 RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
449 let _ = mmcs.commit(vec![large_mat, small_mat]);
450 }
451
452 #[test]
453 fn verify_tampered_proof_fails() {
454 let mut rng = thread_rng();
455 let perm = Perm::new_from_rng_128(&mut rng);
456 let hash = MyHash::new(perm.clone());
457 let compress = MyCompress::new(perm);
458 let mmcs = MyMmcs::new(hash, compress);
459
460 let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 1));
462 let large_mat_dims = (0..4).map(|_| Dimensions {
463 height: 8,
464 width: 1,
465 });
466 let small_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 2));
467 let small_mat_dims = (0..4).map(|_| Dimensions {
468 height: 8,
469 width: 2,
470 });
471
472 let (commit, prover_data) = mmcs.commit(large_mats.chain(small_mats).collect_vec());
473
474 let (opened_values, mut proof) = mmcs.open_batch(3, &prover_data);
476 proof[0][0] += F::ONE;
477 mmcs.verify_batch(
478 &commit,
479 &large_mat_dims.chain(small_mat_dims).collect_vec(),
480 3,
481 &opened_values,
482 &proof,
483 )
484 .expect_err("expected verification to fail");
485 }
486
487 #[test]
488 fn size_gaps() {
489 let mut rng = thread_rng();
490 let perm = Perm::new_from_rng_128(&mut rng);
491 let hash = MyHash::new(perm.clone());
492 let compress = MyCompress::new(perm);
493 let mmcs = MyMmcs::new(hash, compress);
494
495 let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1000, 8));
497 let large_mat_dims = (0..4).map(|_| Dimensions {
498 height: 1000,
499 width: 8,
500 });
501
502 let medium_mats = (0..5).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 70, 8));
504 let medium_mat_dims = (0..5).map(|_| Dimensions {
505 height: 70,
506 width: 8,
507 });
508
509 let small_mats = (0..6).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 8));
511 let small_mat_dims = (0..6).map(|_| Dimensions {
512 height: 8,
513 width: 8,
514 });
515
516 let tiny_mats = (0..7).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1, 8));
518 let tiny_mat_dims = (0..7).map(|_| Dimensions {
519 height: 1,
520 width: 8,
521 });
522
523 let (commit, prover_data) = mmcs.commit(
524 large_mats
525 .chain(medium_mats)
526 .chain(small_mats)
527 .chain(tiny_mats)
528 .collect_vec(),
529 );
530
531 let (opened_values, proof) = mmcs.open_batch(6, &prover_data);
533 mmcs.verify_batch(
534 &commit,
535 &large_mat_dims
536 .chain(medium_mat_dims)
537 .chain(small_mat_dims)
538 .chain(tiny_mat_dims)
539 .collect_vec(),
540 6,
541 &opened_values,
542 &proof,
543 )
544 .expect("expected verification to succeed");
545 }
546
547 #[test]
548 fn different_widths() {
549 let mut rng = thread_rng();
550 let perm = Perm::new_from_rng_128(&mut rng);
551 let hash = MyHash::new(perm.clone());
552 let compress = MyCompress::new(perm);
553 let mmcs = MyMmcs::new(hash, compress);
554
555 let mats = (0..10)
557 .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
558 .collect_vec();
559 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
560
561 let (commit, prover_data) = mmcs.commit(mats);
562 let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
563 mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
564 .expect("expected verification to succeed");
565 }
566}