1use alloc::vec::Vec;
2use core::cell::RefCell;
3
4use itertools::Itertools;
5use p3_commit::Mmcs;
6use p3_field::PackedValue;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::stack::HorizontalPair;
9use p3_matrix::{Dimensions, Matrix};
10use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
11use p3_util::zip_eq::zip_eq;
12use rand::distributions::{Distribution, Standard};
13use rand::Rng;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17use crate::{MerkleTree, MerkleTreeError, MerkleTreeMmcs};
18
19#[derive(Clone, Debug)]
39pub struct MerkleTreeHidingMmcs<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
40{
41 inner: MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>,
42 rng: RefCell<R>,
43}
44
45impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
46 MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
47{
48 pub fn new(hash: H, compress: C, rng: R) -> Self {
49 let inner = MerkleTreeMmcs::new(hash, compress);
50 Self {
51 inner,
52 rng: rng.into(),
53 }
54 }
55}
56
57impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize> Mmcs<P::Value>
58 for MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
59where
60 P: PackedValue,
61 P::Value: Serialize + DeserializeOwned,
62 PW: PackedValue,
63 H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
64 H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
65 H: Sync,
66 C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
67 C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
68 C: Sync,
69 R: Rng + Clone,
70 PW::Value: Eq,
71 [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
72 Standard: Distribution<P::Value>,
73{
74 type ProverData<M> =
75 MerkleTree<P::Value, PW::Value, HorizontalPair<M, RowMajorMatrix<P::Value>>, DIGEST_ELEMS>;
76 type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
77 type Proof = (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>);
79 type Error = MerkleTreeError;
80
81 fn commit<M: Matrix<P::Value>>(
82 &self,
83 inputs: Vec<M>,
84 ) -> (Self::Commitment, Self::ProverData<M>) {
85 let salted_inputs = inputs
86 .into_iter()
87 .map(|mat| {
88 let salts =
89 RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), mat.height(), SALT_ELEMS);
90 HorizontalPair::new(mat, salts)
91 })
92 .collect();
93 self.inner.commit(salted_inputs)
94 }
95
96 fn open_batch<M: Matrix<P::Value>>(
97 &self,
98 index: usize,
99 prover_data: &Self::ProverData<M>,
100 ) -> (
101 Vec<Vec<P::Value>>,
102 (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>),
103 ) {
104 let (salted_openings, siblings) = self.inner.open_batch(index, prover_data);
105 let (openings, salts): (Vec<_>, Vec<_>) = salted_openings
106 .into_iter()
107 .map(|row| {
108 let (a, b) = row.split_at(row.len() - SALT_ELEMS);
109 (a.to_vec(), b.to_vec())
110 })
111 .unzip();
112 (openings, (salts, siblings))
113 }
114
115 fn get_matrices<'a, M: Matrix<P::Value>>(
116 &self,
117 prover_data: &'a Self::ProverData<M>,
118 ) -> Vec<&'a M> {
119 prover_data.leaves.iter().map(|mat| &mat.first).collect()
120 }
121
122 fn verify_batch(
123 &self,
124 commit: &Self::Commitment,
125 dimensions: &[Dimensions],
126 index: usize,
127 opened_values: &[Vec<P::Value>],
128 proof: &Self::Proof,
129 ) -> Result<(), Self::Error> {
130 let (salts, siblings) = proof;
131
132 let opened_salted_values = zip_eq(opened_values, salts, MerkleTreeError::WrongBatchSize)?
133 .map(|(opened, salt)| opened.iter().chain(salt.iter()).copied().collect_vec())
134 .collect_vec();
135
136 self.inner
137 .verify_batch(commit, dimensions, index, &opened_salted_values, siblings)
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use alloc::vec;
144
145 use itertools::Itertools;
146 use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
147 use p3_commit::Mmcs;
148 use p3_field::{Field, FieldAlgebra};
149 use p3_matrix::dense::RowMajorMatrix;
150 use p3_matrix::Matrix;
151 use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
152 use rand::prelude::*;
153
154 use super::MerkleTreeHidingMmcs;
155 use crate::MerkleTreeError;
156
157 type F = BabyBear;
158 const SALT_ELEMS: usize = 4;
159
160 type Perm = Poseidon2BabyBear<16>;
161 type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
162 type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
163 type MyMmcs = MerkleTreeHidingMmcs<
164 <F as Field>::Packing,
165 <F as Field>::Packing,
166 MyHash,
167 MyCompress,
168 ThreadRng,
169 8,
170 SALT_ELEMS,
171 >;
172
173 #[test]
174 #[should_panic]
175 fn mismatched_heights() {
176 let mut rng = thread_rng();
177 let perm = Perm::new_from_rng_128(&mut rng);
178 let hash = MyHash::new(perm.clone());
179 let compress = MyCompress::new(perm);
180 let mmcs = MyMmcs::new(hash, compress, thread_rng());
181
182 let large_mat = RowMajorMatrix::new(
184 [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
185 1,
186 );
187 let small_mat =
188 RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
189 let _ = mmcs.commit(vec![large_mat, small_mat]);
190 }
191
192 #[test]
193 fn different_widths() -> Result<(), MerkleTreeError> {
194 let mut rng = thread_rng();
195 let perm = Perm::new_from_rng_128(&mut rng);
196 let hash = MyHash::new(perm.clone());
197 let compress = MyCompress::new(perm);
198 let mmcs = MyMmcs::new(hash, compress, thread_rng());
199
200 let mats = (0..10)
202 .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
203 .collect_vec();
204 let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
205
206 let (commit, prover_data) = mmcs.commit(mats);
207 let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
208 mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
209 }
210}