openvm_circuit/system/memory/tree/
public_values.rs
1use std::{collections::BTreeMap, sync::Arc};
2
3use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use crate::{
8 arch::hasher::Hasher,
9 system::memory::{
10 dimensions::MemoryDimensions, paged_vec::Address, tree::MemoryNode, MemoryImage,
11 },
12};
13
14pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = 2;
15
16#[derive(Clone, Debug, Serialize, Deserialize)]
18#[serde(bound(
19 serialize = "F: Serialize, [F; CHUNK]: Serialize",
20 deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
21))]
22pub struct UserPublicValuesProof<const CHUNK: usize, F> {
23 pub proof: Vec<[F; CHUNK]>,
26 pub public_values: Vec<F>,
28 pub public_values_commit: [F; CHUNK],
32}
33
34#[derive(Error, Debug)]
35pub enum UserPublicValuesProofError {
36 #[error("unexpected length: {0}")]
37 UnexpectedLength(usize),
38 #[error("incorrect proof length: {0} (expected {1})")]
39 IncorrectProofLength(usize, usize),
40 #[error("user public values do not match commitment")]
41 UserPublicValuesCommitMismatch,
42 #[error("final memory root mismatch")]
43 FinalMemoryRootMismatch,
44}
45
46impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
47 pub fn compute(
51 memory_dimensions: MemoryDimensions,
52 num_public_values: usize,
53 hasher: &(impl Hasher<CHUNK, F> + Sync),
54 final_memory: &MemoryImage<F>,
55 ) -> Self {
56 let proof = compute_merkle_proof_to_user_public_values_root(
57 memory_dimensions,
58 num_public_values,
59 hasher,
60 final_memory,
61 );
62 let public_values =
63 extract_public_values(&memory_dimensions, num_public_values, final_memory);
64 let public_values_commit = hasher.merkle_root(&public_values);
65 UserPublicValuesProof {
66 proof,
67 public_values,
68 public_values_commit,
69 }
70 }
71
72 pub fn verify(
73 &self,
74 hasher: &impl Hasher<CHUNK, F>,
75 memory_dimensions: MemoryDimensions,
76 final_memory_root: [F; CHUNK],
77 ) -> Result<(), UserPublicValuesProofError> {
78 let pv_commit = self.public_values_commit;
83 let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
85 let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
86 let pvs = &self.public_values;
87 if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() {
88 return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
89 }
90 let pv_height = log2_strict_usize(pvs.len() / CHUNK);
91 let proof_len = memory_dimensions.overall_height() - pv_height;
92 let idx_prefix = pv_start_idx >> pv_height;
93 if self.proof.len() != proof_len {
95 return Err(UserPublicValuesProofError::IncorrectProofLength(
96 self.proof.len(),
97 proof_len,
98 ));
99 }
100 let mut curr_root = pv_commit;
101 for (i, sibling_hash) in self.proof.iter().enumerate() {
102 curr_root = if idx_prefix & (1 << i) != 0 {
103 hasher.compress(sibling_hash, &curr_root)
104 } else {
105 hasher.compress(&curr_root, sibling_hash)
106 }
107 }
108 if curr_root != final_memory_root {
109 return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
110 }
111 if hasher.merkle_root(pvs) != pv_commit {
113 return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
114 }
115
116 Ok(())
117 }
118}
119
120fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
121 memory_dimensions: MemoryDimensions,
122 num_public_values: usize,
123 hasher: &(impl Hasher<CHUNK, F> + Sync),
124 final_memory: &MemoryImage<F>,
125) -> Vec<[F; CHUNK]> {
126 assert_eq!(
127 num_public_values % CHUNK,
128 0,
129 "num_public_values must be a multiple of memory chunk {CHUNK}"
130 );
131 let root = MemoryNode::tree_from_memory(memory_dimensions, final_memory, hasher);
132 let num_pv_chunks: usize = num_public_values / CHUNK;
133 assert!(
135 num_pv_chunks.is_power_of_two(),
136 "pv_height must be a power of two"
137 );
138 let pv_height = log2_strict_usize(num_pv_chunks);
139 let address_leading_zeros = memory_dimensions.address_height - pv_height;
140
141 let mut curr_node = Arc::new(root);
142 let mut proof = Vec::with_capacity(memory_dimensions.as_height + address_leading_zeros);
143 for i in 0..memory_dimensions.as_height {
144 let bit = 1 << (memory_dimensions.as_height - i - 1);
145 if let MemoryNode::NonLeaf { left, right, .. } = curr_node.as_ref().clone() {
146 if PUBLIC_VALUES_ADDRESS_SPACE_OFFSET & bit != 0 {
147 curr_node = right;
148 proof.push(left.hash());
149 } else {
150 curr_node = left;
151 proof.push(right.hash());
152 }
153 } else {
154 unreachable!()
155 }
156 }
157 for _ in 0..address_leading_zeros {
158 if let MemoryNode::NonLeaf { left, right, .. } = curr_node.as_ref().clone() {
159 curr_node = left;
160 proof.push(right.hash());
161 } else {
162 unreachable!()
163 }
164 }
165 proof.reverse();
166 proof
167}
168
169pub fn extract_public_values<F: PrimeField32>(
170 memory_dimensions: &MemoryDimensions,
171 num_public_values: usize,
172 final_memory: &MemoryImage<F>,
173) -> Vec<F> {
174 let f_as_start = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
176 let f_as_end = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset + 1;
177
178 let final_memory: BTreeMap<Address, F> = final_memory.items().collect();
181
182 let used_pvs: Vec<_> = final_memory
183 .range((f_as_start, 0)..(f_as_end, 0))
184 .map(|(&(_, pointer), &value)| (pointer as usize, value))
185 .collect();
186 if let Some(&last_pv) = used_pvs.last() {
187 assert!(
188 last_pv.0 < num_public_values || last_pv.1 == F::ZERO,
189 "Last public value is out of bounds"
190 );
191 }
192 let mut public_values = F::zero_vec(num_public_values);
193 for (i, pv) in used_pvs {
194 if i < num_public_values {
195 public_values[i] = pv;
196 }
197 }
198 public_values
199}
200
201#[cfg(test)]
202mod tests {
203 use openvm_stark_backend::p3_field::FieldAlgebra;
204 use openvm_stark_sdk::p3_baby_bear::BabyBear;
205
206 use super::{UserPublicValuesProof, PUBLIC_VALUES_ADDRESS_SPACE_OFFSET};
207 use crate::{
208 arch::{hasher::poseidon2::vm_poseidon2_hasher, SystemConfig},
209 system::memory::{paged_vec::AddressMap, tree::MemoryNode, CHUNK},
210 };
211
212 type F = BabyBear;
213 #[test]
214 fn test_public_value_happy_path() {
215 let mut vm_config = SystemConfig::default();
216 vm_config.memory_config.as_height = 4;
217 vm_config.memory_config.pointer_max_bits = 5;
218 let memory_dimensions = vm_config.memory_config.memory_dimensions();
219 let pv_as = PUBLIC_VALUES_ADDRESS_SPACE_OFFSET + memory_dimensions.as_offset;
220 let num_public_values = 16;
221 let memory = AddressMap::from_iter(
222 memory_dimensions.as_offset,
223 1 << memory_dimensions.as_height,
224 1 << memory_dimensions.address_height,
225 [((pv_as, 15), F::ONE)],
226 );
227 let mut expected_pvs = F::zero_vec(num_public_values);
228 expected_pvs[15] = F::ONE;
229
230 let hasher = vm_poseidon2_hasher();
231 let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
232 memory_dimensions,
233 num_public_values,
234 &hasher,
235 &memory,
236 );
237 assert_eq!(pv_proof.public_values, expected_pvs);
238 let final_memory_root = MemoryNode::tree_from_memory(memory_dimensions, &memory, &hasher);
239 pv_proof
240 .verify(&hasher, memory_dimensions, final_memory_root.hash())
241 .unwrap();
242 }
243}