openvm_circuit/system/memory/merkle/
public_values.rs1use itertools::Itertools;
2use openvm_stark_backend::{p3_field::PrimeField32, p3_util::log2_strict_usize};
3use serde::{Deserialize, Serialize};
4use thiserror::Error;
5use tracing::instrument;
6
7use crate::{
8 arch::{hasher::Hasher, MemoryCellType, ADDR_SPACE_OFFSET},
9 system::memory::{
10 dimensions::MemoryDimensions, merkle::tree::MerkleTree, online::LinearMemory, MemoryImage,
11 },
12};
13
14pub const PUBLIC_VALUES_AS: u32 = 3;
15pub const PUBLIC_VALUES_ADDRESS_SPACE_OFFSET: u32 = PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET;
16
17#[derive(Clone, Debug, Serialize, Deserialize)]
19#[serde(bound(
20 serialize = "F: Serialize, [F; CHUNK]: Serialize",
21 deserialize = "F: Deserialize<'de>, [F; CHUNK]: Deserialize<'de>"
22))]
23pub struct UserPublicValuesProof<const CHUNK: usize, F> {
24 pub proof: Vec<[F; CHUNK]>,
27 pub public_values: Vec<F>,
29 pub public_values_commit: [F; CHUNK],
33}
34
35#[derive(Error, Debug)]
36pub enum UserPublicValuesProofError {
37 #[error("unexpected length: {0}")]
38 UnexpectedLength(usize),
39 #[error("incorrect proof length: {0} (expected {1})")]
40 IncorrectProofLength(usize, usize),
41 #[error("user public values do not match commitment")]
42 UserPublicValuesCommitMismatch,
43 #[error("final memory root mismatch")]
44 FinalMemoryRootMismatch,
45}
46
47impl<const CHUNK: usize, F: PrimeField32> UserPublicValuesProof<CHUNK, F> {
48 #[instrument(name = "compute_user_public_values_proof", skip_all)]
54 pub fn compute(
55 memory_dimensions: MemoryDimensions,
56 num_public_values: usize,
57 hasher: &(impl Hasher<CHUNK, F> + Sync),
58 final_memory: &MemoryImage,
59 ) -> Self {
60 let proof = compute_merkle_proof_to_user_public_values_root(
61 memory_dimensions,
62 num_public_values,
63 hasher,
64 final_memory,
65 );
66 let public_values = extract_public_values(num_public_values, final_memory)
67 .iter()
68 .map(|&x| F::from_canonical_u8(x))
69 .collect_vec();
70 let public_values_commit = hasher.merkle_root(&public_values);
71 UserPublicValuesProof {
72 proof,
73 public_values,
74 public_values_commit,
75 }
76 }
77
78 pub fn verify(
79 &self,
80 hasher: &impl Hasher<CHUNK, F>,
81 memory_dimensions: MemoryDimensions,
82 final_memory_root: [F; CHUNK],
83 ) -> Result<(), UserPublicValuesProofError> {
84 let pv_commit = self.public_values_commit;
89 let pv_as = PUBLIC_VALUES_AS;
91 let pv_start_idx = memory_dimensions.label_to_index((pv_as, 0));
92 let pvs = &self.public_values;
93 if pvs.len() % CHUNK != 0 || !(pvs.len() / CHUNK).is_power_of_two() {
94 return Err(UserPublicValuesProofError::UnexpectedLength(pvs.len()));
95 }
96 let pv_height = log2_strict_usize(pvs.len() / CHUNK);
97 let proof_len = memory_dimensions.overall_height() - pv_height;
98 let idx_prefix = pv_start_idx >> pv_height;
99 if self.proof.len() != proof_len {
101 return Err(UserPublicValuesProofError::IncorrectProofLength(
102 self.proof.len(),
103 proof_len,
104 ));
105 }
106 let mut curr_root = pv_commit;
107 for (i, sibling_hash) in self.proof.iter().enumerate() {
108 curr_root = if idx_prefix & (1 << i) != 0 {
109 hasher.compress(sibling_hash, &curr_root)
110 } else {
111 hasher.compress(&curr_root, sibling_hash)
112 }
113 }
114 if curr_root != final_memory_root {
115 return Err(UserPublicValuesProofError::FinalMemoryRootMismatch);
116 }
117 if hasher.merkle_root(pvs) != pv_commit {
119 return Err(UserPublicValuesProofError::UserPublicValuesCommitMismatch);
120 }
121
122 Ok(())
123 }
124}
125
126fn compute_merkle_proof_to_user_public_values_root<const CHUNK: usize, F: PrimeField32>(
127 memory_dimensions: MemoryDimensions,
128 num_public_values: usize,
129 hasher: &(impl Hasher<CHUNK, F> + Sync),
130 final_memory: &MemoryImage,
131) -> Vec<[F; CHUNK]> {
132 assert_eq!(
133 num_public_values % CHUNK,
134 0,
135 "num_public_values must be a multiple of memory chunk {CHUNK}"
136 );
137 let tree = MerkleTree::<F, CHUNK>::from_memory(final_memory, &memory_dimensions, hasher);
138 let num_pv_chunks: usize = num_public_values / CHUNK;
139 assert!(
141 num_pv_chunks.is_power_of_two(),
142 "pv_height must be a power of two"
143 );
144 let pv_height = log2_strict_usize(num_pv_chunks);
145 let address_leading_zeros = memory_dimensions.address_height - pv_height;
146
147 let mut cur_node_idx = 1; let mut proof = Vec::with_capacity(memory_dimensions.addr_space_height + address_leading_zeros);
149 for i in 0..memory_dimensions.addr_space_height {
150 let bit = 1 << (memory_dimensions.addr_space_height - i - 1);
151 if (PUBLIC_VALUES_AS - ADDR_SPACE_OFFSET) & bit != 0 {
152 proof.push(tree.get_node(cur_node_idx * 2));
153 cur_node_idx = cur_node_idx * 2 + 1;
154 } else {
155 proof.push(tree.get_node(cur_node_idx * 2 + 1));
156 cur_node_idx *= 2;
157 }
158 }
159 for _ in 0..address_leading_zeros {
160 proof.push(tree.get_node(cur_node_idx * 2 + 1));
162 cur_node_idx *= 2;
163 }
164 proof.reverse();
165 proof
166}
167
168pub fn extract_public_values(num_public_values: usize, final_memory: &MemoryImage) -> Vec<u8> {
169 let mut public_values: Vec<u8> = {
170 assert_eq!(
171 final_memory.config[PUBLIC_VALUES_AS as usize].layout,
172 MemoryCellType::U8
173 );
174 final_memory.mem[PUBLIC_VALUES_AS as usize]
175 .as_slice()
176 .to_vec()
177 };
178
179 assert!(
180 public_values.len() >= num_public_values,
181 "Public values address space has {} elements, but configuration has num_public_values={}",
182 public_values.len(),
183 num_public_values
184 );
185 public_values.truncate(num_public_values);
186 public_values
187}
188
189#[cfg(test)]
190mod tests {
191 use openvm_stark_backend::p3_field::FieldAlgebra;
192 use openvm_stark_sdk::p3_baby_bear::BabyBear;
193
194 use super::UserPublicValuesProof;
195 use crate::{
196 arch::{hasher::poseidon2::vm_poseidon2_hasher, MemoryConfig, SystemConfig},
197 system::memory::{
198 merkle::{public_values::PUBLIC_VALUES_AS, tree::MerkleTree},
199 online::GuestMemory,
200 AddressMap, CHUNK,
201 },
202 };
203
204 type F = BabyBear;
205 #[test]
206 fn test_public_value_happy_path() {
207 let mut vm_config = SystemConfig::default().without_continuations();
208 vm_config.memory_config.addr_space_height = 4;
209 vm_config.memory_config.pointer_max_bits = 5;
210 let memory_dimensions = vm_config.memory_config.memory_dimensions();
211 let num_public_values = 16;
212 let mut addr_spaces_config = MemoryConfig::empty_address_space_configs(4);
213 addr_spaces_config[PUBLIC_VALUES_AS as usize].num_cells = num_public_values;
214 let mut memory = GuestMemory {
215 memory: AddressMap::new(addr_spaces_config),
216 };
217 unsafe {
218 memory.write::<u8, 4>(PUBLIC_VALUES_AS, 12, [0, 0, 0, 1]);
219 }
220 let mut expected_pvs = F::zero_vec(num_public_values);
221 expected_pvs[15] = F::ONE;
222
223 let hasher = vm_poseidon2_hasher();
224 let pv_proof = UserPublicValuesProof::<{ CHUNK }, F>::compute(
225 memory_dimensions,
226 num_public_values,
227 &hasher,
228 &memory.memory,
229 );
230 assert_eq!(pv_proof.public_values, expected_pvs);
231 let final_memory_root =
232 MerkleTree::from_memory(&memory.memory, &memory_dimensions, &hasher).root();
233 pv_proof
234 .verify(&hasher, memory_dimensions, final_memory_root)
235 .unwrap();
236 }
237}