1use alloc::vec::Vec;
23use core::cmp::Reverse;
24use core::marker::PhantomData;
25
26use itertools::Itertools;
27use p3_commit::{BatchOpening, BatchOpeningRef, Mmcs};
28use p3_field::PackedValue;
29use p3_matrix::{Dimensions, Matrix};
30use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
31use p3_util::{log2_ceil_usize, log2_strict_usize};
32use serde::{Deserialize, Serialize};
33use thiserror::Error;
34
35use crate::MerkleTree;
36use crate::MerkleTreeError::{
37 EmptyBatch, IncompatibleHeights, IndexOutOfBounds, RootMismatch, WrongBatchSize, WrongHeight,
38};
39
40#[derive(Copy, Clone, Debug)]
53pub struct MerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
54 hash: H,
56
57 compress: C,
59
60 _phantom: PhantomData<(P, PW)>,
62}
63
64#[derive(Debug, Error)]
66pub enum MerkleTreeError {
67 #[error("wrong batch size: number of openings does not match expected")]
69 WrongBatchSize,
70
71 #[error("wrong width: matrix has a different width than expected")]
73 WrongWidth,
74
75 #[error("wrong height: expected log_max_height {log_max_height}, got {num_siblings} siblings")]
77 WrongHeight {
78 log_max_height: usize,
80
81 num_siblings: usize,
83 },
84
85 #[error("incompatible heights: matrices cannot share a common binary Merkle tree")]
87 IncompatibleHeights,
88
89 #[error("index out of bounds: index {index} exceeds max height {max_height}")]
91 IndexOutOfBounds {
92 max_height: usize,
94 index: usize,
96 },
97
98 #[error("root mismatch: computed Merkle root does not match commitment")]
100 RootMismatch,
101
102 #[error("empty batch: attempted to open an empty batch with no committed matrices")]
104 EmptyBatch,
105}
106
107impl<P, PW, H, C, const DIGEST_ELEMS: usize> MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
108 pub const fn new(hash: H, compress: C) -> Self {
110 Self {
111 hash,
112 compress,
113 _phantom: PhantomData,
114 }
115 }
116}
117
118impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Value>
119 for MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
120where
121 P: PackedValue,
122 PW: PackedValue,
123 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>
124 + CryptographicHasher<P, [PW; DIGEST_ELEMS]>
125 + Sync,
126 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>
127 + PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>
128 + Sync,
129 PW::Value: Eq,
130 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
131{
132 type ProverData<M> = MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>;
133 type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
134 type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
135 type Error = MerkleTreeError;
136
137 fn commit<M: Matrix<P::Value>>(
138 &self,
139 inputs: Vec<M>,
140 ) -> (Self::Commitment, Self::ProverData<M>) {
141 if let Some(max_height) = inputs.iter().map(|m| m.height()).max()
142 && max_height > 0
143 {
144 let log_max_height = log2_ceil_usize(max_height);
145 for matrix in &inputs {
148 let height = matrix.height();
149 assert!(height > 0, "matrix height 0 not supported");
150
151 let log_height = log2_ceil_usize(height);
152 let bits_reduced = log_max_height - log_height;
153 let expected_height = ((max_height - 1) >> bits_reduced) + 1;
155
156 assert!(
157 height == expected_height,
158 "matrix height {height} incompatible with tallest height {max_height}; \
159 expected ceil_div({max_height}, 2^{bits_reduced}) = {expected_height} \
160 so every global index maps to a row at depth {bits_reduced}"
161 );
162 }
163 } else {
164 panic!("all matrices have height 0");
165 }
166
167 let tree = MerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
168 let root = tree.root();
169 (root, tree)
170 }
171
172 fn open_batch<M: Matrix<P::Value>>(
180 &self,
181 index: usize,
182 prover_data: &MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>,
183 ) -> BatchOpening<P::Value, Self> {
184 let max_height = self.get_max_height(prover_data);
185 assert!(
186 index < max_height,
187 "index {index} out of bounds for height {max_height}"
188 );
189 let log_max_height = log2_ceil_usize(max_height);
190
191 let openings = prover_data
193 .leaves
194 .iter()
195 .map(|matrix| {
196 let log2_height = log2_ceil_usize(matrix.height());
197 let bits_reduced = log_max_height - log2_height;
198 let reduced_index = index >> bits_reduced;
199 matrix.row(reduced_index).unwrap().into_iter().collect()
200 })
201 .collect_vec();
202
203 let proof = (0..log_max_height)
205 .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
206 .collect();
207
208 BatchOpening::new(openings, proof)
209 }
210
211 fn get_matrices<'a, M: Matrix<P::Value>>(
212 &self,
213 prover_data: &'a Self::ProverData<M>,
214 ) -> Vec<&'a M> {
215 prover_data.leaves.iter().collect()
216 }
217
218 fn verify_batch(
231 &self,
232 commit: &Self::Commitment,
233 dimensions: &[Dimensions],
234 mut index: usize,
235 batch_proof: BatchOpeningRef<'_, P::Value, Self>,
236 ) -> Result<(), Self::Error> {
237 let (opened_values, opening_proof) = batch_proof.unpack();
238 if dimensions.len() != opened_values.len() {
240 return Err(WrongBatchSize);
241 }
242
243 let mut heights_tallest_first = dimensions
251 .iter()
252 .enumerate()
253 .sorted_by_key(|(_, dims)| Reverse(dims.height))
254 .peekable();
255
256 if !heights_tallest_first
258 .clone()
259 .map(|(_, dims)| dims.height)
260 .tuple_windows()
261 .all(|(curr, next)| {
262 curr == next || curr.next_power_of_two() != next.next_power_of_two()
263 })
264 {
265 return Err(IncompatibleHeights);
266 }
267
268 let (max_height, mut curr_height_padded) = match heights_tallest_first.peek() {
274 Some((_, dims)) => {
275 let max_height = dims.height;
276 let curr_height_padded = max_height.next_power_of_two();
277 let log_max_height = log2_strict_usize(curr_height_padded);
278 if opening_proof.len() != log_max_height {
279 return Err(WrongHeight {
280 log_max_height,
281 num_siblings: opening_proof.len(),
282 });
283 }
284 (max_height, curr_height_padded)
285 }
286 None => return Err(EmptyBatch),
287 };
288
289 if index >= max_height {
290 return Err(IndexOutOfBounds { max_height, index });
291 }
292
293 let mut root = self.hash.hash_iter_slices(
295 heights_tallest_first
296 .peeking_take_while(|(_, dims)| {
297 dims.height.next_power_of_two() == curr_height_padded
298 })
299 .map(|(i, _)| opened_values[i].as_slice()),
300 );
301
302 for &sibling in opening_proof {
303 let (left, right) = if index & 1 == 0 {
305 (root, sibling)
306 } else {
307 (sibling, root)
308 };
309
310 root = self.compress.compress([left, right]);
312 index >>= 1;
313 curr_height_padded >>= 1;
314
315 let next_height = heights_tallest_first
317 .peek()
318 .map(|(_, dims)| dims.height)
319 .filter(|h| h.next_power_of_two() == curr_height_padded);
320 if let Some(next_height) = next_height {
321 let next_height_openings_digest = self.hash.hash_iter_slices(
323 heights_tallest_first
324 .peeking_take_while(|(_, dims)| dims.height == next_height)
325 .map(|(i, _)| opened_values[i].as_slice()),
326 );
327
328 root = self.compress.compress([root, next_height_openings_digest]);
329 }
330 }
331
332 if commit == &root {
334 Ok(())
335 } else {
336 Err(RootMismatch)
337 }
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use alloc::vec;
344
345 use itertools::Itertools;
346 use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
347 use p3_commit::Mmcs;
348 use p3_field::{Field, PrimeCharacteristicRing};
349 use p3_matrix::dense::RowMajorMatrix;
350 use p3_matrix::{Dimensions, Matrix};
351 use p3_symmetric::{
352 CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
353 };
354 use rand::SeedableRng;
355 use rand::rngs::SmallRng;
356
357 use super::MerkleTreeMmcs;
358
359 type F = BabyBear;
360
361 type Perm = Poseidon2BabyBear<16>;
362 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
363 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
364 type MyMmcs =
365 MerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
366
367 #[test]
368 fn commit_single_1x8() {
369 let mut rng = SmallRng::seed_from_u64(1);
370 let perm = Perm::new_from_rng_128(&mut rng);
371 let hash = MyHash::new(perm.clone());
372 let compress = MyCompress::new(perm);
373 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
374
375 let v = vec![
377 F::TWO,
378 F::ONE,
379 F::TWO,
380 F::TWO,
381 F::ZERO,
382 F::ZERO,
383 F::ONE,
384 F::ZERO,
385 ];
386 let (commit, _) = mmcs.commit_vec(v.clone());
387
388 let expected_result = compress.compress([
389 compress.compress([
390 compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
391 compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
392 ]),
393 compress.compress([
394 compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
395 compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
396 ]),
397 ]);
398 assert_eq!(commit, expected_result);
399 }
400
401 #[test]
402 fn commit_single_8x1() {
403 let mut rng = SmallRng::seed_from_u64(1);
404 let perm = Perm::new_from_rng_128(&mut rng);
405 let hash = MyHash::new(perm.clone());
406 let compress = MyCompress::new(perm);
407 let mmcs = MyMmcs::new(hash.clone(), compress);
408
409 let mat = RowMajorMatrix::<F>::rand(&mut rng, 1, 8);
410 let (commit, _) = mmcs.commit(vec![mat.clone()]);
411
412 let expected_result = hash.hash_iter(mat.vertically_packed_row(0));
413 assert_eq!(commit, expected_result);
414 }
415
416 #[test]
417 fn commit_single_2x2() {
418 let mut rng = SmallRng::seed_from_u64(1);
419 let perm = Perm::new_from_rng_128(&mut rng);
420 let hash = MyHash::new(perm.clone());
421 let compress = MyCompress::new(perm);
422 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
423
424 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE], 2);
429
430 let (commit, _) = mmcs.commit(vec![mat]);
431
432 let expected_result = compress.compress([
433 hash.hash_slice(&[F::ZERO, F::ONE]),
434 hash.hash_slice(&[F::TWO, F::ONE]),
435 ]);
436 assert_eq!(commit, expected_result);
437 }
438
439 #[test]
440 fn commit_single_2x3() {
441 let mut rng = SmallRng::seed_from_u64(1);
442 let perm = Perm::new_from_rng_128(&mut rng);
443 let hash = MyHash::new(perm.clone());
444 let compress = MyCompress::new(perm);
445 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
446 let default_digest = [F::ZERO; 8];
447
448 let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE, F::TWO, F::TWO], 2);
454
455 let (commit, _) = mmcs.commit(vec![mat]);
456
457 let expected_result = compress.compress([
458 compress.compress([
459 hash.hash_slice(&[F::ZERO, F::ONE]),
460 hash.hash_slice(&[F::TWO, F::ONE]),
461 ]),
462 compress.compress([hash.hash_slice(&[F::TWO, F::TWO]), default_digest]),
463 ]);
464 assert_eq!(commit, expected_result);
465 }
466
467 #[test]
468 fn commit_mixed() {
469 let mut rng = SmallRng::seed_from_u64(1);
470 let perm = Perm::new_from_rng_128(&mut rng);
471 let hash = MyHash::new(perm.clone());
472 let compress = MyCompress::new(perm);
473 let mmcs = MyMmcs::new(hash.clone(), compress.clone());
474 let default_digest = [F::ZERO; 8];
475
476 let mat_1 = RowMajorMatrix::new(
484 vec![
485 F::ZERO,
486 F::ONE,
487 F::TWO,
488 F::ONE,
489 F::TWO,
490 F::TWO,
491 F::TWO,
492 F::ONE,
493 F::TWO,
494 F::TWO,
495 ],
496 2,
497 );
498 let mat_2 = RowMajorMatrix::new(
504 vec![
505 F::ONE,
506 F::TWO,
507 F::ONE,
508 F::ZERO,
509 F::TWO,
510 F::TWO,
511 F::ONE,
512 F::TWO,
513 F::ONE,
514 ],
515 3,
516 );
517
518 let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
519
520 let mat_1_leaf_hashes = [
521 hash.hash_slice(&[F::ZERO, F::ONE]),
522 hash.hash_slice(&[F::TWO, F::ONE]),
523 hash.hash_slice(&[F::TWO, F::TWO]),
524 hash.hash_slice(&[F::TWO, F::ONE]),
525 hash.hash_slice(&[F::TWO, F::TWO]),
526 ];
527 let mat_2_leaf_hashes = [
528 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
529 hash.hash_slice(&[F::ZERO, F::TWO, F::TWO]),
530 hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
531 ];
532
533 let expected_result = compress.compress([
534 compress.compress([
535 compress.compress([
536 compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
537 mat_2_leaf_hashes[0],
538 ]),
539 compress.compress([
540 compress.compress([mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]]),
541 mat_2_leaf_hashes[1],
542 ]),
543 ]),
544 compress.compress([
545 compress.compress([
546 compress.compress([mat_1_leaf_hashes[4], default_digest]),
547 mat_2_leaf_hashes[2],
548 ]),
549 default_digest,
550 ]),
551 ]);
552
553 assert_eq!(commit, expected_result);
554
555 let (opened_values, _) = mmcs.open_batch(2, &prover_data).unpack();
556 assert_eq!(
557 opened_values,
558 vec![vec![F::TWO, F::TWO], vec![F::ZERO, F::TWO, F::TWO]]
559 );
560 }
561
562 #[test]
563 fn commit_either_order() {
564 let mut rng = SmallRng::seed_from_u64(1);
565 let perm = Perm::new_from_rng_128(&mut rng);
566 let hash = MyHash::new(perm.clone());
567 let compress = MyCompress::new(perm);
568 let mmcs = MyMmcs::new(hash, compress);
569
570 let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
571 let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
572
573 let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
574 let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
575 assert_eq!(commit_1_2, commit_2_1);
576 }
577
578 #[test]
579 #[should_panic]
580 fn mismatched_heights() {
581 let mut rng = SmallRng::seed_from_u64(1);
582 let perm = Perm::new_from_rng_128(&mut rng);
583 let hash = MyHash::new(perm.clone());
584 let compress = MyCompress::new(perm);
585 let mmcs = MyMmcs::new(hash, compress);
586
587 let large_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7, 8].map(F::from_u8).to_vec(), 1);
589 let small_mat = RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_u8).to_vec(), 1);
590 let _ = mmcs.commit(vec![large_mat, small_mat]);
591 }
592
593 #[test]
594 fn verify_tampered_proof_fails() {
595 let mut rng = SmallRng::seed_from_u64(1);
596 let perm = Perm::new_from_rng_128(&mut rng);
597 let hash = MyHash::new(perm.clone());
598 let compress = MyCompress::new(perm);
599 let mmcs = MyMmcs::new(hash, compress);
600
601 let mut mats = (0..4)
603 .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 1))
604 .collect_vec();
605 let large_mat_dims = (0..4).map(|_| Dimensions {
606 height: 8,
607 width: 1,
608 });
609 mats.extend((0..4).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 2)));
610 let small_mat_dims = (0..4).map(|_| Dimensions {
611 height: 8,
612 width: 2,
613 });
614
615 let (commit, prover_data) = mmcs.commit(mats);
616
617 let mut batch_opening = mmcs.open_batch(3, &prover_data);
619 batch_opening.opening_proof[0][0] += F::ONE;
620 mmcs.verify_batch(
621 &commit,
622 &large_mat_dims.chain(small_mat_dims).collect_vec(),
623 3,
624 (&batch_opening).into(),
625 )
626 .expect_err("expected verification to fail");
627 }
628
629 #[test]
630 fn size_gaps() {
631 let mut rng = SmallRng::seed_from_u64(1);
632 let perm = Perm::new_from_rng_128(&mut rng);
633 let hash = MyHash::new(perm.clone());
634 let compress = MyCompress::new(perm);
635 let mmcs = MyMmcs::new(hash, compress);
636
637 let mut mats = (0..4)
639 .map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1000, 8))
640 .collect_vec();
641 let large_mat_dims = (0..4).map(|_| Dimensions {
642 height: 1000,
643 width: 8,
644 });
645
646 mats.extend((0..5).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 125, 8)));
648 let medium_mat_dims = (0..5).map(|_| Dimensions {
649 height: 125,
650 width: 8,
651 });
652
653 mats.extend((0..6).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 8, 8)));
655 let small_mat_dims = (0..6).map(|_| Dimensions {
656 height: 8,
657 width: 8,
658 });
659
660 mats.extend((0..7).map(|_| RowMajorMatrix::<F>::rand(&mut rng, 1, 8)));
662 let tiny_mat_dims = (0..7).map(|_| Dimensions {
663 height: 1,
664 width: 8,
665 });
666
667 let dims = large_mat_dims
668 .chain(medium_mat_dims)
669 .chain(small_mat_dims)
670 .chain(tiny_mat_dims)
671 .collect_vec();
672
673 let (commit, prover_data) = mmcs.commit(mats);
674
675 for &index in &[0, 6, 124, 999] {
676 let batch_opening = mmcs.open_batch(index, &prover_data);
677 mmcs.verify_batch(&commit, &dims, index, (&batch_opening).into())
678 .expect("expected verification to succeed");
679 }
680 }
681
682 #[test]
683 fn different_widths() {
684 let mut rng = SmallRng::seed_from_u64(1);
685 let perm = Perm::new_from_rng_128(&mut rng);
686 let hash = MyHash::new(perm.clone());
687 let compress = MyCompress::new(perm);
688 let mmcs = MyMmcs::new(hash, compress);
689
690 let mats = (0..10)
692 .map(|i| RowMajorMatrix::<F>::rand(&mut rng, 32, i + 1))
693 .collect_vec();
694 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
695
696 let (commit, prover_data) = mmcs.commit(mats);
697 let batch_opening = mmcs.open_batch(17, &prover_data);
698 mmcs.verify_batch(&commit, &dims, 17, (&batch_opening).into())
699 .expect("expected verification to succeed");
700 }
701
702 #[test]
703 #[should_panic(expected = "matrix height 5 incompatible")]
704 fn commit_rejects_missing_leaf_coverage() {
705 let mut rng = SmallRng::seed_from_u64(9);
706 let perm = Perm::new_from_rng_128(&mut rng);
707 let hash = MyHash::new(perm.clone());
708 let compress = MyCompress::new(perm);
709 let mmcs = MyMmcs::new(hash, compress);
710
711 let tallest = RowMajorMatrix::new(vec![F::ONE; 11], 1);
712 let invalid = RowMajorMatrix::new(vec![F::ONE; 5], 1);
713
714 let _ = mmcs.commit(vec![tallest, invalid]);
717 }
718}