openvm_circuit/system/memory/merkle/
public_values.rs1use itertools::Itertools;
2use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use tracing::instrument;
6
7use crate::{
8 arch::{hasher::Hasher, MemoryCellType, ADDR_SPACE_OFFSET},
9 system::memory::{dimensions::MemoryDimensions, online::LinearMemory, MemoryImage},
10};
11
12pub const PUBLIC_VALUES_AS: u32 = 3;
13pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET;
14
15#[derive(Clone, Debug, Serialize, Deserialize)]
17#[serde(bound(
18 serialize = "F: Serialize, [F; CHUNK]: Serialize",
19 deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
20))]
21pub struct UserPublicValuesProof<const CHUNK: usize, F> {
22 pub proof: Vec<[F; CHUNK]>,
25 pub public_values: Vec<F>,
27 pub public_values_commit: [F; CHUNK],
31}
32
33#[derive(Error, Debug)]
34pub enum UserPublicValuesProofError {
35 #[error("unexpected length: {0}")]
36 UnexpectedLength(usize),
37 #[error("incorrect proof length: {0} (expected {1})")]
38 IncorrectProofLength(usize, usize),
39 #[error("user public values do not match commitment")]
40 UserPublicValuesCommitMismatch,
41 #[error("final memory root mismatch")]
42 FinalMemoryRootMismatch,
43}
44
45impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
46 #[instrument(name = "compute_user_public_values_proof", skip_all)]
54 pub fn compute(
55 memory_dimensions: MemoryDimensions,
56 num_public_values: usize,
57 hasher: &(impl Hasher<CHUNK, F> + Sync),
58 final_memory: &MemoryImage,
59 top_tree: &[[F; CHUNK]],
60 ) -> Self {
61 let public_values = extract_public_values(num_public_values, final_memory)
62 .iter()
63 .map(|&x| F::from_canonical_u8(x))
64 .collect_vec();
65 let public_values_commit = hasher.merkle_root(&public_values);
66 let proof = compute_merkle_proof_to_user_public_values_root(
67 memory_dimensions,
68 num_public_values,
69 hasher,
70 top_tree,
71 );
72 UserPublicValuesProof {
73 proof,
74 public_values,
75 public_values_commit,
76 }
77 }
78
79 pub fn verify(
80 &self,
81 hasher: &impl Hasher<CHUNK, F>,
82 memory_dimensions: MemoryDimensions,
83 final_memory_root: [F; CHUNK],
84 ) -> Result<(), UserPublicValuesProofError> {
85 let pv_commit = self.public_values_commit;
90 let pv_as = PUBLIC_VALUES_AS;
92 let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
93 let pvs = &self.public_values;
94 if !pvs.len().is_multiple_of(CHUNK) || !(pvs.len() / CHUNK).is_power_of_two() {
95 return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
96 }
97 let pv_height = log2_strict_usize(pvs.len() / CHUNK);
98 let proof_len = memory_dimensions.overall_height() - pv_height;
99 let idx_prefix = pv_start_idx >> pv_height;
100 if self.proof.len() != proof_len {
102 return Err(UserPublicValuesProofError::IncorrectProofLength(
103 self.proof.len(),
104 proof_len,
105 ));
106 }
107 let mut curr_root = pv_commit;
108 for (i, sibling_hash) in self.proof.iter().enumerate() {
109 curr_root = if idx_prefix & (1 << i) != 0 {
110 hasher.compress(sibling_hash, &curr_root)
111 } else {
112 hasher.compress(&curr_root, sibling_hash)
113 }
114 }
115 if curr_root != final_memory_root {
116 return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
117 }
118 if hasher.merkle_root(pvs) != pv_commit {
120 return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
121 }
122
123 Ok(())
124 }
125}
126
127fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
128 memory_dimensions: MemoryDimensions,
129 num_public_values: usize,
130 hasher: &(impl Hasher<CHUNK, F> + Sync),
131 top_tree: &[[F; CHUNK]],
132) -> Vec<[F; CHUNK]> {
133 assert_eq!(
134 num_public_values % CHUNK,
135 0,
136 "num_public_values must be a multiple of memory chunk {CHUNK}"
137 );
138 let address_height = memory_dimensions.address_height;
139 let addr_space_height = memory_dimensions.addr_space_height;
140 assert_eq!(top_tree.len(), (2 << addr_space_height) - 1);
141 let num_pv_chunks: usize = num_public_values / CHUNK;
142 assert!(
144 num_pv_chunks.is_power_of_two(),
145 "pv_height must be a power of two"
146 );
147 let pv_height = log2_strict_usize(num_pv_chunks);
148 let address_leading_zeros = address_height - pv_height;
149
150 let mut cur_node_idx = 1; let mut proof = Vec::with_capacity(addr_space_height + address_leading_zeros);
152 let zero_nodes: Vec<_> = (0..address_height)
153 .scan(hasher.hash(&[F::ZERO; CHUNK]), |acc, _| {
154 let result = Some(*acc);
155 *acc = hasher.compress(acc, acc);
156 result
157 })
158 .collect();
159 for i in 0..addr_space_height {
160 let bit = 1 << (memory_dimensions.addr_space_height - i - 1);
161 if (PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET) & bit != 0 {
163 proof.push(top_tree[cur_node_idx * 2 - 1]);
164 cur_node_idx = cur_node_idx * 2 + 1;
165 } else {
166 proof.push(top_tree[cur_node_idx * 2]);
167 cur_node_idx *= 2;
168 }
169 }
170 for i in 0..address_leading_zeros {
171 proof.push(zero_nodes[address_height - 1 - i]);
173 }
174 proof.reverse();
175 proof
176}
177
178pub fn extract_public_values(num_public_values: usize, final_memory: &MemoryImage) -> Vec<u8> {
179 let mut public_values: Vec<u8> = {
180 assert_eq!(
181 final_memory.config[PUBLIC_VALUES_AS as usize].layout,
182 MemoryCellType::U8
183 );
184 final_memory.mem[PUBLIC_VALUES_AS as usize]
185 .as_slice()
186 .to_vec()
187 };
188
189 assert!(
190 public_values.len() >= num_public_values,
191 "Public values address space has {} elements, but configuration has num_public_values={}",
192 public_values.len(),
193 num_public_values
194 );
195 public_values.truncate(num_public_values);
196 public_values
197}
198
199#[cfg(test)]
200mod tests {
201 use openvm_stark_backend::p3_field::FieldAlgebra;
202 use openvm_stark_sdk::p3_baby_bear::BabyBear;
203
204 use super::UserPublicValuesProof;
205 use crate::{
206 arch::{hasher::poseidon2::vm_poseidon2_hasher, MemoryConfig, SystemConfig},
207 system::memory::{
208 merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree},
209 online::GuestMemory,
210 AddressMap, CHUNK,
211 },
212 };
213
214 type F = BabyBear;
215 #[test]
216 fn test_public_value_happy_path() {
217 let mut vm_config = SystemConfig::default().without_continuations();
218 let addr_space_height = 4;
219 vm_config.memory_config.addr_space_height = addr_space_height;
220 vm_config.memory_config.pointer_max_bits = 5;
221 let memory_dimensions = vm_config.memory_config.memory_dimensions();
222 let num_public_values = 16;
223 let mut addr_spaces_config = MemoryConfig::empty_address_space_configs(4);
224 addr_spaces_config[PUBLIC_VALUES_AS as usize].num_cells = num_public_values;
225 let mut memory = GuestMemory {
226 memory: AddressMap::new(addr_spaces_config),
227 };
228 unsafe {
229 memory.write::<u8, 4>(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]);
230 }
231 let mut expected_pvs = F::zero_vec(num_public_values);
232 expected_pvs[15] = F::ONE;
233
234 let hasher = vm_poseidon2_hasher();
235 let tree = MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher);
236 let top_tree = tree.top_tree(addr_space_height);
237 let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
238 memory_dimensions,
239 num_public_values,
240 &hasher,
241 &memory.memory,
242 &top_tree,
243 );
244 assert_eq!(pv_proof.public_values, expected_pvs);
245 let final_memory_root =
246 MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root();
247 pv_proof
248 .verify(&hasher, memory_dimensions, final_memory_root)
249 .unwrap();
250 }
251}