1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 iter,
5 sync::Arc,
6};
7
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_stark_backend::{
10 config::{StarkGenericConfig, Val},
11 interaction::{InteractionBuilder, PermutationCheckBus},
12 p3_air::{Air, AirBuilder, BaseAir},
13 p3_field::{FieldAlgebra, PrimeField32},
14 p3_matrix::{dense::RowMajorMatrix, Matrix},
15 p3_maybe_rayon::prelude::*,
16 prover::types::AirProofInput,
17 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
18 AirRef, Chip, ChipUsageGetter,
19};
20use rustc_hash::FxHashSet;
21
22use super::merkle::SerialReceiver;
23use crate::{
24 arch::hasher::Hasher,
25 system::memory::{
26 dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage,
27 TimestampedEquipartition, INITIAL_TIMESTAMP,
28 },
29};
30
31#[repr(C)]
34#[derive(Debug, AlignedBorrow)]
35pub struct PersistentBoundaryCols<T, const CHUNK: usize> {
36 pub expand_direction: T,
40 pub address_space: T,
41 pub leaf_label: T,
42 pub values: [T; CHUNK],
43 pub hash: [T; CHUNK],
44 pub timestamp: T,
45}
46
47#[derive(Clone, Debug)]
55pub struct PersistentBoundaryAir<const CHUNK: usize> {
56 pub memory_dims: MemoryDimensions,
57 pub memory_bus: MemoryBus,
58 pub merkle_bus: PermutationCheckBus,
59 pub compression_bus: PermutationCheckBus,
60}
61
62impl<const CHUNK: usize, F> BaseAir<F> for PersistentBoundaryAir<CHUNK> {
63 fn width(&self) -> usize {
64 PersistentBoundaryCols::<F, CHUNK>::width()
65 }
66}
67
68impl<const CHUNK: usize, F> BaseAirWithPublicValues<F> for PersistentBoundaryAir<CHUNK> {}
69impl<const CHUNK: usize, F> PartitionedBaseAir<F> for PersistentBoundaryAir<CHUNK> {}
70
71impl<const CHUNK: usize, AB: InteractionBuilder> Air<AB> for PersistentBoundaryAir<CHUNK> {
72 fn eval(&self, builder: &mut AB) {
73 let main = builder.main();
74 let local = main.row_slice(0);
75 let local: &PersistentBoundaryCols<AB::Var, CHUNK> = (*local).borrow();
76
77 builder.assert_eq(
79 local.expand_direction,
80 local.expand_direction * local.expand_direction * local.expand_direction,
81 );
82
83 builder
87 .when(local.expand_direction * (local.expand_direction + AB::F::ONE))
88 .assert_zero(local.timestamp);
89
90 let mut expand_fields = vec![
91 local.expand_direction.into(),
94 AB::Expr::ZERO,
95 local.address_space - AB::F::from_canonical_u32(self.memory_dims.as_offset),
96 local.leaf_label.into(),
97 ];
98 expand_fields.extend(local.hash.map(Into::into));
99 self.merkle_bus
100 .interact(builder, expand_fields, local.expand_direction.into());
101
102 self.compression_bus.interact(
103 builder,
104 iter::empty()
105 .chain(local.values.map(Into::into))
106 .chain(iter::repeat_n(AB::Expr::ZERO, CHUNK))
107 .chain(local.hash.map(Into::into)),
108 local.expand_direction * local.expand_direction,
109 );
110
111 self.memory_bus
112 .send(
113 MemoryAddress::new(
114 local.address_space,
115 local.leaf_label * AB::F::from_canonical_usize(CHUNK),
116 ),
117 local.values.to_vec(),
118 local.timestamp,
119 )
120 .eval(builder, local.expand_direction);
121 }
122}
123
124pub struct PersistentBoundaryChip<F, const CHUNK: usize> {
125 pub air: PersistentBoundaryAir<CHUNK>,
126 touched_labels: TouchedLabels<F, CHUNK>,
127 overridden_height: Option<usize>,
128}
129
130#[derive(Debug)]
131enum TouchedLabels<F, const CHUNK: usize> {
132 Running(FxHashSet<(u32, u32)>),
133 Final(Vec<FinalTouchedLabel<F, CHUNK>>),
134}
135
136#[derive(Debug)]
137struct FinalTouchedLabel<F, const CHUNK: usize> {
138 address_space: u32,
139 label: u32,
140 init_values: [F; CHUNK],
141 final_values: [F; CHUNK],
142 init_hash: [F; CHUNK],
143 final_hash: [F; CHUNK],
144 final_timestamp: u32,
145}
146
147impl<F: PrimeField32, const CHUNK: usize> Default for TouchedLabels<F, CHUNK> {
148 fn default() -> Self {
149 Self::Running(FxHashSet::default())
150 }
151}
152
153impl<F: PrimeField32, const CHUNK: usize> TouchedLabels<F, CHUNK> {
154 fn touch(&mut self, address_space: u32, label: u32) {
155 match self {
156 TouchedLabels::Running(touched_labels) => {
157 touched_labels.insert((address_space, label));
158 }
159 _ => panic!("Cannot touch after finalization"),
160 }
161 }
162 fn len(&self) -> usize {
163 match self {
164 TouchedLabels::Running(touched_labels) => touched_labels.len(),
165 TouchedLabels::Final(touched_labels) => touched_labels.len(),
166 }
167 }
168}
169
170impl<const CHUNK: usize, F: PrimeField32> PersistentBoundaryChip<F, CHUNK> {
171 pub fn new(
172 memory_dimensions: MemoryDimensions,
173 memory_bus: MemoryBus,
174 merkle_bus: PermutationCheckBus,
175 compression_bus: PermutationCheckBus,
176 ) -> Self {
177 Self {
178 air: PersistentBoundaryAir {
179 memory_dims: memory_dimensions,
180 memory_bus,
181 merkle_bus,
182 compression_bus,
183 },
184 touched_labels: Default::default(),
185 overridden_height: None,
186 }
187 }
188
189 pub fn set_overridden_height(&mut self, overridden_height: usize) {
190 self.overridden_height = Some(overridden_height);
191 }
192
193 pub fn touch_range(&mut self, address_space: u32, pointer: u32, len: u32) {
194 let start_label = pointer / CHUNK as u32;
195 let end_label = (pointer + len - 1) / CHUNK as u32;
196 for label in start_label..=end_label {
197 self.touched_labels.touch(address_space, label);
198 }
199 }
200
201 pub fn finalize<H>(
202 &mut self,
203 initial_memory: &MemoryImage<F>,
204 final_memory: &TimestampedEquipartition<F, CHUNK>,
205 hasher: &mut H,
206 ) where
207 H: Hasher<CHUNK, F> + Sync + for<'a> SerialReceiver<&'a [F]>,
208 {
209 match &mut self.touched_labels {
210 TouchedLabels::Running(touched_labels) => {
211 let final_touched_labels: Vec<_> = touched_labels
212 .par_iter()
213 .map(|&(address_space, label)| {
214 let pointer = label * CHUNK as u32;
215 let init_values = array::from_fn(|i| {
216 *initial_memory
217 .get(&(address_space, pointer + i as u32))
218 .unwrap_or(&F::ZERO)
219 });
220 let initial_hash = hasher.hash(&init_values);
221 let timestamped_values = final_memory.get(&(address_space, label)).unwrap();
222 let final_hash = hasher.hash(×tamped_values.values);
223 FinalTouchedLabel {
224 address_space,
225 label,
226 init_values,
227 final_values: timestamped_values.values,
228 init_hash: initial_hash,
229 final_hash,
230 final_timestamp: timestamped_values.timestamp,
231 }
232 })
233 .collect();
234 for l in &final_touched_labels {
235 hasher.receive(&l.init_values);
236 hasher.receive(&l.final_values);
237 }
238 self.touched_labels = TouchedLabels::Final(final_touched_labels);
239 }
240 _ => panic!("Cannot finalize after finalization"),
241 }
242 }
243}
244
245impl<const CHUNK: usize, SC: StarkGenericConfig> Chip<SC> for PersistentBoundaryChip<Val<SC>, CHUNK>
246where
247 Val<SC>: PrimeField32,
248{
249 fn air(&self) -> AirRef<SC> {
250 Arc::new(self.air.clone())
251 }
252
253 fn generate_air_proof_input(self) -> AirProofInput<SC> {
254 let trace = {
255 let width = PersistentBoundaryCols::<Val<SC>, CHUNK>::width();
256 let mut height = (2 * self.touched_labels.len()).next_power_of_two();
258 if let Some(mut oh) = self.overridden_height {
259 oh = oh.next_power_of_two();
260 assert!(
261 oh >= height,
262 "Overridden height is less than the required height"
263 );
264 height = oh;
265 }
266 let mut rows = Val::<SC>::zero_vec(height * width);
267
268 let touched_labels = match self.touched_labels {
269 TouchedLabels::Final(touched_labels) => touched_labels,
270 _ => panic!("Cannot generate trace before finalization"),
271 };
272
273 rows.par_chunks_mut(2 * width)
274 .zip(touched_labels.into_par_iter())
275 .for_each(|(row, touched_label)| {
276 let (initial_row, final_row) = row.split_at_mut(width);
277 *initial_row.borrow_mut() = PersistentBoundaryCols {
278 expand_direction: Val::<SC>::ONE,
279 address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
280 leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
281 values: touched_label.init_values,
282 hash: touched_label.init_hash,
283 timestamp: Val::<SC>::from_canonical_u32(INITIAL_TIMESTAMP),
284 };
285
286 *final_row.borrow_mut() = PersistentBoundaryCols {
287 expand_direction: Val::<SC>::NEG_ONE,
288 address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
289 leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
290 values: touched_label.final_values,
291 hash: touched_label.final_hash,
292 timestamp: Val::<SC>::from_canonical_u32(touched_label.final_timestamp),
293 };
294 });
295 RowMajorMatrix::new(rows, width)
296 };
297 AirProofInput::simple_no_pis(trace)
298 }
299}
300
301impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for PersistentBoundaryChip<F, CHUNK> {
302 fn air_name(&self) -> String {
303 "Boundary".to_string()
304 }
305
306 fn current_trace_height(&self) -> usize {
307 2 * self.touched_labels.len()
308 }
309
310 fn trace_width(&self) -> usize {
311 PersistentBoundaryCols::<F, CHUNK>::width()
312 }
313}