p3_merkle_tree/
hiding_mmcs.rs

1use alloc::vec::Vec;
2use core::cell::RefCell;
3
4use itertools::Itertools;
5use p3_commit::Mmcs;
6use p3_field::PackedValue;
7use p3_matrix::dense::RowMajorMatrix;
8use p3_matrix::stack::HorizontalPair;
9use p3_matrix::{Dimensions, Matrix};
10use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction};
11use p3_util::zip_eq::zip_eq;
12use rand::distributions::{Distribution, Standard};
13use rand::Rng;
14use serde::de::DeserializeOwned;
15use serde::{Deserialize, Serialize};
16
17use crate::{MerkleTree, MerkleTreeError, MerkleTreeMmcs};
18
19/// A vector commitment scheme backed by a `MerkleTree`.
20///
21/// This is similar to `MerkleTreeMmcs`, but each leaf is "salted" with random elements. This is
22/// done to turn the Merkle tree into a hiding commitment. See e.g. Section 3 of
23/// [Interactive Oracle Proofs](https://eprint.iacr.org/2016/116).
24///
25/// `SALT_ELEMS` should be set such that the product of `SALT_ELEMS` with the size of the value
26/// (`P::Value`) is at least the target security parameter.
27///
28/// `R` should be an appropriately seeded cryptographically secure pseudorandom number generator
29/// (CSPRNG). Something like `ThreadRng` may work, although it relies on the operating system to
30/// provide sufficient entropy.
31///
32/// Generics:
33/// - `P`: a leaf value
34/// - `PW`: an element of a digest
35/// - `H`: the leaf hasher
36/// - `C`: the digest compression function
37/// - `R`: a random number generator for blinding leaves
38#[derive(Clone, Debug)]
39pub struct MerkleTreeHidingMmcs<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
40{
41    inner: MerkleTreeMmcs<P, PW, H, C, DIGEST_ELEMS>,
42    rng: RefCell<R>,
43}
44
45impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize>
46    MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
47{
48    pub fn new(hash: H, compress: C, rng: R) -> Self {
49        let inner = MerkleTreeMmcs::new(hash, compress);
50        Self {
51            inner,
52            rng: rng.into(),
53        }
54    }
55}
56
57impl<P, PW, H, C, R, const DIGEST_ELEMS: usize, const SALT_ELEMS: usize> Mmcs<P::Value>
58    for MerkleTreeHidingMmcs<P, PW, H, C, R, DIGEST_ELEMS, SALT_ELEMS>
59where
60    P: PackedValue,
61    P::Value: Serialize + DeserializeOwned,
62    PW: PackedValue,
63    H: CryptographicHasher<P::Value, [PW::Value; DIGEST_ELEMS]>,
64    H: CryptographicHasher<P, [PW; DIGEST_ELEMS]>,
65    H: Sync,
66    C: PseudoCompressionFunction<[PW::Value; DIGEST_ELEMS], 2>,
67    C: PseudoCompressionFunction<[PW; DIGEST_ELEMS], 2>,
68    C: Sync,
69    R: Rng + Clone,
70    PW::Value: Eq,
71    [PW::Value; DIGEST_ELEMS]: Serialize + for<'de> Deserialize<'de>,
72    Standard: Distribution<P::Value>,
73{
74    type ProverData<M> =
75        MerkleTree<P::Value, PW::Value, HorizontalPair<M, RowMajorMatrix<P::Value>>, DIGEST_ELEMS>;
76    type Commitment = Hash<P::Value, PW::Value, DIGEST_ELEMS>;
77    /// The first item is salts; the second is the usual Merkle proof (sibling digests).
78    type Proof = (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>);
79    type Error = MerkleTreeError;
80
81    fn commit<M: Matrix<P::Value>>(
82        &self,
83        inputs: Vec<M>,
84    ) -> (Self::Commitment, Self::ProverData<M>) {
85        let salted_inputs = inputs
86            .into_iter()
87            .map(|mat| {
88                let salts =
89                    RowMajorMatrix::rand(&mut *self.rng.borrow_mut(), mat.height(), SALT_ELEMS);
90                HorizontalPair::new(mat, salts)
91            })
92            .collect();
93        self.inner.commit(salted_inputs)
94    }
95
96    fn open_batch<M: Matrix<P::Value>>(
97        &self,
98        index: usize,
99        prover_data: &Self::ProverData<M>,
100    ) -> (
101        Vec<Vec<P::Value>>,
102        (Vec<Vec<P::Value>>, Vec<[PW::Value; DIGEST_ELEMS]>),
103    ) {
104        let (salted_openings, siblings) = self.inner.open_batch(index, prover_data);
105        let (openings, salts): (Vec<_>, Vec<_>) = salted_openings
106            .into_iter()
107            .map(|row| {
108                let (a, b) = row.split_at(row.len() - SALT_ELEMS);
109                (a.to_vec(), b.to_vec())
110            })
111            .unzip();
112        (openings, (salts, siblings))
113    }
114
115    fn get_matrices<'a, M: Matrix<P::Value>>(
116        &self,
117        prover_data: &'a Self::ProverData<M>,
118    ) -> Vec<&'a M> {
119        prover_data.leaves.iter().map(|mat| &mat.first).collect()
120    }
121
122    fn verify_batch(
123        &self,
124        commit: &Self::Commitment,
125        dimensions: &[Dimensions],
126        index: usize,
127        opened_values: &[Vec<P::Value>],
128        proof: &Self::Proof,
129    ) -> Result<(), Self::Error> {
130        let (salts, siblings) = proof;
131
132        let opened_salted_values = zip_eq(opened_values, salts, MerkleTreeError::WrongBatchSize)?
133            .map(|(opened, salt)| opened.iter().chain(salt.iter()).copied().collect_vec())
134            .collect_vec();
135
136        self.inner
137            .verify_batch(commit, dimensions, index, &opened_salted_values, siblings)
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use alloc::vec;
144
145    use itertools::Itertools;
146    use p3_baby_bear::{BabyBear, Poseidon2BabyBear};
147    use p3_commit::Mmcs;
148    use p3_field::{Field, FieldAlgebra};
149    use p3_matrix::dense::RowMajorMatrix;
150    use p3_matrix::Matrix;
151    use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation};
152    use rand::prelude::*;
153
154    use super::MerkleTreeHidingMmcs;
155    use crate::MerkleTreeError;
156
157    type F = BabyBear;
158    const SALT_ELEMS: usize = 4;
159
160    type Perm = Poseidon2BabyBear<16>;
161    type MyHash = PaddingFreeSponge<Perm, 16, 8, 8>;
162    type MyCompress = TruncatedPermutation<Perm, 2, 8, 16>;
163    type MyMmcs = MerkleTreeHidingMmcs<
164        <F as Field>::Packing,
165        <F as Field>::Packing,
166        MyHash,
167        MyCompress,
168        ThreadRng,
169        8,
170        SALT_ELEMS,
171    >;
172
173    #[test]
174    #[should_panic]
175    fn mismatched_heights() {
176        let mut rng = thread_rng();
177        let perm = Perm::new_from_rng_128(&mut rng);
178        let hash = MyHash::new(perm.clone());
179        let compress = MyCompress::new(perm);
180        let mmcs = MyMmcs::new(hash, compress, thread_rng());
181
182        // attempt to commit to a mat with 8 rows and a mat with 7 rows. this should panic.
183        let large_mat = RowMajorMatrix::new(
184            [1, 2, 3, 4, 5, 6, 7, 8].map(F::from_canonical_u8).to_vec(),
185            1,
186        );
187        let small_mat =
188            RowMajorMatrix::new([1, 2, 3, 4, 5, 6, 7].map(F::from_canonical_u8).to_vec(), 1);
189        let _ = mmcs.commit(vec![large_mat, small_mat]);
190    }
191
192    #[test]
193    fn different_widths() -> Result<(), MerkleTreeError> {
194        let mut rng = thread_rng();
195        let perm = Perm::new_from_rng_128(&mut rng);
196        let hash = MyHash::new(perm.clone());
197        let compress = MyCompress::new(perm);
198        let mmcs = MyMmcs::new(hash, compress, thread_rng());
199
200        // 10 mats with 32 rows where the ith mat has i + 1 cols
201        let mats = (0..10)
202            .map(|i| RowMajorMatrix::<F>::rand(&mut thread_rng(), 32, i + 1))
203            .collect_vec();
204        let dims = mats.iter().map(|m| m.dimensions()).collect_vec();
205
206        let (commit, prover_data) = mmcs.commit(mats);
207        let (opened_values, proof) = mmcs.open_batch(17, &prover_data);
208        mmcs.verify_batch(&commit, &dims, 17, &opened_values, &proof)
209    }
210}