p3_merkle_tree/
mmcs.rs

1use alloc::vec::Vec;
2use core::cmp::Reverse;
3use core::marker::PhantomData;
4
5use itertools::Itertools;
6use p3_commit::Mmcs;
7use p3_field::PackedValue;
8use p3_matrix::{Dimensions, Matrix};
9use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
10use p3_util::log2_ceil_usize;
11use serde::{Deserialize, Serialize};
12
13use crate::MerkleTree;
14use crate::MerkleTreeError::{EmptyBatch, RootMismatch, WrongBatchSize, WrongHeight};
15
16/// A vector commitment scheme backed by a `MerkleTree`.
17///
18/// Generics:
19/// - `P`: a leaf value
20/// - `PW`: an element of a digest
21/// - `H`: the leaf hasher
22/// - `C`: the digest compression function
23#[derive(Copy, Clone, Debug)]
24pub struct MerkleTreeMmcs<P, PW, H, C, const DIGEST_ELEMS: usize> {
25    hash: H,
26    compress: C,
27    _phantom: PhantomData<(P, PW)>,
28}
29
30#[derive(Debug)]
31pub enum MerkleTreeError {
32    WrongBatchSize,
33    WrongWidth,
34    WrongHeight {
35        max_height: usize,
36        num_siblings: usize,
37    },
38    RootMismatch,
39    EmptyBatch,
40}
41
42impl<P, PW, H, C, const DIGEST_ELEMS: usize> MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS> {
43    pub const fn new(hash: H, compress: C) -> Self {
44        Self {
45            hash,
46            compress,
47            _phantom: PhantomData,
48        }
49    }
50}
51
52impl<P, PW, H, C, const DIGEST_ELEMS: usize> Mmcs<P::Value>
53    for MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>
54where
55    P: PackedValue,
56    PW: PackedValue,
57    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
58    H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
59    H: Sync,
60    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
61    C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
62    C: Sync,
63    PW::Value: Eq,
64    [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
65{
66    type ProverData<M> = MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>;
67    type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
68    type Proof = Vec<[PW::Value; DIGEST_ELEMS]>;
69    type Error = MerkleTreeError;
70
71    fn commit<M: Matrix<P::Value>>(
72        &self,
73        inputs: Vec<M>,
74    ) -> (Self::Commitment, Self::ProverData<M>) {
75        let tree = MerkleTree::new::<P, PW, H, C>(&self.hash, &self.compress, inputs);
76        let root = tree.root();
77        (root, tree)
78    }
79
80    fn open_batch<M: Matrix<P::Value>>(
81        &self,
82        index: usize,
83        prover_data: &MerkleTree<P::Value, PW::Value, M, DIGEST_ELEMS>,
84    ) -> (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>) {
85        let max_height = self.get_max_height(prover_data);
86        let log_max_height = log2_ceil_usize(max_height);
87
88        let openings = prover_data
89            .leaves
90            .iter()
91            .map(|matrix| {
92                let log2_height = log2_ceil_usize(matrix.height());
93                let bits_reduced = log_max_height - log2_height;
94                let reduced_index = index >> bits_reduced;
95                matrix.row(reduced_index).collect()
96            })
97            .collect_vec();
98
99        let proof: Vec<_> = (0..log_max_height)
100            .map(|i| prover_data.digest_layers[i][(index >> i) ^ 1])
101            .collect();
102
103        (openings, proof)
104    }
105
106    fn get_matrices<'a, M: Matrix<P::Value>>(
107        &self,
108        prover_data: &'a Self::ProverData<M>,
109    ) -> Vec<&'a M> {
110        prover_data.leaves.iter().collect()
111    }
112
113    fn verify_batch(
114        &self,
115        commit: &Self::Commitment,
116        dimensions: &[Dimensions],
117        mut index: usize,
118        opened_values: &[Vec<P::Value>],
119        proof: &Self::Proof,
120    ) -> Result<(), Self::Error> {
121        // Check that the openings have the correct shape.
122        if dimensions.len() != opened_values.len() {
123            return Err(WrongBatchSize);
124        }
125
126        // TODO: Disabled for now since TwoAdicFriPcs and CirclePcs currently pass 0 for width.
127        // for (dims, opened_vals) in zip_eq(dimensions.iter(), opened_values) {
128        //     if opened_vals.len() != dims.width {
129        //         return Err(WrongWidth);
130        //     }
131        // }
132
133        // TODO: Disabled for now, CirclePcs sometimes passes a height that's off by 1 bit.
134        let Some(max_height) = dimensions.iter().map(|dim| dim.height).max() else {
135            // dimensions is empty
136            return Err(EmptyBatch);
137        };
138        let log_max_height = log2_ceil_usize(max_height);
139        if proof.len() != log_max_height {
140            return Err(WrongHeight {
141                max_height,
142                num_siblings: proof.len(),
143            });
144        }
145
146        let mut heights_tallest_first = dimensions
147            .iter()
148            .enumerate()
149            .sorted_by_key(|(_, dims)| Reverse(dims.height))
150            .peekable();
151
152        let Some(mut curr_height_padded) = heights_tallest_first
153            .peek()
154            .map(|x| x.1.height.next_power_of_two())
155        else {
156            // dimensions is empty
157            return Err(EmptyBatch);
158        };
159
160        let mut root = self.hash.hash_iter_slices(
161            heights_tallest_first
162                .peeking_take_while(|(_, dims)| {
163                    dims.height.next_power_of_two() == curr_height_padded
164                })
165                .map(|(i, _)| opened_values[i].as_slice()),
166        );
167
168        for &sibling in proof.iter() {
169            let (left, right) = if index & 1 == 0 {
170                (root, sibling)
171            } else {
172                (sibling, root)
173            };
174
175            root = self.compress.compress([left, right]);
176            index >>= 1;
177            curr_height_padded >>= 1;
178
179            let next_height = heights_tallest_first
180                .peek()
181                .map(|(_, dims)| dims.height)
182                .filter(|h| h.next_power_of_two() == curr_height_padded);
183            if let Some(next_height) = next_height {
184                let next_height_openings_digest = self.hash.hash_iter_slices(
185                    heights_tallest_first
186                        .peeking_take_while(|(_, dims)| dims.height == next_height)
187                        .map(|(i, _)| opened_values[i].as_slice()),
188                );
189
190                root = self.compress.compress([root, next_height_openings_digest]);
191            }
192        }
193
194        if commit == &root {
195            Ok(())
196        } else {
197            Err(RootMismatch)
198        }
199    }
200}
201
202#[cfg(test)]
203mod tests {
204    use alloc::vec;
205
206    use itertools::Itertools;
207    use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
208    use p3_commit::Mmcs;
209    use p3_field::{Field, FieldAlgebra};
210    use p3_matrix::dense::RowMajorMatrix;
211    use p3_matrix::{Dimensions, Matrix};
212    use p3_symmetric::{
213        CryptographicHasher, PaddingFreeSponge, PseudoCompressionFunction, TruncatedPermutation,
214    };
215    use rand::thread_rng;
216
217    use super::MerkleTreeMmcs;
218
219    type F = BabyBear;
220
221    type Perm = Poseidon2BabyBear<16>;
222    type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
223    type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
224    type MyMmcs =
225        MerkleTreeMmcs<<F as Field>::Packing, <F as Field>::Packing, MyHash, MyCompress, 8>;
226
227    #[test]
228    fn commit_single_1x8() {
229        let perm = Perm::new_from_rng_128(&mut thread_rng());
230        let hash = MyHash::new(perm.clone());
231        let compress = MyCompress::new(perm);
232        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
233
234        // v = [2, 1, 2, 2, 0, 0, 1, 0]
235        let v = vec![
236            F::TWO,
237            F::ONE,
238            F::TWO,
239            F::TWO,
240            F::ZERO,
241            F::ZERO,
242            F::ONE,
243            F::ZERO,
244        ];
245        let (commit, _) = mmcs.commit_vec(v.clone());
246
247        let expected_result = compress.compress([
248            compress.compress([
249                compress.compress([hash.hash_item(v[0]), hash.hash_item(v[1])]),
250                compress.compress([hash.hash_item(v[2]), hash.hash_item(v[3])]),
251            ]),
252            compress.compress([
253                compress.compress([hash.hash_item(v[4]), hash.hash_item(v[5])]),
254                compress.compress([hash.hash_item(v[6]), hash.hash_item(v[7])]),
255            ]),
256        ]);
257        assert_eq!(commit, expected_result);
258    }
259
260    #[test]
261    fn commit_single_8x1() {
262        let perm = Perm::new_from_rng_128(&mut thread_rng());
263        let hash = MyHash::new(perm.clone());
264        let compress = MyCompress::new(perm);
265        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
266
267        let mat = RowMajorMatrix::<F>::rand(&mut thread_rng(), 1, 8);
268        let (commit, _) = mmcs.commit(vec![mat.clone()]);
269
270        let expected_result = hash.hash_iter(mat.clone().vertically_packed_row(0));
271        assert_eq!(commit, expected_result);
272    }
273
274    #[test]
275    fn commit_single_2x2() {
276        let perm = Perm::new_from_rng_128(&mut thread_rng());
277        let hash = MyHash::new(perm.clone());
278        let compress = MyCompress::new(perm);
279        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
280
281        // mat = [
282        //   0 1
283        //   2 1
284        // ]
285        let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE], 2);
286
287        let (commit, _) = mmcs.commit(vec![mat]);
288
289        let expected_result = compress.compress([
290            hash.hash_slice(&[F::ZERO, F::ONE]),
291            hash.hash_slice(&[F::TWO, F::ONE]),
292        ]);
293        assert_eq!(commit, expected_result);
294    }
295
296    #[test]
297    fn commit_single_2x3() {
298        let perm = Perm::new_from_rng_128(&mut thread_rng());
299        let hash = MyHash::new(perm.clone());
300        let compress = MyCompress::new(perm);
301        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
302        let default_digest = [F::ZERO; 8];
303
304        // mat = [
305        //   0 1
306        //   2 1
307        //   2 2
308        // ]
309        let mat = RowMajorMatrix::new(vec![F::ZERO, F::ONE, F::TWO, F::ONE, F::TWO, F::TWO], 2);
310
311        let (commit, _) = mmcs.commit(vec![mat]);
312
313        let expected_result = compress.compress([
314            compress.compress([
315                hash.hash_slice(&[F::ZERO, F::ONE]),
316                hash.hash_slice(&[F::TWO, F::ONE]),
317            ]),
318            compress.compress([hash.hash_slice(&[F::TWO, F::TWO]), default_digest]),
319        ]);
320        assert_eq!(commit, expected_result);
321    }
322
323    #[test]
324    fn commit_mixed() {
325        let perm = Perm::new_from_rng_128(&mut thread_rng());
326        let hash = MyHash::new(perm.clone());
327        let compress = MyCompress::new(perm);
328        let mmcs = MyMmcs::new(hash.clone(), compress.clone());
329        let default_digest = [F::ZERO; 8];
330
331        // mat_1 = [
332        //   0 1
333        //   2 1
334        //   2 2
335        //   2 1
336        //   2 2
337        // ]
338        let mat_1 = RowMajorMatrix::new(
339            vec![
340                F::ZERO,
341                F::ONE,
342                F::TWO,
343                F::ONE,
344                F::TWO,
345                F::TWO,
346                F::TWO,
347                F::ONE,
348                F::TWO,
349                F::TWO,
350            ],
351            2,
352        );
353        // mat_2 = [
354        //   1 2 1
355        //   0 2 2
356        //   1 2 1
357        // ]
358        let mat_2 = RowMajorMatrix::new(
359            vec![
360                F::ONE,
361                F::TWO,
362                F::ONE,
363                F::ZERO,
364                F::TWO,
365                F::TWO,
366                F::ONE,
367                F::TWO,
368                F::ONE,
369            ],
370            3,
371        );
372
373        let (commit, prover_data) = mmcs.commit(vec![mat_1, mat_2]);
374
375        let mat_1_leaf_hashes = [
376            hash.hash_slice(&[F::ZERO, F::ONE]),
377            hash.hash_slice(&[F::TWO, F::ONE]),
378            hash.hash_slice(&[F::TWO, F::TWO]),
379            hash.hash_slice(&[F::TWO, F::ONE]),
380            hash.hash_slice(&[F::TWO, F::TWO]),
381        ];
382        let mat_2_leaf_hashes = [
383            hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
384            hash.hash_slice(&[F::ZERO, F::TWO, F::TWO]),
385            hash.hash_slice(&[F::ONE, F::TWO, F::ONE]),
386        ];
387
388        let expected_result = compress.compress([
389            compress.compress([
390                compress.compress([
391                    compress.compress([mat_1_leaf_hashes[0], mat_1_leaf_hashes[1]]),
392                    mat_2_leaf_hashes[0],
393                ]),
394                compress.compress([
395                    compress.compress([mat_1_leaf_hashes[2], mat_1_leaf_hashes[3]]),
396                    mat_2_leaf_hashes[1],
397                ]),
398            ]),
399            compress.compress([
400                compress.compress([
401                    compress.compress([mat_1_leaf_hashes[4], default_digest]),
402                    mat_2_leaf_hashes[2],
403                ]),
404                default_digest,
405            ]),
406        ]);
407
408        assert_eq!(commit, expected_result);
409
410        let (opened_values, _proof) = mmcs.open_batch(2, &prover_data);
411        assert_eq!(
412            opened_values,
413            vec![vec![F::TWO, F::TWO], vec![F::ZERO, F::TWO, F::TWO]]
414        );
415    }
416
417    #[test]
418    fn commit_either_order() {
419        let mut rng = thread_rng();
420        let perm = Perm::new_from_rng_128(&mut rng);
421        let hash = MyHash::new(perm.clone());
422        let compress = MyCompress::new(perm);
423        let mmcs = MyMmcs::new(hash, compress);
424
425        let input_1 = RowMajorMatrix::<F>::rand(&mut rng, 5, 8);
426        let input_2 = RowMajorMatrix::<F>::rand(&mut rng, 3, 16);
427
428        let (commit_1_2, _) = mmcs.commit(vec![input_1.clone(), input_2.clone()]);
429        let (commit_2_1, _) = mmcs.commit(vec![input_2, input_1]);
430        assert_eq!(commit_1_2, commit_2_1);
431    }
432
433    #[test]
434    #[should_panic]
435    fn mismatched_heights() {
436        let mut rng = thread_rng();
437        let perm = Perm::new_from_rng_128(&mut rng);
438        let hash = MyHash::new(perm.clone());
439        let compress = MyCompress::new(perm);
440        let mmcs = MyMmcs::new(hash, compress);
441
442        // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic.
443        let large_mat = RowMajorMatrix::new(
444            [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
445            1,
446        );
447        let small_mat =
448            RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
449        let _ = mmcs.commit(vec![large_mat, small_mat]);
450    }
451
452    #[test]
453    fn verify_tampered_proof_fails() {
454        let mut rng = thread_rng();
455        let perm = Perm::new_from_rng_128(&mut rng);
456        let hash = MyHash::new(perm.clone());
457        let compress = MyCompress::new(perm);
458        let mmcs = MyMmcs::new(hash, compress);
459
460        // 4 8x1 matrixes, 4 8x2 matrixes
461        let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 1));
462        let large_mat_dims = (0..4).map(|_| Dimensions {
463            height: 8,
464            width: 1,
465        });
466        let small_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 2));
467        let small_mat_dims = (0..4).map(|_| Dimensions {
468            height: 8,
469            width: 2,
470        });
471
472        let (commit, prover_data) = mmcs.commit(large_mats.chain(small_mats).collect_vec());
473
474        // open the 3rd row of each matrix, mess with proof, and verify
475        let (opened_values, mut proof) = mmcs.open_batch(3, &prover_data);
476        proof[0][0] += F::ONE;
477        mmcs.verify_batch(
478            &commit,
479            &large_mat_dims.chain(small_mat_dims).collect_vec(),
480            3,
481            &opened_values,
482            &proof,
483        )
484        .expect_err("expected verification to fail");
485    }
486
487    #[test]
488    fn size_gaps() {
489        let mut rng = thread_rng();
490        let perm = Perm::new_from_rng_128(&mut rng);
491        let hash = MyHash::new(perm.clone());
492        let compress = MyCompress::new(perm);
493        let mmcs = MyMmcs::new(hash, compress);
494
495        // 4 mats with 1000 rows, 8 columns
496        let large_mats = (0..4).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1000, 8));
497        let large_mat_dims = (0..4).map(|_| Dimensions {
498            height: 1000,
499            width: 8,
500        });
501
502        // 5 mats with 70 rows, 8 columns
503        let medium_mats = (0..5).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 70, 8));
504        let medium_mat_dims = (0..5).map(|_| Dimensions {
505            height: 70,
506            width: 8,
507        });
508
509        // 6 mats with 8 rows, 8 columns
510        let small_mats = (0..6).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 8, 8));
511        let small_mat_dims = (0..6).map(|_| Dimensions {
512            height: 8,
513            width: 8,
514        });
515
516        // 7 tiny mat with 1 row, 8 columns
517        let tiny_mats = (0..7).map(|_| RowMajorMatrix::<F>::rand(&mut thread_rng(), 1, 8));
518        let tiny_mat_dims = (0..7).map(|_| Dimensions {
519            height: 1,
520            width: 8,
521        });
522
523        let (commit, prover_data) = mmcs.commit(
524            large_mats
525                .chain(medium_mats)
526                .chain(small_mats)
527                .chain(tiny_mats)
528                .collect_vec(),
529        );
530
531        // open the 6th row of each matrix and verify
532        let (opened_values, proof) = mmcs.open_batch(6, &prover_data);
533        mmcs.verify_batch(
534            &commit,
535            &large_mat_dims
536                .chain(medium_mat_dims)
537                .chain(small_mat_dims)
538                .chain(tiny_mat_dims)
539                .collect_vec(),
540            6,
541            &opened_values,
542            &proof,
543        )
544        .expect("expected verification to succeed");
545    }
546
547    #[test]
548    fn different_widths() {
549        let mut rng = thread_rng();
550        let perm = Perm::new_from_rng_128(&mut rng);
551        let hash = MyHash::new(perm.clone());
552        let compress = MyCompress::new(perm);
553        let mmcs = MyMmcs::new(hash, compress);
554
555        // 10 mats with 32 rows where the ith mat has i + 1 cols
556        let mats = (0..10)
557            .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
558            .collect_vec();
559        let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
560
561        let (commit, prover_data) = mmcs.commit(mats);
562        let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
563        mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
564            .expect("expected verification to succeed");
565    }
566}