openvm_circuit/system/memory/tree/
public_values.rs

1use std::{collections::BTreeMap, sync::Arc};
2
3use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::{
8    arch::hasher::Hasher,
9    system::memory::{
10        dimensions::MemoryDimensions, paged_vec::Address, tree::MemoryNode, MemoryImage,
11    },
12};
13
14pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = 2;
15
16/// Merkle proof for user public values in the memory state.
17#[derive(Clone, Debug, Serialize, Deserialize)]
18#[serde(bound(
19    serialize = "F: Serialize, [F; CHUNK]: Serialize",
20    deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
21))]
22pub struct UserPublicValuesProof<const CHUNK: usize, F> {
23    /// Proof of the path from the root of public values to the memory root in the format of
24    /// sequence of sibling node hashes.
25    pub proof: Vec<[F; CHUNK]>,
26    /// Raw public values. Its length should be a power of two * CHUNK.
27    pub public_values: Vec<F>,
28    /// Merkle root of public values. The computation of this value follows the same logic of
29    /// `MemoryNode`. The merkle tree doesn't pad because the length `public_values` implies the
30    /// merkle tree is always a full binary tree.
31    pub public_values_commit: [F; CHUNK],
32}
33
34#[derive(Error, Debug)]
35pub enum UserPublicValuesProofError {
36    #[error("unexpected length: {0}")]
37    UnexpectedLength(usize),
38    #[error("incorrect proof length: {0} (expected {1})")]
39    IncorrectProofLength(usize, usize),
40    #[error("user public values do not match commitment")]
41    UserPublicValuesCommitMismatch,
42    #[error("final memory root mismatch")]
43    FinalMemoryRootMismatch,
44}
45
46impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
47    /// Computes the proof of the public values from the final memory state.
48    /// Assumption:
49    /// - `num_public_values` is a power of two * CHUNK. It cannot be 0.
50    pub fn compute(
51        memory_dimensions: MemoryDimensions,
52        num_public_values: usize,
53        hasher: &(impl Hasher<CHUNK, F> + Sync),
54        final_memory: &MemoryImage<F>,
55    ) -> Self {
56        let proof = compute_merkle_proof_to_user_public_values_root(
57            memory_dimensions,
58            num_public_values,
59            hasher,
60            final_memory,
61        );
62        let public_values =
63            extract_public_values(&memory_dimensions, num_public_values, final_memory);
64        let public_values_commit = hasher.merkle_root(&public_values);
65        UserPublicValuesProof {
66            proof,
67            public_values,
68            public_values_commit,
69        }
70    }
71
72    pub fn verify(
73        &self,
74        hasher: &impl Hasher<CHUNK, F>,
75        memory_dimensions: MemoryDimensions,
76        final_memory_root: [F; CHUNK],
77    ) -> Result<(), UserPublicValuesProofError> {
78        // Verify user public values Merkle proof:
79        // 0. Get correct indices for Merkle proof based on memory dimensions
80        // 1. Verify user public values commitment with respect to the final memory root.
81        // 2. Compare user public values commitment with Merkle root of user public values.
82        let pv_commit = self.public_values_commit;
83        // 0.
84        let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
85        let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
86        let pvs = &self.public_values;
87        if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() {
88            return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
89        }
90        let pv_height = log2_strict_usize(pvs.len() / CHUNK);
91        let proof_len = memory_dimensions.overall_height() - pv_height;
92        let idx_prefix = pv_start_idx >> pv_height;
93        // 1.
94        if self.proof.len() != proof_len {
95            return Err(UserPublicValuesProofError::IncorrectProofLength(
96                self.proof.len(),
97                proof_len,
98            ));
99        }
100        let mut curr_root = pv_commit;
101        for (i, sibling_hash) in self.proof.iter().enumerate() {
102            curr_root = if idx_prefix & (1 << i) != 0 {
103                hasher.compress(sibling_hash, &curr_root)
104            } else {
105                hasher.compress(&curr_root, sibling_hash)
106            }
107        }
108        if curr_root != final_memory_root {
109            return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
110        }
111        // 2. Compute merkle root of public values
112        if hasher.merkle_root(pvs) != pv_commit {
113            return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
114        }
115
116        Ok(())
117    }
118}
119
120fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
121    memory_dimensions: MemoryDimensions,
122    num_public_values: usize,
123    hasher: &(impl Hasher<CHUNK, F> + Sync),
124    final_memory: &MemoryImage<F>,
125) -> Vec<[F; CHUNK]> {
126    assert_eq!(
127        num_public_values % CHUNK,
128        0,
129        "num_public_values must be a multiple of memory chunk {CHUNK}"
130    );
131    let root = MemoryNode::tree_from_memory(memory_dimensions, final_memory, hasher);
132    let num_pv_chunks: usize = num_public_values / CHUNK;
133    // This enforces the number of public values cannot be 0.
134    assert!(
135        num_pv_chunks.is_power_of_two(),
136        "pv_height must be a power of two"
137    );
138    let pv_height = log2_strict_usize(num_pv_chunks);
139    let address_leading_zeros = memory_dimensions.address_height - pv_height;
140
141    let mut curr_node = Arc::new(root);
142    let mut proof = Vec::with_capacity(memory_dimensions.as_height + address_leading_zeros);
143    for i in 0..memory_dimensions.as_height {
144        let bit = 1 << (memory_dimensions.as_height - i - 1);
145        if let MemoryNode::NonLeaf { left, right, .. } = curr_node.as_ref().clone() {
146            if PUBLIC_VALUES_ADDRESS_SPACE_OFFSET & bit != 0 {
147                curr_node = right;
148                proof.push(left.hash());
149            } else {
150                curr_node = left;
151                proof.push(right.hash());
152            }
153        } else {
154            unreachable!()
155        }
156    }
157    for _ in 0..address_leading_zeros {
158        if let MemoryNode::NonLeaf { left, right, .. } = curr_node.as_ref().clone() {
159            curr_node = left;
160            proof.push(right.hash());
161        } else {
162            unreachable!()
163        }
164    }
165    proof.reverse();
166    proof
167}
168
169pub fn extract_public_values<F: PrimeField32>(
170    memory_dimensions: &MemoryDimensions,
171    num_public_values: usize,
172    final_memory: &MemoryImage<F>,
173) -> Vec<F> {
174    // All (addr, value) pairs in the public value address space.
175    let f_as_start = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
176    let f_as_end = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset + 1;
177
178    // This clones the entire memory. Ideally this should run in time proportional to
179    // the size of the PV address space, not entire memory.
180    let final_memory: BTreeMap<Address, F> = final_memory.items().collect();
181
182    let used_pvs: Vec<_> = final_memory
183        .range((f_as_start, 0)..(f_as_end, 0))
184        .map(|(&(_, pointer), &value)| (pointer as usize, value))
185        .collect();
186    if let Some(&last_pv) = used_pvs.last() {
187        assert!(
188            last_pv.0 < num_public_values || last_pv.1 == F::ZERO,
189            "Last public value is out of bounds"
190        );
191    }
192    let mut public_values = F::zero_vec(num_public_values);
193    for (i, pv) in used_pvs {
194        if i < num_public_values {
195            public_values[i] = pv;
196        }
197    }
198    public_values
199}
200
201#[cfg(test)]
202mod tests {
203    use openvm_stark_backend::p3_field::FieldAlgebra;
204    use openvm_stark_sdk::p3_baby_bear::BabyBear;
205
206    use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET};
207    use crate::{
208        arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig},
209        system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK},
210    };
211
212    type F = BabyBear;
213    #[test]
214    fn test_public_value_happy_path() {
215        let mut vm_config = SystemConfig::default();
216        vm_config.memory_config.as_height = 4;
217        vm_config.memory_config.pointer_max_bits = 5;
218        let memory_dimensions = vm_config.memory_config.memory_dimensions();
219        let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
220        let num_public_values = 16;
221        let memory = AddressMap::from_iter(
222            memory_dimensions.as_offset,
223            1 << memory_dimensions.as_height,
224            1 << memory_dimensions.address_height,
225            [((pv_as, 15), F::ONE)],
226        );
227        let mut expected_pvs = F::zero_vec(num_public_values);
228        expected_pvs[15] = F::ONE;
229
230        let hasher = vm_poseidon2_hasher();
231        let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
232            memory_dimensions,
233            num_public_values,
234            &hasher,
235            &memory,
236        );
237        assert_eq!(pv_proof.public_values, expected_pvs);
238        let final_memory_root = MemoryNode::tree_from_memory(memory_dimensions, &memory, &hasher);
239        pv_proof
240            .verify(&hasher, memory_dimensions, final_memory_root.hash())
241            .unwrap();
242    }
243}