1use std::{
2 array,
3 borrow::{Borrow, BorrowMut},
4 iter,
5 sync::Arc,
6};
7
8use openvm_circuit_primitives_derive::AlignedBorrow;
9use openvm_stark_backend::{
10 config::{StarkGenericConfig, Val},
11 interaction::{InteractionBuilder, PermutationCheckBus},
12 p3_air::{Air, AirBuilder, BaseAir},
13 p3_field::{FieldAlgebra, PrimeField32},
14 p3_matrix::{dense::RowMajorMatrix, Matrix},
15 p3_maybe_rayon::prelude::*,
16 prover::{cpu::CpuBackend, types::AirProvingContext},
17 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
18 Chip, ChipUsageGetter,
19};
20use rustc_hash::FxHashSet;
21use tracing::instrument;
22
23use super::{merkle::SerialReceiver, online::INITIAL_TIMESTAMP, TimestampedValues};
24use crate::{
25 arch::{hasher::Hasher, ADDR_SPACE_OFFSET},
26 system::memory::{
27 dimensions::MemoryDimensions, offline_checker::MemoryBus, MemoryAddress, MemoryImage,
28 TimestampedEquipartition,
29 },
30};
31
32#[repr(C)]
35#[derive(Debug, AlignedBorrow)]
36pub struct PersistentBoundaryCols<T, const CHUNK: usize> {
37 pub expand_direction: T,
41 pub address_space: T,
42 pub leaf_label: T,
43 pub values: [T; CHUNK],
44 pub hash: [T; CHUNK],
45 pub timestamp: T,
46}
47
48#[derive(Clone, Debug)]
56pub struct PersistentBoundaryAir<const CHUNK: usize> {
57 pub memory_dims: MemoryDimensions,
58 pub memory_bus: MemoryBus,
59 pub merkle_bus: PermutationCheckBus,
60 pub compression_bus: PermutationCheckBus,
61}
62
63impl<const CHUNK: usize, F> BaseAir<F> for PersistentBoundaryAir<CHUNK> {
64 fn width(&self) -> usize {
65 PersistentBoundaryCols::<F, CHUNK>::width()
66 }
67}
68
69impl<const CHUNK: usize, F> BaseAirWithPublicValues<F> for PersistentBoundaryAir<CHUNK> {}
70impl<const CHUNK: usize, F> PartitionedBaseAir<F> for PersistentBoundaryAir<CHUNK> {}
71
72impl<const CHUNK: usize, AB: InteractionBuilder> Air<AB> for PersistentBoundaryAir<CHUNK> {
73 fn eval(&self, builder: &mut AB) {
74 let main = builder.main();
75 let local = main.row_slice(0);
76 let local: &PersistentBoundaryCols<AB::Var, CHUNK> = (*local).borrow();
77
78 builder.assert_eq(
80 local.expand_direction,
81 local.expand_direction * local.expand_direction * local.expand_direction,
82 );
83
84 builder
88 .when(local.expand_direction * (local.expand_direction + AB::F::ONE))
89 .assert_zero(local.timestamp);
90
91 let mut expand_fields = vec![
92 local.expand_direction.into(),
95 AB::Expr::ZERO,
96 local.address_space - AB::F::from_canonical_u32(ADDR_SPACE_OFFSET),
97 local.leaf_label.into(),
98 ];
99 expand_fields.extend(local.hash.map(Into::into));
100 self.merkle_bus
101 .interact(builder, expand_fields, local.expand_direction.into());
102
103 self.compression_bus.interact(
104 builder,
105 iter::empty()
106 .chain(local.values.map(Into::into))
107 .chain(iter::repeat_n(AB::Expr::ZERO, CHUNK))
108 .chain(local.hash.map(Into::into)),
109 local.expand_direction * local.expand_direction,
110 );
111
112 self.memory_bus
113 .send(
114 MemoryAddress::new(
115 local.address_space,
116 local.leaf_label * AB::F::from_canonical_usize(CHUNK),
117 ),
118 local.values.to_vec(),
119 local.timestamp,
120 )
121 .eval(builder, local.expand_direction);
122 }
123}
124
125pub struct PersistentBoundaryChip<F, const CHUNK: usize> {
126 pub air: PersistentBoundaryAir<CHUNK>,
127 pub touched_labels: TouchedLabels<F, CHUNK>,
128 overridden_height: Option<usize>,
129}
130
131#[derive(Debug)]
132pub enum TouchedLabels<F, const CHUNK: usize> {
133 Running(FxHashSet<(u32, u32)>),
134 Final(Vec<FinalTouchedLabel<F, CHUNK>>),
135}
136
137#[derive(Debug)]
138pub struct FinalTouchedLabel<F, const CHUNK: usize> {
139 address_space: u32,
140 label: u32,
141 init_values: [F; CHUNK],
142 final_values: [F; CHUNK],
143 init_hash: [F; CHUNK],
144 final_hash: [F; CHUNK],
145 final_timestamp: u32,
146}
147
148impl<F: PrimeField32, const CHUNK: usize> Default for TouchedLabels<F, CHUNK> {
149 fn default() -> Self {
150 Self::Running(FxHashSet::default())
151 }
152}
153
154impl<F: PrimeField32, const CHUNK: usize> TouchedLabels<F, CHUNK> {
155 fn touch(&mut self, address_space: u32, label: u32) {
156 match self {
157 TouchedLabels::Running(touched_labels) => {
158 touched_labels.insert((address_space, label));
159 }
160 _ => panic!("Cannot touch after finalization"),
161 }
162 }
163
164 pub fn is_empty(&self) -> bool {
165 match self {
166 TouchedLabels::Running(touched_labels) => touched_labels.is_empty(),
167 TouchedLabels::Final(touched_labels) => touched_labels.is_empty(),
168 }
169 }
170
171 pub fn len(&self) -> usize {
172 match self {
173 TouchedLabels::Running(touched_labels) => touched_labels.len(),
174 TouchedLabels::Final(touched_labels) => touched_labels.len(),
175 }
176 }
177}
178
179impl<const CHUNK: usize, F: PrimeField32> PersistentBoundaryChip<F, CHUNK> {
180 pub fn new(
181 memory_dimensions: MemoryDimensions,
182 memory_bus: MemoryBus,
183 merkle_bus: PermutationCheckBus,
184 compression_bus: PermutationCheckBus,
185 ) -> Self {
186 Self {
187 air: PersistentBoundaryAir {
188 memory_dims: memory_dimensions,
189 memory_bus,
190 merkle_bus,
191 compression_bus,
192 },
193 touched_labels: Default::default(),
194 overridden_height: None,
195 }
196 }
197
198 pub fn set_overridden_height(&mut self, overridden_height: usize) {
199 self.overridden_height = Some(overridden_height);
200 }
201
202 pub fn touch_range(&mut self, address_space: u32, pointer: u32, len: u32) {
203 let start_label = pointer / CHUNK as u32;
204 let end_label = (pointer + len - 1) / CHUNK as u32;
205 for label in start_label..=end_label {
206 self.touched_labels.touch(address_space, label);
207 }
208 }
209
210 #[instrument(name = "boundary_finalize", level = "debug", skip_all)]
211 pub(crate) fn finalize<H>(
212 &mut self,
213 initial_memory: &MemoryImage,
214 final_memory: &TimestampedEquipartition<F, CHUNK>,
216 hasher: &H,
217 ) where
218 H: Hasher<CHUNK, F> + Sync + for<'a> SerialReceiver<&'a [F]>,
219 {
220 let final_touched_labels: Vec<_> = final_memory
221 .par_iter()
222 .map(|&((addr_space, ptr), ts_values)| {
223 let TimestampedValues { timestamp, values } = ts_values;
224 let init_values = array::from_fn(|i| unsafe {
226 initial_memory.get_f::<F>(addr_space, ptr + i as u32)
227 });
228 let initial_hash = hasher.hash(&init_values);
229 let final_hash = hasher.hash(&values);
230 FinalTouchedLabel {
231 address_space: addr_space,
232 label: ptr / CHUNK as u32,
233 init_values,
234 final_values: values,
235 init_hash: initial_hash,
236 final_hash,
237 final_timestamp: timestamp,
238 }
239 })
240 .collect();
241 for l in &final_touched_labels {
242 hasher.receive(&l.init_values);
243 hasher.receive(&l.final_values);
244 }
245 self.touched_labels = TouchedLabels::Final(final_touched_labels);
246 }
247}
248
249impl<const CHUNK: usize, RA, SC> Chip<RA, CpuBackend<SC>> for PersistentBoundaryChip<Val<SC>, CHUNK>
250where
251 SC: StarkGenericConfig,
252 Val<SC>: PrimeField32,
253{
254 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
255 let trace = {
256 let width = PersistentBoundaryCols::<Val<SC>, CHUNK>::width();
257 let mut height = (2 * self.touched_labels.len()).next_power_of_two();
259 if let Some(mut oh) = self.overridden_height {
260 oh = oh.next_power_of_two();
261 assert!(
262 oh >= height,
263 "Overridden height is less than the required height"
264 );
265 height = oh;
266 }
267 let mut rows = Val::<SC>::zero_vec(height * width);
268
269 let touched_labels = match &self.touched_labels {
270 TouchedLabels::Final(touched_labels) => touched_labels,
271 _ => panic!("Cannot generate trace before finalization"),
272 };
273
274 rows.par_chunks_mut(2 * width)
275 .zip(touched_labels.par_iter())
276 .for_each(|(row, touched_label)| {
277 let (initial_row, final_row) = row.split_at_mut(width);
278 *initial_row.borrow_mut() = PersistentBoundaryCols {
279 expand_direction: Val::<SC>::ONE,
280 address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
281 leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
282 values: touched_label.init_values,
283 hash: touched_label.init_hash,
284 timestamp: Val::<SC>::from_canonical_u32(INITIAL_TIMESTAMP),
285 };
286
287 *final_row.borrow_mut() = PersistentBoundaryCols {
288 expand_direction: Val::<SC>::NEG_ONE,
289 address_space: Val::<SC>::from_canonical_u32(touched_label.address_space),
290 leaf_label: Val::<SC>::from_canonical_u32(touched_label.label),
291 values: touched_label.final_values,
292 hash: touched_label.final_hash,
293 timestamp: Val::<SC>::from_canonical_u32(touched_label.final_timestamp),
294 };
295 });
296 Arc::new(RowMajorMatrix::new(rows, width))
297 };
298 AirProvingContext::simple_no_pis(trace)
299 }
300}
301
302impl<const CHUNK: usize, F: PrimeField32> ChipUsageGetter for PersistentBoundaryChip<F, CHUNK> {
303 fn air_name(&self) -> String {
304 "Boundary".to_string()
305 }
306
307 fn current_trace_height(&self) -> usize {
308 2 * self.touched_labels.len()
309 }
310
311 fn trace_width(&self) -> usize {
312 PersistentBoundaryCols::<F, CHUNK>::width()
313 }
314}