openvm_circuit/system/memory/merkle/
public_values.rs

1use itertools::Itertools;
2use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use tracing::instrument;
6
7use crate::{
8    arch::{hasher::Hasher, MemoryCellType, ADDR_SPACE_OFFSET},
9    system::memory::{dimensions::MemoryDimensions, online::LinearMemory, MemoryImage},
10};
11
12pub const PUBLIC_VALUES_AS: u32 = 3;
13pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET;
14
15/// Merkle proof for user public values in the memory state.
16#[derive(Clone, Debug, Serialize, Deserialize)]
17#[serde(bound(
18    serialize = "F: Serialize, [F; CHUNK]: Serialize",
19    deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
20))]
21pub struct UserPublicValuesProof<const CHUNK: usize, F> {
22    /// Proof of the path from the root of public values to the memory root in the format of
23    /// sequence of sibling node hashes.
24    pub proof: Vec<[F; CHUNK]>,
25    /// Raw public values. Its length should be (a power of two) * CHUNK.
26    pub public_values: Vec<F>,
27    /// Merkle root of public values. The computation of this value follows the same logic of
28    /// `MemoryNode`. The merkle tree doesn't pad because the length `public_values` implies the
29    /// merkle tree is always a full binary tree.
30    pub public_values_commit: [F; CHUNK],
31}
32
33#[derive(Error, Debug)]
34pub enum UserPublicValuesProofError {
35    #[error("unexpected length: {0}")]
36    UnexpectedLength(usize),
37    #[error("incorrect proof length: {0} (expected {1})")]
38    IncorrectProofLength(usize, usize),
39    #[error("user public values do not match commitment")]
40    UserPublicValuesCommitMismatch,
41    #[error("final memory root mismatch")]
42    FinalMemoryRootMismatch,
43}
44
45impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
46    /// Computes the proof of the public values from the final memory state and the Merkle top
47    /// sub-tree of address space roots. This function will re-compute the empty merkle roots of
48    /// each height `0..=address_height` internally.
49    ///
50    /// Assumption:
51    /// - `num_public_values` is a power of two * CHUNK. It cannot be 0.
52    /// - `top_tree` is 0-indexed and a segment tree of length `2 * 2^addr_space_height - 1`.
53    #[instrument(name = "compute_user_public_values_proof", skip_all)]
54    pub fn compute(
55        memory_dimensions: MemoryDimensions,
56        num_public_values: usize,
57        hasher: &(impl Hasher<CHUNK, F> + Sync),
58        final_memory: &MemoryImage,
59        top_tree: &[[F; CHUNK]],
60    ) -> Self {
61        let public_values = extract_public_values(num_public_values, final_memory)
62            .iter()
63            .map(|&x| F::from_canonical_u8(x))
64            .collect_vec();
65        let public_values_commit = hasher.merkle_root(&public_values);
66        let proof = compute_merkle_proof_to_user_public_values_root(
67            memory_dimensions,
68            num_public_values,
69            hasher,
70            top_tree,
71        );
72        UserPublicValuesProof {
73            proof,
74            public_values,
75            public_values_commit,
76        }
77    }
78
79    pub fn verify(
80        &self,
81        hasher: &impl Hasher<CHUNK, F>,
82        memory_dimensions: MemoryDimensions,
83        final_memory_root: [F; CHUNK],
84    ) -> Result<(), UserPublicValuesProofError> {
85        // Verify user public values Merkle proof:
86        // 0. Get correct indices for Merkle proof based on memory dimensions
87        // 1. Verify user public values commitment with respect to the final memory root.
88        // 2. Compare user public values commitment with Merkle root of user public values.
89        let pv_commit = self.public_values_commit;
90        // 0.
91        let pv_as = PUBLIC_VALUES_AS;
92        let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
93        let pvs = &self.public_values;
94        if !pvs.len().is_multiple_of(CHUNK) || !(pvs.len() / CHUNK).is_power_of_two() {
95            return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
96        }
97        let pv_height = log2_strict_usize(pvs.len() / CHUNK);
98        let proof_len = memory_dimensions.overall_height() - pv_height;
99        let idx_prefix = pv_start_idx >> pv_height;
100        // 1.
101        if self.proof.len() != proof_len {
102            return Err(UserPublicValuesProofError::IncorrectProofLength(
103                self.proof.len(),
104                proof_len,
105            ));
106        }
107        let mut curr_root = pv_commit;
108        for (i, sibling_hash) in self.proof.iter().enumerate() {
109            curr_root = if idx_prefix & (1 << i) != 0 {
110                hasher.compress(sibling_hash, &curr_root)
111            } else {
112                hasher.compress(&curr_root, sibling_hash)
113            }
114        }
115        if curr_root != final_memory_root {
116            return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
117        }
118        // 2. Compute merkle root of public values
119        if hasher.merkle_root(pvs) != pv_commit {
120            return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
121        }
122
123        Ok(())
124    }
125}
126
127fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
128    memory_dimensions: MemoryDimensions,
129    num_public_values: usize,
130    hasher: &(impl Hasher<CHUNK, F> + Sync),
131    top_tree: &[[F; CHUNK]],
132) -> Vec<[F; CHUNK]> {
133    assert_eq!(
134        num_public_values % CHUNK,
135        0,
136        "num_public_values must be a multiple of memory chunk {CHUNK}"
137    );
138    let address_height = memory_dimensions.address_height;
139    let addr_space_height = memory_dimensions.addr_space_height;
140    assert_eq!(top_tree.len(), (2 << addr_space_height) - 1);
141    let num_pv_chunks: usize = num_public_values / CHUNK;
142    // This enforces the number of public values cannot be 0.
143    assert!(
144        num_pv_chunks.is_power_of_two(),
145        "pv_height must be a power of two"
146    );
147    let pv_height = log2_strict_usize(num_pv_chunks);
148    let address_leading_zeros = address_height - pv_height;
149
150    let mut cur_node_idx = 1; // root
151    let mut proof = Vec::with_capacity(addr_space_height + address_leading_zeros);
152    let zero_nodes: Vec<_> = (0..address_height)
153        .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| {
154            let result = Some(*acc);
155            *acc = hasher.compress(acc, acc);
156            result
157        })
158        .collect();
159    for i in 0..addr_space_height {
160        let bit = 1 << (memory_dimensions.addr_space_height - i - 1);
161        // Recall: top_tree is 0-indexed, but cur_node_idx is 1-indexed
162        if (PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET) & bit != 0 {
163            proof.push(top_tree[cur_node_idx * 2 - 1]);
164            cur_node_idx = cur_node_idx * 2 + 1;
165        } else {
166            proof.push(top_tree[cur_node_idx * 2]);
167            cur_node_idx *= 2;
168        }
169    }
170    for i in 0..address_leading_zeros {
171        // node is always on the left, the sibling is always zero node hash
172        proof.push(zero_nodes[address_height - 1 - i]);
173    }
174    proof.reverse();
175    proof
176}
177
178pub fn extract_public_values(num_public_values: usize, final_memory: &MemoryImage) -> Vec<u8> {
179    let mut public_values: Vec<u8> = {
180        assert_eq!(
181            final_memory.config[PUBLIC_VALUES_AS as usize].layout,
182            MemoryCellType::U8
183        );
184        final_memory.mem[PUBLIC_VALUES_AS as usize]
185            .as_slice()
186            .to_vec()
187    };
188
189    assert!(
190        public_values.len() >= num_public_values,
191        "Public values address space has {} elements, but configuration has num_public_values={}",
192        public_values.len(),
193        num_public_values
194    );
195    public_values.truncate(num_public_values);
196    public_values
197}
198
199#[cfg(test)]
200mod tests {
201    use openvm_stark_backend::p3_field::FieldAlgebra;
202    use openvm_stark_sdk::p3_baby_bear::BabyBear;
203
204    use super::UserPublicValuesProof;
205    use crate::{
206        arch::{hasher::poseidon2::vm_poseidon2_hasher, MemoryConfig, SystemConfig},
207        system::memory::{
208            merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree},
209            online::GuestMemory,
210            AddressMap, CHUNK,
211        },
212    };
213
214    type F = BabyBear;
215    #[test]
216    fn test_public_value_happy_path() {
217        let mut vm_config = SystemConfig::default().without_continuations();
218        let addr_space_height = 4;
219        vm_config.memory_config.addr_space_height = addr_space_height;
220        vm_config.memory_config.pointer_max_bits = 5;
221        let memory_dimensions = vm_config.memory_config.memory_dimensions();
222        let num_public_values = 16;
223        let mut addr_spaces_config = MemoryConfig::empty_address_space_configs(4);
224        addr_spaces_config[PUBLIC_VALUES_AS as usize].num_cells = num_public_values;
225        let mut memory = GuestMemory {
226            memory: AddressMap::new(addr_spaces_config),
227        };
228        unsafe {
229            memory.write::<u8, 4>(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]);
230        }
231        let mut expected_pvs = F::zero_vec(num_public_values);
232        expected_pvs[15] = F::ONE;
233
234        let hasher = vm_poseidon2_hasher();
235        let tree = MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher);
236        let top_tree = tree.top_tree(addr_space_height);
237        let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
238            memory_dimensions,
239            num_public_values,
240            &hasher,
241            &memory.memory,
242            &top_tree,
243        );
244        assert_eq!(pv_proof.public_values, expected_pvs);
245        let final_memory_root =
246            MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root();
247        pv_proof
248            .verify(&hasher, memory_dimensions, final_memory_root)
249            .unwrap();
250    }
251}