openvm_circuit/system/memory/merkle/
public_values.rs

1use itertools::Itertools;
2use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use tracing::instrument;
6
7use crate::{
8    arch::{hasher::Hasher, MemoryCellType, ADDR_SPACE_OFFSET},
9    system::memory::{
10        dimensions::MemoryDimensions, merkle::tree::MerkleTree, online::LinearMemory, MemoryImage,
11    },
12};
13
14pub const PUBLIC_VALUES_AS: u32 = 3;
15pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET;
16
17/// Merkle proof for user public values in the memory state.
18#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(bound(
20    serialize = "F: Serialize, [F; CHUNK]: Serialize",
21    deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
22))]
23pub struct UserPublicValuesProof<const CHUNK: usize, F> {
24    /// Proof of the path from the root of public values to the memory root in the format of
25    /// sequence of sibling node hashes.
26    pub proof: Vec<[F; CHUNK]>,
27    /// Raw public values. Its length should be (a power of two) * CHUNK.
28    pub public_values: Vec<F>,
29    /// Merkle root of public values. The computation of this value follows the same logic of
30    /// `MemoryNode`. The merkle tree doesn't pad because the length `public_values` implies the
31    /// merkle tree is always a full binary tree.
32    pub public_values_commit: [F; CHUNK],
33}
34
35#[derive(Error, Debug)]
36pub enum UserPublicValuesProofError {
37    #[error("unexpected length: {0}")]
38    UnexpectedLength(usize),
39    #[error("incorrect proof length: {0} (expected {1})")]
40    IncorrectProofLength(usize, usize),
41    #[error("user public values do not match commitment")]
42    UserPublicValuesCommitMismatch,
43    #[error("final memory root mismatch")]
44    FinalMemoryRootMismatch,
45}
46
47impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
48    /// Computes the proof of the public values from the final memory state.
49    /// Assumption:
50    /// - `num_public_values` is a power of two * CHUNK. It cannot be 0.
51    // TODO[jpw]: this currently reconstructs the merkle tree from final memory; we should avoid
52    // this. We should make this a function within SystemChipComplex
53    #[instrument(name = "compute_user_public_values_proof", skip_all)]
54    pub fn compute(
55        memory_dimensions: MemoryDimensions,
56        num_public_values: usize,
57        hasher: &(impl Hasher<CHUNK, F> + Sync),
58        final_memory: &MemoryImage,
59    ) -> Self {
60        let proof = compute_merkle_proof_to_user_public_values_root(
61            memory_dimensions,
62            num_public_values,
63            hasher,
64            final_memory,
65        );
66        let public_values = extract_public_values(num_public_values, final_memory)
67            .iter()
68            .map(|&x| F::from_canonical_u8(x))
69            .collect_vec();
70        let public_values_commit = hasher.merkle_root(&public_values);
71        UserPublicValuesProof {
72            proof,
73            public_values,
74            public_values_commit,
75        }
76    }
77
78    pub fn verify(
79        &self,
80        hasher: &impl Hasher<CHUNK, F>,
81        memory_dimensions: MemoryDimensions,
82        final_memory_root: [F; CHUNK],
83    ) -> Result<(), UserPublicValuesProofError> {
84        // Verify user public values Merkle proof:
85        // 0. Get correct indices for Merkle proof based on memory dimensions
86        // 1. Verify user public values commitment with respect to the final memory root.
87        // 2. Compare user public values commitment with Merkle root of user public values.
88        let pv_commit = self.public_values_commit;
89        // 0.
90        let pv_as = PUBLIC_VALUES_AS;
91        let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
92        let pvs = &self.public_values;
93        if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() {
94            return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
95        }
96        let pv_height = log2_strict_usize(pvs.len() / CHUNK);
97        let proof_len = memory_dimensions.overall_height() - pv_height;
98        let idx_prefix = pv_start_idx >> pv_height;
99        // 1.
100        if self.proof.len() != proof_len {
101            return Err(UserPublicValuesProofError::IncorrectProofLength(
102                self.proof.len(),
103                proof_len,
104            ));
105        }
106        let mut curr_root = pv_commit;
107        for (i, sibling_hash) in self.proof.iter().enumerate() {
108            curr_root = if idx_prefix & (1 << i) != 0 {
109                hasher.compress(sibling_hash, &curr_root)
110            } else {
111                hasher.compress(&curr_root, sibling_hash)
112            }
113        }
114        if curr_root != final_memory_root {
115            return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
116        }
117        // 2. Compute merkle root of public values
118        if hasher.merkle_root(pvs) != pv_commit {
119            return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
120        }
121
122        Ok(())
123    }
124}
125
126fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
127    memory_dimensions: MemoryDimensions,
128    num_public_values: usize,
129    hasher: &(impl Hasher<CHUNK, F> + Sync),
130    final_memory: &MemoryImage,
131) -> Vec<[F; CHUNK]> {
132    assert_eq!(
133        num_public_values % CHUNK,
134        0,
135        "num_public_values must be a multiple of memory chunk {CHUNK}"
136    );
137    let tree = MerkleTree::<F, CHUNK>::from_memory(final_memory, &memory_dimensions, hasher);
138    let num_pv_chunks: usize = num_public_values / CHUNK;
139    // This enforces the number of public values cannot be 0.
140    assert!(
141        num_pv_chunks.is_power_of_two(),
142        "pv_height must be a power of two"
143    );
144    let pv_height = log2_strict_usize(num_pv_chunks);
145    let address_leading_zeros = memory_dimensions.address_height - pv_height;
146
147    let mut cur_node_idx = 1; // root
148    let mut proof = Vec::with_capacity(memory_dimensions.addr_space_height + address_leading_zeros);
149    for i in 0..memory_dimensions.addr_space_height {
150        let bit = 1 << (memory_dimensions.addr_space_height - i - 1);
151        if (PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET) & bit != 0 {
152            proof.push(tree.get_node(cur_node_idx * 2));
153            cur_node_idx = cur_node_idx * 2 + 1;
154        } else {
155            proof.push(tree.get_node(cur_node_idx * 2 + 1));
156            cur_node_idx *= 2;
157        }
158    }
159    for _ in 0..address_leading_zeros {
160        // always go left
161        proof.push(tree.get_node(cur_node_idx * 2 + 1));
162        cur_node_idx *= 2;
163    }
164    proof.reverse();
165    proof
166}
167
168pub fn extract_public_values(num_public_values: usize, final_memory: &MemoryImage) -> Vec<u8> {
169    let mut public_values: Vec<u8> = {
170        assert_eq!(
171            final_memory.config[PUBLIC_VALUES_AS as usize].layout,
172            MemoryCellType::U8
173        );
174        final_memory.mem[PUBLIC_VALUES_AS as usize]
175            .as_slice()
176            .to_vec()
177    };
178
179    assert!(
180        public_values.len() >= num_public_values,
181        "Public values address space has {} elements, but configuration has num_public_values={}",
182        public_values.len(),
183        num_public_values
184    );
185    public_values.truncate(num_public_values);
186    public_values
187}
188
189#[cfg(test)]
190mod tests {
191    use openvm_stark_backend::p3_field::FieldAlgebra;
192    use openvm_stark_sdk::p3_baby_bear::BabyBear;
193
194    use super::UserPublicValuesProof;
195    use crate::{
196        arch::{hasher::poseidon2::vm_poseidon2_hasher, MemoryConfig, SystemConfig},
197        system::memory::{
198            merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree},
199            online::GuestMemory,
200            AddressMap, CHUNK,
201        },
202    };
203
204    type F = BabyBear;
205    #[test]
206    fn test_public_value_happy_path() {
207        let mut vm_config = SystemConfig::default().without_continuations();
208        vm_config.memory_config.addr_space_height = 4;
209        vm_config.memory_config.pointer_max_bits = 5;
210        let memory_dimensions = vm_config.memory_config.memory_dimensions();
211        let num_public_values = 16;
212        let mut addr_spaces_config = MemoryConfig::empty_address_space_configs(4);
213        addr_spaces_config[PUBLIC_VALUES_AS as usize].num_cells = num_public_values;
214        let mut memory = GuestMemory {
215            memory: AddressMap::new(addr_spaces_config),
216        };
217        unsafe {
218            memory.write::<u8, 4>(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]);
219        }
220        let mut expected_pvs = F::zero_vec(num_public_values);
221        expected_pvs[15] = F::ONE;
222
223        let hasher = vm_poseidon2_hasher();
224        let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
225            memory_dimensions,
226            num_public_values,
227            &hasher,
228            &memory.memory,
229        );
230        assert_eq!(pv_proof.public_values, expected_pvs);
231        let final_memory_root =
232            MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root();
233        pv_proof
234            .verify(&hasher, memory_dimensions, final_memory_root)
235            .unwrap();
236    }
237}