1use std::{
2 borrow::{Borrow, BorrowMut},
3 cmp::min,
4 sync::Arc,
5};
6
7use itertools::zip_eq;
8use openvm_circuit_primitives::{
9 is_less_than_array::{
10 IsLtArrayAuxCols, IsLtArrayIo, IsLtArraySubAir, IsLtArrayWhenTransitionAir,
11 },
12 utils::{compose, implies},
13 var_range::{SharedVariableRangeCheckerChip, VariableRangeCheckerBus},
14 SubAir, TraceSubRowGenerator,
15};
16use openvm_circuit_primitives_derive::AlignedBorrow;
17use openvm_stark_backend::{
18 config::{StarkGenericConfig, Val},
19 interaction::InteractionBuilder,
20 p3_air::{Air, AirBuilder, BaseAir},
21 p3_field::{Field, FieldAlgebra, PrimeField32},
22 p3_matrix::{dense::RowMajorMatrix, Matrix},
23 p3_maybe_rayon::prelude::*,
24 prover::{cpu::CpuBackend, types::AirProvingContext},
25 rap::{BaseAirWithPublicValues, PartitionedBaseAir},
26 Chip, ChipUsageGetter,
27};
28use static_assertions::const_assert;
29use tracing::instrument;
30
31use super::TimestampedEquipartition;
32use crate::system::memory::{
33 offline_checker::{MemoryBus, AUX_LEN},
34 MemoryAddress,
35};
36
37#[cfg(test)]
38mod tests;
39
40const ADDR_ELTS: usize = 2;
42const NUM_AS_LIMBS: usize = 1;
43const_assert!(NUM_AS_LIMBS <= AUX_LEN);
44
45#[repr(C)]
46#[derive(Clone, Copy, Debug, AlignedBorrow)]
47pub struct VolatileBoundaryCols<T> {
48 pub addr_space_limbs: [T; NUM_AS_LIMBS],
49 pub pointer_limbs: [T; AUX_LEN],
50
51 pub initial_data: T,
52 pub final_data: T,
53 pub final_timestamp: T,
54
55 pub is_valid: T,
57 pub addr_lt_aux: IsLtArrayAuxCols<T, ADDR_ELTS, AUX_LEN>,
58}
59
60#[derive(Clone, Debug)]
61pub struct VolatileBoundaryAir {
62 pub memory_bus: MemoryBus,
63 pub addr_lt_air: IsLtArrayWhenTransitionAir<ADDR_ELTS>,
64
65 addr_space_limb_bits: [usize; NUM_AS_LIMBS],
66 pointer_limb_bits: [usize; AUX_LEN],
67}
68
69impl VolatileBoundaryAir {
70 pub fn new(
71 memory_bus: MemoryBus,
72 addr_space_max_bits: usize,
73 pointer_max_bits: usize,
74 range_bus: VariableRangeCheckerBus,
75 ) -> Self {
76 let addr_lt_air =
77 IsLtArraySubAir::<ADDR_ELTS>::new(range_bus, addr_space_max_bits.max(pointer_max_bits))
78 .when_transition();
79 let range_max_bits = range_bus.range_max_bits;
80 let mut addr_space_limb_bits = [0; NUM_AS_LIMBS];
81 let mut bits_remaining = addr_space_max_bits;
82 for limb_bits in &mut addr_space_limb_bits {
83 *limb_bits = min(bits_remaining, range_max_bits);
84 bits_remaining -= *limb_bits;
85 }
86 assert_eq!(bits_remaining, 0, "addr_space_max_bits={addr_space_max_bits} with {NUM_AS_LIMBS} limbs exceeds range_max_bits={range_max_bits}");
87 let mut pointer_limb_bits = [0; AUX_LEN];
88 let mut bits_remaining = pointer_max_bits;
89 for limb_bits in &mut pointer_limb_bits {
90 *limb_bits = min(bits_remaining, range_max_bits);
91 bits_remaining -= *limb_bits;
92 }
93 assert_eq!(bits_remaining, 0, "pointer_max_bits={pointer_max_bits} with {AUX_LEN} limbs exceeds range_max_bits={range_max_bits}");
94 Self {
95 memory_bus,
96 addr_lt_air,
97 addr_space_limb_bits,
98 pointer_limb_bits,
99 }
100 }
101
102 pub fn range_bus(&self) -> VariableRangeCheckerBus {
103 self.addr_lt_air.0.lt.bus
104 }
105}
106
107impl<F: Field> BaseAirWithPublicValues<F> for VolatileBoundaryAir {}
108impl<F: Field> PartitionedBaseAir<F> for VolatileBoundaryAir {}
109impl<F: Field> BaseAir<F> for VolatileBoundaryAir {
110 fn width(&self) -> usize {
111 VolatileBoundaryCols::<F>::width()
112 }
113}
114
115impl<AB: InteractionBuilder> Air<AB> for VolatileBoundaryAir {
116 fn eval(&self, builder: &mut AB) {
117 let main = builder.main();
118
119 let [local, next] = [0, 1].map(|i| main.row_slice(i));
120 let local: &VolatileBoundaryCols<_> = (*local).borrow();
121 let next: &VolatileBoundaryCols<_> = (*next).borrow();
122
123 builder.assert_bool(local.is_valid);
124
125 builder
127 .when_transition()
128 .assert_one(implies(next.is_valid, local.is_valid));
129
130 for (&limb, limb_bits) in zip_eq(&local.addr_space_limbs, self.addr_space_limb_bits) {
132 self.range_bus()
133 .range_check(limb, limb_bits)
134 .eval(builder, local.is_valid);
135 }
136 for (&limb, limb_bits) in zip_eq(&local.pointer_limbs, self.pointer_limb_bits) {
138 self.range_bus()
139 .range_check(limb, limb_bits)
140 .eval(builder, local.is_valid);
141 }
142 let range_max_bits = self.range_bus().range_max_bits;
143 let [addr_space, next_addr_space] = [&local.addr_space_limbs, &next.addr_space_limbs]
146 .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
147 let [pointer, next_pointer] = [&local.pointer_limbs, &next.pointer_limbs]
148 .map(|limbs| compose::<AB::Expr>(limbs, range_max_bits));
149
150 let lt_io = IsLtArrayIo {
153 x: [addr_space.clone(), pointer.clone()],
154 y: [next_addr_space, next_pointer],
155 out: AB::Expr::ONE,
156 count: next.is_valid.into(),
157 };
158 self.addr_lt_air
161 .eval(builder, (lt_io, (&local.addr_lt_aux).into()));
162
163 self.memory_bus
165 .send(
166 MemoryAddress::new(addr_space.clone(), pointer.clone()),
167 vec![local.initial_data],
168 AB::Expr::ZERO,
169 )
170 .eval(builder, local.is_valid);
171
172 self.memory_bus
174 .receive(
175 MemoryAddress::new(addr_space.clone(), pointer.clone()),
176 vec![local.final_data],
177 local.final_timestamp,
178 )
179 .eval(builder, local.is_valid);
180 }
181}
182
183pub struct VolatileBoundaryChip<F> {
184 pub air: VolatileBoundaryAir,
185 range_checker: SharedVariableRangeCheckerChip,
186 overridden_height: Option<usize>,
187 pub final_memory: Option<TimestampedEquipartition<F, 1>>,
188 addr_space_max_bits: usize,
189 pointer_max_bits: usize,
190}
191
192impl<F> VolatileBoundaryChip<F> {
193 pub fn new(
194 memory_bus: MemoryBus,
195 addr_space_max_bits: usize,
196 pointer_max_bits: usize,
197 range_checker: SharedVariableRangeCheckerChip,
198 ) -> Self {
199 let range_bus = range_checker.bus();
200 Self {
201 air: VolatileBoundaryAir::new(
202 memory_bus,
203 addr_space_max_bits,
204 pointer_max_bits,
205 range_bus,
206 ),
207 range_checker,
208 overridden_height: None,
209 final_memory: None,
210 addr_space_max_bits,
211 pointer_max_bits,
212 }
213 }
214}
215
216impl<F: PrimeField32> VolatileBoundaryChip<F> {
217 pub fn set_overridden_height(&mut self, overridden_height: usize) {
218 self.overridden_height = Some(overridden_height);
219 }
220 #[instrument(name = "boundary_finalize", level = "debug", skip_all)]
223 pub fn finalize(&mut self, final_memory: TimestampedEquipartition<F, 1>) {
224 self.final_memory = Some(final_memory);
225 }
226}
227
228impl<RA, SC: StarkGenericConfig> Chip<RA, CpuBackend<SC>> for VolatileBoundaryChip<Val<SC>>
229where
230 Val<SC>: PrimeField32,
231{
232 fn generate_proving_ctx(&self, _: RA) -> AirProvingContext<CpuBackend<SC>> {
233 let width = self.trace_width();
237 let addr_lt_air = &self.air.addr_lt_air;
238 let final_memory = self
240 .final_memory
241 .clone()
242 .expect("Trace generation should be called after finalize");
243 let trace_height = if let Some(height) = self.overridden_height {
244 assert!(
245 height >= final_memory.len(),
246 "Overridden height is less than the required height"
247 );
248 height
249 } else {
250 final_memory.len()
251 };
252 let trace_height = trace_height.next_power_of_two();
253
254 let sorted_final_memory: Vec<_> = final_memory.into_par_iter().collect();
256 let memory_len = sorted_final_memory.len();
257
258 let range_checker = self.range_checker.as_ref();
259 let mut rows = Val::<SC>::zero_vec(trace_height * width);
260 rows.par_chunks_mut(width)
261 .zip(sorted_final_memory.par_iter())
262 .enumerate()
263 .for_each(|(i, (row, ((addr_space, ptr), timestamped_values)))| {
264 let [data] = timestamped_values.values;
266 let row: &mut VolatileBoundaryCols<_> = row.borrow_mut();
267 range_checker.decompose(
268 *addr_space,
269 self.addr_space_max_bits,
270 &mut row.addr_space_limbs,
271 );
272 range_checker.decompose(*ptr, self.pointer_max_bits, &mut row.pointer_limbs);
273 row.initial_data = Val::<SC>::ZERO;
274 row.final_data = data;
275 row.final_timestamp = Val::<SC>::from_canonical_u32(timestamped_values.timestamp);
276 row.is_valid = Val::<SC>::ONE;
277
278 if i != memory_len - 1 {
280 let (next_addr_space, next_ptr) = sorted_final_memory[i + 1].0;
281 let mut out = Val::<SC>::ZERO;
282 addr_lt_air.0.generate_subrow(
283 (
284 self.range_checker.as_ref(),
285 &[
286 Val::<SC>::from_canonical_u32(*addr_space),
287 Val::<SC>::from_canonical_u32(*ptr),
288 ],
289 &[
290 Val::<SC>::from_canonical_u32(next_addr_space),
291 Val::<SC>::from_canonical_u32(next_ptr),
292 ],
293 ),
294 ((&mut row.addr_lt_aux).into(), &mut out),
295 );
296 debug_assert_eq!(out, Val::<SC>::ONE, "Addresses are not sorted");
297 }
298 });
299 if memory_len > 0 {
301 let mut out = Val::<SC>::ZERO;
302 let row: &mut VolatileBoundaryCols<_> = rows[width * (trace_height - 1)..].borrow_mut();
303 addr_lt_air.0.generate_subrow(
304 (
305 self.range_checker.as_ref(),
306 &[Val::<SC>::ZERO, Val::<SC>::ZERO],
307 &[Val::<SC>::ZERO, Val::<SC>::ZERO],
308 ),
309 ((&mut row.addr_lt_aux).into(), &mut out),
310 );
311 }
312
313 let trace = Arc::new(RowMajorMatrix::new(rows, width));
314 AirProvingContext::simple_no_pis(trace)
315 }
316}
317
318impl<F: PrimeField32> ChipUsageGetter for VolatileBoundaryChip<F> {
319 fn air_name(&self) -> String {
320 "Boundary".to_string()
321 }
322
323 fn current_trace_height(&self) -> usize {
324 if let Some(final_memory) = &self.final_memory {
325 final_memory.len()
326 } else {
327 0
328 }
329 }
330
331 fn trace_width(&self) -> usize {
332 VolatileBoundaryCols::<F>::width()
333 }
334}